ChatBot Training Tour

This article records my process of training an AI Chat Bot.

weibo.pth: 4435959 (41X)

Corpus.pth: 106478 (batch_size = 2048)

Summary

Version Epoch Loss Parameter
V14-4 68 4.748 Same as before
V14-3 32 4.753 Same as before
V14-2 92 4.744 Same as before
V14-1 13 4.796 Same as before
V14 42 4.773 batch_size = 13750
V13-9 18 4.829 Same as before
V13-8 21 4.823 Same as before
V13-7 18 4.837 Same as before
V13-6 17 4.860 Same as before
V13-5 16 4.835 Same as before
V13-4 20 4.866 Same as before
V13-3 6 4.889 Same as before
V13-2 16 4.881 Same as before
V13-1 10 4.874 Same as before
V13 20 4.876 batch_size = 11000
V12-2 21 4.938 Same as before
V12-1 19 4.914 Same as before
V12 17 4.970 batch_size = 10000
learning_rate = 0.0001
V11 6 4.965 batch_size = 12000
V10 14 4.976 batch_size = 10000
V9 7 5.059 batch_size = 5000
learning_rate = 0.001
V8 10 5.092 batch_size = 4096
learning_rate = 0.001
V7 11 5.181 batch_size = 2048
learning_rate = 0.0001
V6 7 5.198 batch_size = 2048
learning_rate = 0.0005
V5 22 5.226 batch_size = 2048
learning_rate = 0.001
V4 8 5.917 dropout = 0.2
teacher_forcing_ratio = 0.85
batch_size = 2048
learning_rate = 0.0001
V3 14 5.966 dropout = 0.2
clip = 60.0
teacher_forcing_ratio = 0.9
batch_size = 2048
learning_rate = 0.0001
V2 44 5.055 Origin

2024.04.23

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: d89c96166e3ca903aed3bd960c7d8dab8fdaa5f0)

训练平台:百度飞浆

Epoch: 68

Average loss: 4.748495200444133

2024.04.23

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: d89c96166e3ca903aed3bd960c7d8dab8fdaa5f0)

训练平台:百度飞浆

Epoch: 32

Average loss: 4.753096600420531

2024.04.22

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: d89c96166e3ca903aed3bd960c7d8dab8fdaa5f0)

训练平台:百度飞浆

Epoch: 92

Average loss: 4.743738563202844

2024.04.21

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: d89c96166e3ca903aed3bd960c7d8dab8fdaa5f0)

训练平台:百度飞浆

Epoch: 13

Average loss: 4.796416420097512

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: d89c96166e3ca903aed3bd960c7d8dab8fdaa5f0)

训练平台:百度飞浆

Epoch: 42

Average loss: 4.772935219662937

2024.04.20

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)

训练平台:百度飞浆

Epoch: 18

Average loss: 4.828708153485162

2024.04.19

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)

训练平台:百度飞浆

Epoch: 21

Average loss: 4.822652317165321

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)

训练平台:百度飞浆

Epoch: 18

Average loss: 4.837234509962312

2024.04.16

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)

训练平台:百度飞浆

Epoch: 17

Average loss: 4.860292780773252

2024.04.15

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)

训练平台:百度飞浆

Epoch: 16

Average loss: 4.83524939768576

2024.04.14

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)

训练平台:百度飞浆

Epoch: 20

Average loss: 4.86629382181837

2024.04.07

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)

训练平台:百度飞浆

Epoch: 6

Average loss: 4.889181085683352

2024.04.07

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)

训练平台:百度飞浆

Epoch: 16

Average loss: 4.881189099620858

2024.04.07

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)

训练平台:百度飞浆

Epoch: 10

Average loss: 4.8736686256889294

2024.04.04

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 87160e2723ede287d60a68b82700e63bdc1a5c2b)

训练平台:百度飞浆

Epoch: 20

Average loss: 4.876422007442719

2024.03.29

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 8e5ffe59fa69e32fb938f1883db340d87b261f7a)

训练平台:百度飞浆

Epoch: 21

Average loss: 4.9384859099953164

2024.03.28

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 8e5ffe59fa69e32fb938f1883db340d87b261f7a)

