ChatBot Training Tour
This article records my process of training an AI Chat Bot.
weibo.pth: 4435959 (41X)
Corpus.pth: 106478 (batch_size = 2048)
Summary
Version | Epoch | Loss | Parameter |
---|---|---|---|
V14-4 | 68 | 4.748 | Same as before |
V14-3 | 32 | 4.753 | Same as before |
V14-2 | 92 | 4.744 | Same as before |
V14-1 | 13 | 4.796 | Same as before |
V14 | 42 | 4.773 | batch_size = 13750 |
V13-9 | 18 | 4.829 | Same as before |
V13-8 | 21 | 4.823 | Same as before |
V13-7 | 18 | 4.837 | Same as before |
V13-6 | 17 | 4.860 | Same as before |
V13-5 | 16 | 4.835 | Same as before |
V13-4 | 20 | 4.866 | Same as before |
V13-3 | 6 | 4.889 | Same as before |
V13-2 | 16 | 4.881 | Same as before |
V13-1 | 10 | 4.874 | Same as before |
V13 | 20 | 4.876 | batch_size = 11000 |
V12-2 | 21 | 4.938 | Same as before |
V12-1 | 19 | 4.914 | Same as before |
V12 | 17 | 4.970 | batch_size = 10000 learning_rate = 0.0001 |
V11 | 6 | 4.965 | batch_size = 12000 |
V10 | 14 | 4.976 | batch_size = 10000 |
V9 | 7 | 5.059 | batch_size = 5000 learning_rate = 0.001 |
V8 | 10 | 5.092 | batch_size = 4096 learning_rate = 0.001 |
V7 | 11 | 5.181 | batch_size = 2048 learning_rate = 0.0001 |
V6 | 7 | 5.198 | batch_size = 2048 learning_rate = 0.0005 |
V5 | 22 | 5.226 | batch_size = 2048 learning_rate = 0.001 |
V4 | 8 | 5.917 | dropout = 0.2 teacher_forcing_ratio = 0.85 batch_size = 2048 learning_rate = 0.0001 |
V3 | 14 | 5.966 | dropout = 0.2 clip = 60.0 teacher_forcing_ratio = 0.9 batch_size = 2048 learning_rate = 0.0001 |
V2 | 44 | 5.055 | Origin |
2024.04.23
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: d89c96166e3ca903aed3bd960c7d8dab8fdaa5f0)训练平台:百度飞浆
Epoch: 68
Average loss: 4.748495200444133
2024.04.23
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: d89c96166e3ca903aed3bd960c7d8dab8fdaa5f0)训练平台:百度飞浆
Epoch: 32
Average loss: 4.753096600420531
2024.04.22
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: d89c96166e3ca903aed3bd960c7d8dab8fdaa5f0)训练平台:百度飞浆
Epoch: 92
Average loss: 4.743738563202844
2024.04.21
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: d89c96166e3ca903aed3bd960c7d8dab8fdaa5f0)训练平台:百度飞浆
Epoch: 13
Average loss: 4.796416420097512
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: d89c96166e3ca903aed3bd960c7d8dab8fdaa5f0)训练平台:百度飞浆
Epoch: 42
Average loss: 4.772935219662937
2024.04.20
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)训练平台:百度飞浆
Epoch: 18
Average loss: 4.828708153485162
2024.04.19
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)训练平台:百度飞浆
Epoch: 21
Average loss: 4.822652317165321
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)训练平台:百度飞浆
Epoch: 18
Average loss: 4.837234509962312
2024.04.16
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)训练平台:百度飞浆
Epoch: 17
Average loss: 4.860292780773252
2024.04.15
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)训练平台:百度飞浆
Epoch: 16
Average loss: 4.83524939768576
2024.04.14
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)训练平台:百度飞浆
Epoch: 20
Average loss: 4.86629382181837
2024.04.07
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)训练平台:百度飞浆
Epoch: 6
Average loss: 4.889181085683352
2024.04.07
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)训练平台:百度飞浆
Epoch: 16
Average loss: 4.881189099620858
2024.04.07
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)训练平台:百度飞浆
Epoch: 10
Average loss: 4.8736686256889294
2024.04.04
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)训练平台:百度飞浆
Epoch: 20
Average loss: 4.876422007442719
2024.03.29
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 8e5ffe59fa69e32fb938f1883db340d87b261f7a)训练平台:百度飞浆
Epoch: 21
Average loss: 4.9384859099953164
2024.03.28
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 8e5ffe59fa69e32fb938f1883db340d87b261f7a)训练平台:百度飞浆
Epoch: 19
Average loss: 4.914344686694272
2024.03.27
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 8e5ffe59fa69e32fb938f1883db340d87b261f7a)训练平台:百度飞浆
Epoch: 17
Average loss: 4.969914259826938
2024.03.26
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: cb49184c8e7092ca1edfee4e3e8275607b975b07)训练平台:vc
Epoch: 6
Average loss: 4.96543633802544
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: faa387b6875bbcdf8e3adc250684ae55a2de3812)训练平台:百度飞浆
Epoch: 14
Average loss: 4.97629819865179
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: ee22690c8e35bce2f99fb3f4b190908187256fbe)
训练平台:百度飞浆Epoch: 7
Average loss: 5.059292034180287
2024.03.25
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 88a3fe81e66cba7f529182d17ef886b444720542)
训练平台:百度飞浆Epoch: 10
Average loss: 5.092367821139052
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 1ce617ee78368acfaa489b5dd1c8021d2f506e27)
训练平台:百度飞浆Epoch: 11
Average loss: 5.18057608768361
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 6bba047016a0f05a2fe1e3dda8ae4389f018c9ca)
训练平台:百度飞浆Epoch: 7
Average loss: 5.197861174456669
2024.03.24
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 6435d7d7bb7bf52b4dc6e1fe46c76f55a63d8c98)
训练平台:百度飞浆Epoch: 22
Average loss: 5.2257363110720165
The loss curve you’ve provided suggests that the loss has decreased significantly and then plateaued, which is common in training deep learning models as they begin to converge. When the loss curve plateaus, it often means that the model has reached a point where it’s not learning much more from the data given the current settings.
Here are some strategies you could consider to further decrease the loss:
Learning Rate Adjustment: If you’re using a constant learning rate, you might want to switch to a learning rate schedule that decreases the learning rate as you continue to train. If you’re already using a learning rate scheduler, consider adjusting it to decrease the learning rate more slowly or to use a learning rate warm-up period.
Early Stopping: If the loss isn’t improving for a significant number of epochs, it might be time to stop training to prevent overfitting.
Regularization Techniques: Implement or increase regularization to prevent overfitting if you suspect that’s an issue. This can include techniques like dropout, L2 regularization, or using a more robust optimizer like AdamW which incorporates weight decay.
Model Architecture: Consider whether your model might be too simple or too complex for the problem. Adding layers or units can increase model capacity, while reducing them can help the model generalize better if it’s overfitting.
Data Augmentation or Cleaning: More or cleaner data can help improve model performance. If you’re already using all your data, consider data augmentation techniques. Sometimes, reviewing and cleaning the dataset can also yield improvements, especially if there are noisy labels or irrelevant features.
Batch Size: Adjust the batch size if possible. A smaller batch size can provide a regularizing effect and larger batch sizes can provide more stable gradient estimates.
Gradient Clipping: This prevents the exploding gradient problem by capping the gradients during backpropagation to avoid very large updates to the weights.
Different Optimizer: If you’re not already, consider using an optimizer like Adam, which adapts the learning rate for each parameter and is generally good for large datasets and complex models.
Loss Function: Consider if the loss function you’re using is appropriate for your data and problem. Sometimes, tweaking the loss function or using a custom one can yield better results.
Hyperparameter Tuning: Use techniques like grid search, random search, or Bayesian optimization to systematically explore different hyperparameter settings.
To apply these strategies, you would need to adjust the parameters in your training code (train_eval.py
, config.py
, etc.). Since I can’t inspect the code within those files directly, I would recommend looking for the following in your code:
- Where the learning rate is defined and how it’s updated during training.
- Where regularization settings are configured.
- Where the model’s architecture is defined.
- How data is preprocessed and augmented.
- The batch size settings.
Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 23108d169e8bcb8e777fdc4a21e6165d82dbb0a8)
训练平台:百度飞浆Epoch: 306
Average loss: 1.412546370036195
使用chatbot_0324_1543模型训练得到的部分对话
Config.py: https://github.com/OnlyourMiracle/ChatBot/blob/main/ChatBotV4/config.py
(SHA: 921086296f841570df5b210c57e2054232e0c874)
训练平台:百度飞浆Epoch: 8
Average loss: 5.917065527910856
2024.03.22
Config.py: https://github.com/OnlyourMiracle/ChatBot/blob/main/ChatBotV4/config.py
(SHA: e747321aa47db35adfcccfcb2f5b24788c13b1ca)训练平台:百度飞浆
Epoch: 14
Average loss : 5.96649232743882
2024.03.21
Github: https://github.com/OnlyourMiracle/ChatBot/tree/main/ChatBotV3
最新 ChatBot 训练代码(V3)
训练平台:百度飞浆
Epoch: 44
Average loss : 5.055076121505371
观察结果:
- 快速下降期: 损失值在最初迅速下降,这表明模型在最初的训练阶段学习到了数据集的显著特征。
- 稳定期: 随着迭代次数增加,损失值下降放缓并在5.1左右波动。这表明模型可能已经达到了当前参数和数据配置下的性能瓶颈。
改进建议:
- 调整学习率: 如果损失值在一个水平线上波动,可能意味着学习率设置得太高或太低。您可以尝试使用学习率衰减策略,或者尝试更多的学习率值。
- 引入或增加 Dropout: 您当前的模型没有使用dropout,dropout可以帮助防止过拟合并提高模型的泛化能力。考虑在RNN层和/或全连接层之间添加dropout。
- 梯度裁剪: 您设置的梯度裁剪阈值为50,这通常用于防止梯度爆炸。观察模型是否经历过任何梯度爆炸的迹象,如果没有,可以尝试放宽裁剪阈值。
- 调整优化器: 如果当前优化器(可能是Adam,因为这是学习率1e-3的常用选择)的性能已经达到瓶颈,可以尝试其他优化器,如SGD或Adagrad。
- 改变教师强制比例: 教师强制比例目前为1.0,意味着您总是使用真实的输出作为下一时刻的输入。随着训练的进行,逐渐降低这个比例可能会帮助模型学会更加自信地依靠自己的预测。
- 模型容量: 如果模型太简单,可能无法捕捉数据的全部复杂性。您可以尝试增加隐藏层的大小或添加更多的RNN层。
- 批次大小: 您的批次大小相对较大(2048)。虽然较大的批次可以提供更稳定的梯度估计,但它们也可能导致优化过程的探索性下降。尝试减少批次大小可能有助于模型找到更优的损失值。
- 早停 (Early Stopping): 如果验证集上的性能不再提升,可以停止训练以避免过拟合。
- 及时保存模型:考虑到目前在百度飞浆平台上训练的速度以及受每日GPU使用限额影响,应调整“save_every”参数大小,改为10较为合适。
在尝试这些改进策略时,重要的是一次只改变一个参数,并且监视其对训练动态的影响。另外,确保有一个稳定的验证集来监测这些改变如何影响模型在未见数据上的表现。
2024.03.20
前几天一直出现一个问题:当训练到 118 epoch 时 loss=nan,今天终于找到问题的答案了,原来是datapreprocess过程出错了,原来的代码:
1 | for line in combined_lines: |
今天找到问题所在后改为如下代码就成功地解决了问题。
1 | for line in combined_lines: |
ChatBot Training Tour