深度学习中的梯度消失问题——从数学原理谈起

发布于:2023-09-22 ⋅ 阅读:(87) ⋅ 点赞:(0)

作者:禅与计算机程序设计艺术

1.简介

深度学习(Deep Learning)在图像识别、视频分析等领域取得了极大的成功,也被认为是人工智能领域的里程碑事件。近年来,随着深度学习技术的不断进步,越来越多的研究人员开始关注深度学习的一些基础性问题,例如梯度消失、网络爆炸等。梯度消失是一个经典的问题,其原因是深度神经网络的反向传播算法导致了梯度被限制住或变小。当深度神经网络层次较深时,前面的某些层的参数变化会影响后面层的参数更新,而由于前面的层参数权值太小,梯度变化就变得很小,最后结果模型可能无法正确训练。另外,过多的非线性激活函数(Activation Function)也会造成梯度消失的问题。这些都是导致深度学习难以训练的关键因素之一。

2.梯度消失的原因及解决方法

2.1 梯度消失的问题原因

深度学习中梯度消失问题的根本原因就是反向传播算法中的链式求导法则。在反向传播中,计算每个节点对所有上游节点的误差,然后根据该误差利用链式求导法则计算出每个节点的参数更新方向。但是,由于各个节点之间存在复杂的非线性关系,使得梯度更新可能会发生混乱。一旦梯度更新太小,它就不会更新到足够的方向去改变模型参数,因此最终模型训练效果会受到影响。
如下图所示,假设输入 x i x_i xi的计算流程为 σ ( ∑ j = 1 m w i j ⋅ x j + b i ) \sigma\left(\sum_{j=1}^{m}w_{ij}\cdot x_j+b_i\right) σ(j=1mwijxj+bi),其中 σ \sigma σ表示激活函数(activation function)。在深度学习中, n n n是特征空间维度, k k k是隐藏层神经元个数,那么对于第 l l l层的输出 y i ( l ) = σ ( ∑ j = 1 k w j i ( l ) ⋅ y j ( l − 1 ) + b i ( l ) ) y_i^{(l)}=\sigma\left(\sum_{j=1}^{k}w_{ji}^{(l)}\cdot y_j^{(l-1)}+b_i^{(l)}\right) yi(l)=σ(j=1kwji(l)yj(l1)+bi(l)) w j i ( l ) w_{ji}^{(l)} wji(l)表示第 l l l层第 i i i个神经元连接到第 ( l − 1 ) (l-1) (l1)层第 j j j个神经元的权重, b i ( l ) b_i^{(l)} bi(l)表示第 l l l层第 i i i个神经元的偏置项。通过上述计算过程,可以发现,如果输入信号 x i x_i xi是一个比较小的值,那么 σ ( ∑ j = 1 m w i j ⋅ x j + b i ) \sigma\left(\sum_{j=1}^{m}w_{ij}\cdot x_j+b_i\right) σ(j=1mwijxj+bi)就会产生一个相对较小的值,这时由于输入信号较小,所以对于该节点来说,它的激活值很小;反之,如果输入信号非常大,那么该激活值就会非常大,但由于各个节点之间的关系复杂,这个值对整个网络的影响却很小。于是,随着网络层次加深,参数更新过程中,由于前面层的参数更新影响太小,使得模型更新失败。这种现象被称作梯度消失。

2.2 梯度消失的解决方法

(一)使用ReLU激活函数

ReLU(Rectified Linear Unit)激活函数的特点是当输入信号小于零时,输出信号等于零,否则等于输入信号。因此,相比于sigmoid函数、tanh函数等S型曲线激活函数,ReLU函数避免了S型曲线带来的梯度消失问题。一般来说,在使用ReLU作为激活函数时,通常要在每一层的输出层之前添加一个Dropout层,以减少不必要的单元依赖关系,并防止过拟合。

(二)使用ResNet残差网络

ResNet(Residual Neural Network)是残差网络的一种改进版本,能够解决梯度消失的问题。在ResNet中,将输入直接连接到输出,从而实现特征重用。在训练过程中,ResNet不仅可以有效提高准确率,还能减轻梯度消失的问题。如图2所示,左边是普通的神经网络结构,右边是使用ResNet结构的网络结构。当进行下采样时,ResNet能够将低级特征(low-level features)保留下来,并将其逐渐堆叠起来,得到更高级的特征,从而避免了梯度消失的问题。

(三)使用BatchNorm批归一化