训练平台:百度飞浆

Epoch: 19

Average loss: 4.914344686694272

2024.03.27

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 8e5ffe59fa69e32fb938f1883db340d87b261f7a)

训练平台:百度飞浆

Epoch: 17

Average loss: 4.969914259826938

2024.03.26

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: cb49184c8e7092ca1edfee4e3e8275607b975b07)

训练平台:vc

Epoch: 6

Average loss: 4.96543633802544

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: faa387b6875bbcdf8e3adc250684ae55a2de3812)

训练平台:百度飞浆

Epoch: 14

Average loss: 4.97629819865179

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: ee22690c8e35bce2f99fb3f4b190908187256fbe)
训练平台:百度飞浆

Epoch: 7

Average loss: 5.059292034180287

2024.03.25

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 88a3fe81e66cba7f529182d17ef886b444720542)
训练平台:百度飞浆

Epoch: 10

Average loss: 5.092367821139052

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 1ce617ee78368acfaa489b5dd1c8021d2f506e27)
训练平台:百度飞浆

Epoch: 11

Average loss: 5.18057608768361

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 6bba047016a0f05a2fe1e3dda8ae4389f018c9ca)
训练平台:百度飞浆

Epoch: 7

Average loss: 5.197861174456669

2024.03.24

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 6435d7d7bb7bf52b4dc6e1fe46c76f55a63d8c98)
训练平台:百度飞浆

Epoch: 22

Average loss: 5.2257363110720165

The loss curve you’ve provided suggests that the loss has decreased significantly and then plateaued, which is common in training deep learning models as they begin to converge. When the loss curve plateaus, it often means that the model has reached a point where it’s not learning much more from the data given the current settings.

Here are some strategies you could consider to further decrease the loss:

  1. Learning Rate Adjustment: If you’re using a constant learning rate, you might want to switch to a learning rate schedule that decreases the learning rate as you continue to train. If you’re already using a learning rate scheduler, consider adjusting it to decrease the learning rate more slowly or to use a learning rate warm-up period.

  2. Early Stopping: If the loss isn’t improving for a significant number of epochs, it might be time to stop training to prevent overfitting.

  3. Regularization Techniques: Implement or increase regularization to prevent overfitting if you suspect that’s an issue. This can include techniques like dropout, L2 regularization, or using a more robust optimizer like AdamW which incorporates weight decay.

  4. Model Architecture: Consider whether your model might be too simple or too complex for the problem. Adding layers or units can increase model capacity, while reducing them can help the model generalize better if it’s overfitting.

  5. Data Augmentation or Cleaning: More or cleaner data can help improve model performance. If you’re already using all your data, consider data augmentation techniques. Sometimes, reviewing and cleaning the dataset can also yield improvements, especially if there are noisy labels or irrelevant features.

  6. Batch Size: Adjust the batch size if possible. A smaller batch size can provide a regularizing effect and larger batch sizes can provide more stable gradient estimates.

  7. Gradient Clipping: This prevents the exploding gradient problem by capping the gradients during backpropagation to avoid very large updates to the weights.

  8. Different Optimizer: If you’re not already, consider using an optimizer like Adam, which adapts the learning rate for each parameter and is generally good for large datasets and complex models.

  9. Loss Function: Consider if the loss function you’re using is appropriate for your data and problem. Sometimes, tweaking the loss function or using a custom one can yield better results.

  10. Hyperparameter Tuning: Use techniques like grid search, random search, or Bayesian optimization to systematically explore different hyperparameter settings.

To apply these strategies, you would need to adjust the parameters in your training code (train_eval.py, config.py, etc.). Since I can’t inspect the code within those files directly, I would recommend looking for the following in your code:

  • Where the learning rate is defined and how it’s updated during training.
  • Where regularization settings are configured.
  • Where the model’s architecture is defined.
  • How data is preprocessed and augmented.
  • The batch size settings.

Config.py: https://github.com/OnlyourMiracle/Chinese-Chatbot-PyTorch-Implementation/blob/master/config.py
(SHA: 23108d169e8bcb8e777fdc4a21e6165d82dbb0a8)
训练平台:百度飞浆

