Pytorch Learning Tour
This article records my process of learning Pytorch.
Chapter2
神经网络
神经网络的通用训练步骤
(1)定义一个包含可学习参数的神经网络。
(2)加载用于训练该网络的数据集。
(3)进行前向传播得到网络的输出结果,计算损失(网络输出结果与正确结果的差距)。
(4)进行反向传播,更新网络参数。
(5)保存网络模型。
定义网络
在定义网络时,模型需要继承nn.Module
,并实现它的forward方法。其中,网络里含有可学习参数的层应该放在构造函数__init__()
中,如果某一层(如ReLU)不含有可学习参数,那么它既可以放在构造函数中,又可以放在forward方法中。这里将这些不含有可学习参数的层放在forward方法中,并使用nn.functional
实现:
1 | import torch.nn as nn |
1 | Out:Net( |
用户只需要在nn.Module
的子类中定义了forward函数,backward函数就会自动实现(利用autograd)。在forward函数中不仅可以使用Tensor支持的任何函数,还可以使用if、for、print、log等Python语法,写法和标准的Python写法一致。
使用net.parameters()
可以得到网络的可学习参数,使用net.named_parameters()
可以同时得到网络的可学习参数及其名称。
注意:torch.nn
只支持输入mini-batch,不支持一次只输入一个样本。如果只输入一个样本,那么需要使用 input.unsqueeze(0)
将batch_size设为1。例如, nn.Conv2d
的输入必须是4维,形如$\text{nSamples} \times \text{nChannels} \times \text{Height} \times \text{Width}$。如果一次输入只有一个样本,那么可以将$\text{nSample}$设置为1,即$1 \times \text{nChannels} \times \text{Height} \times \text{Width}$。
损失函数
torch.nn
实现了神经网络中大多数的损失函数,例如nn.MSELoss
用来计算均方误差,nn.CrossEntropyLoss
用来计算交叉熵损失等。
1 | input = t.randn(1, 1, 32, 32) |
当调用loss.backward()
时,计算图会动态生成并自动微分,自动计算图中参数(parameters)的导数,示例如下:
1 | In: # 运行.backward,观察调用之前和调用之后的grad |
优化器
在完成反向传播中所有参数的梯度计算后,需要使用优化方法来更新网络的权重和参数。常用的随机梯度下降法(SGD)的更新策略如下:
1 | # weight = weight - learning_rate * gradient |
torch.optim
中实现了深度学习中大多数优化方法,例如RMSProp、Adam、SGD等,
- 随机梯度下降(SGD)
- 优点
- 简单易实现。
- 在大规模数据集和高维空间中依然有效,因为它每次更新只考虑一个样本,计算效率较高。
- 缺点
- 收敛速度可能比较慢,尤其是在梯度较小的平坦区域。
- 可能会陷入局部最优解。
- 动量(Momentum)SGD
- 优点
- 通过累积过去梯度的信息来加速SGD,在相关方向上加快学习速度,减缓在非相关方向上的学习速度,从而加快收敛。
- 缺点
- 需要选择额外的动量参数,增加了调参的复杂度。
- RMSProp
- 优点
- 通过调整学习率来加快训练速度,适合处理非平稳目标——对于RNN的效果很好。
- 能够在很多非凸优化问题中快速收敛。
- 缺点
- 和Momentum一样,需要设置更多的超参数(如学习率和衰减系数)。
- Adam(Adaptive Moment Estimation)
- 优点
- 结合了Momentum和RMSProp的优点,对学习率进行自适应调整。
- 通常在很多不同的深度学习模型中表现良好,被广泛使用。
- 缺点
- 相对于简单的SGD,计算资源消耗更大。
- 在某些情况下,Adam的自适应学习率可能导致收敛到次优解。
因此,通常情况下用户不需要手动实现上述代码。下面举例说明如何使用torch.optim
进行网络的参数更新:
1 | In: import torch.optim as optim |
数据加载与预处理
在深度学习中,数据加载及预处理是非常繁琐的过程。幸运的是,PyTorch提供了一些可以极大简化和加快数据处理流程的工具:Dataset
与DataLoader
。同时,对于常用的数据集,PyTorch提供了封装好的接口供用户快速调用,这些数据集主要保存在torchvision
中。torchvision
是一个视觉工具包,它提供了许多视觉图像处理的工具,主要包含以下三部分。
- datasets:提供了常用的数据集,如MNIST、CIFAR-10、ImageNet等。
- models:提供了深度学习中经典的网络结构与预训练模型,如ResNet、MobileNet等。
- transforms:提供了常用的数据预处理操作,主要包括对Tensor、PIL Image等的操作。
Chapter3
Pytorch Learning Tour