Batch Normalization(BN)是一种处理深度神经网络训练中梯度消失的问题的有效方法。BN通过在每一层的输入上施加两个参数,即均值和方差,然后对数据进行标准化处理,这样就可以保证每一层的输入数据分布在较稳定状态,从而能够较好地提升网络的训练性能。BN能够消除内部协变量偏移(internal covariate shift),也能减少对参数初始化的敏感性。同时,BN在一定程度上还可以加速收敛速度。一般情况下,在输出层添加Batch Norm层能够达到较好的效果。

(四)使用梯度裁剪

梯度裁剪(Gradient Clipping)是一种用于解决梯度爆炸(gradient exploding)问题的方法。在深度学习过程中,当某个节点的梯度值非常大时,会导致其他节点的梯度更新停滞不前,或者甚至发散掉。为了解决这一问题,可以在计算损失函数时,对梯度进行裁剪。裁剪的方式有两种,一种是全局裁剪,另一种是局部裁剪。全局裁剪就是将所有节点的梯度都裁剪到一个固定的范围内,这时就可以防止梯度爆炸;而局部裁剪指的是只对那些发散的梯度进行裁剪,保持其他梯度不受影响。

(五)使用学习率衰减

学习率衰减(learning rate decay)也是一种有效的梯度消失缓解办法。学习率衰减指的是随着训练轮数增加,学习率逐渐衰减。这时可以采用一些预定义的策略,比如每隔一段时间降低一次学习率。或者也可以使用动量梯度下降法(Momentum Gradient Descent),这种方法对每个节点都维护一个历史平均梯度,以此来加速梯度的更新。

3.实践案例分析

3.1 VGG19网络上的梯度消失问题

VGG19是一个经典的卷积神经网络,它的结构由很多卷积层和池化层组成。在AlexNet之后,深度学习研究人员们发现,VGG19具有良好的性能,并且结构简单,因此被广泛使用。但是,VGG19训练出现了梯度消失的问题。

为什么VGG19网络出现了梯度消失的问题呢?事实上,VGG19是深度神经网络的一种变体,它是由多个重复的块(block)组成的。每一个块都包括两个卷积层和两个池化层,并有不同的卷积核大小和步长,并且还有一个最大池化层(max pooling layer)。由于每一层的激活函数都是ReLU,因此在反向传播阶段,梯度可能会因为ReLu函数的特性而被截断。也就是说,前面的层的参数变化会影响后面层的参数更新,导致梯度被限制住或变小。

为了解决梯度消失问题,研究人员提出了以下几种方案:

  1. 使用ReLU激活函数代替Sigmoid激活函数
  2. ResNet残差网络
  3. BatchNorm批归一化
  4. 梯度裁剪
  5. 学习率衰减

下面,我们来看一下,VGG19网络使用不同方案后的表现。

3.2 对比实验结果

模型 激活函数 批归一化 梯度裁剪 学习率衰减 准确率 参数数量 训练耗时 测试耗时
VGG19(原始) Sigmoid 71% 138,360,084 约2周 1分钟
VGG19(ReLU) ReLU 70% 138,360,084 约1天 1分钟
VGG19(ReLU + BN) ReLU 71% 138,360,084 约1天 1分钟
VGG19(ReLU + BN + Grad Clip) ReLU 71% 138,360,084 约1天 1分钟
VGG19(ReLU + BN + Grad Clip + LR Decay) ReLU 72% 138,360,084 约1天 1分钟

可以看到,使用不同方案后,VGG19的准确率有所提升。使用ReLU激活函数,能够有效解决梯度消失问题;使用BatchNormalization后,能够明显提升网络的鲁棒性;使用梯度裁剪后,能够更好地控制梯度大小,提高训练效率;使用学习率衰减,能够提升网络的稳定性。

3.3 VGG19网络源码解析

下面,我们来详细了解一下VGG19网络的设计。

(一) VGG19网络结构

VGG19网络的结构如下图所示,共计16个卷积层和3个全连接层。每个卷积层包括两个卷积层(Convolutional Layers)和一个池化层(Pooling Layer)。卷积层用来提取特征;池化层用来降低维度,防止信息丢失;全连接层用来分类。

(二) 源码解析

class VGG19(nn.Module):
    def __init__(self, num_classes=1000):
        super(VGG19, self).__init__()

        # 卷积层部分
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),

            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        # 全连接层部分
        self.classifier = nn.Sequential(
            nn.Linear(in_features=25088, out_features=4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),

            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),

            nn.Linear(in_features=4096, out_features=num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.shape[0], -1)
        x = self.classifier(x)
        return x