Epoch: 306

Average loss: 1.412546370036195

使用chatbot_0324_1543模型训练得到的部分对话

Config.py: https://github.com/OnlyourMiracle/ChatBot/blob/main/ChatBotV4/config.py
(SHA: 921086296f841570df5b210c57e2054232e0c874)
训练平台:百度飞浆

Epoch: 8

Average loss: 5.917065527910856

2024.03.22

Config.py: https://github.com/OnlyourMiracle/ChatBot/blob/main/ChatBotV4/config.py
(SHA: e747321aa47db35adfcccfcb2f5b24788c13b1ca)

训练平台:百度飞浆

Epoch: 14

Average loss : 5.96649232743882

2024.03.21

Github: https://github.com/OnlyourMiracle/ChatBot/tree/main/ChatBotV3

最新 ChatBot 训练代码(V3)

训练平台:百度飞浆

Epoch: 44

Average loss : 5.055076121505371

观察结果:

  1. 快速下降期: 损失值在最初迅速下降,这表明模型在最初的训练阶段学习到了数据集的显著特征。
  2. 稳定期: 随着迭代次数增加,损失值下降放缓并在5.1左右波动。这表明模型可能已经达到了当前参数和数据配置下的性能瓶颈。

改进建议:

  1. 调整学习率: 如果损失值在一个水平线上波动,可能意味着学习率设置得太高或太低。您可以尝试使用学习率衰减策略,或者尝试更多的学习率值。
  2. 引入或增加 Dropout: 您当前的模型没有使用dropout,dropout可以帮助防止过拟合并提高模型的泛化能力。考虑在RNN层和/或全连接层之间添加dropout。
  3. 梯度裁剪: 您设置的梯度裁剪阈值为50,这通常用于防止梯度爆炸。观察模型是否经历过任何梯度爆炸的迹象,如果没有,可以尝试放宽裁剪阈值。
  4. 调整优化器: 如果当前优化器(可能是Adam,因为这是学习率1e-3的常用选择)的性能已经达到瓶颈,可以尝试其他优化器,如SGD或Adagrad。
  5. 改变教师强制比例: 教师强制比例目前为1.0,意味着您总是使用真实的输出作为下一时刻的输入。随着训练的进行,逐渐降低这个比例可能会帮助模型学会更加自信地依靠自己的预测。
  6. 模型容量: 如果模型太简单,可能无法捕捉数据的全部复杂性。您可以尝试增加隐藏层的大小或添加更多的RNN层。
  7. 批次大小: 您的批次大小相对较大(2048)。虽然较大的批次可以提供更稳定的梯度估计,但它们也可能导致优化过程的探索性下降。尝试减少批次大小可能有助于模型找到更优的损失值。
  8. 早停 (Early Stopping): 如果验证集上的性能不再提升,可以停止训练以避免过拟合。
  9. 及时保存模型:考虑到目前在百度飞浆平台上训练的速度以及受每日GPU使用限额影响,应调整“save_every”参数大小,改为10较为合适。

在尝试这些改进策略时,重要的是一次只改变一个参数,并且监视其对训练动态的影响。另外,确保有一个稳定的验证集来监测这些改变如何影响模型在未见数据上的表现。

2024.03.20

​ 前几天一直出现一个问题:当训练到 118 epoch 时 loss=nan,今天终于找到问题的答案了,原来是datapreprocess过程出错了,原来的代码:

1
2
3
4
5
6
7
8
for line in combined_lines:
sentences = []
for value in line:
sentence = cop.sub("", value).split()
#sentence = jieba.lcut(cop.sub("", value))
sentence = sentence[:max_sentence_length] + [eos]
sentences.append(sentence)
data.append(sentences)

今天找到问题所在后改为如下代码就成功地解决了问题。

1
2
3
4
5
6
7
8
for line in combined_lines:
sentences = []
for value in line:
sentence = cop.sub("", value)
sentence = jieba.lcut(sentence)
sentence = sentence[:max_sentence_length] + [eos]
sentences.append(sentence)
data.append(sentences)
Author

OnlyourMiracle

Posted on

2024-03-20

Updated on

2024-04-24

Licensed under

Comments