产生原因
链式法则
深度学习模型为了获得很强的拟合能力和表达能力, 往往都使用深层网络. 其中每一层都可以看作是对其输入进行非线性转换, 即每一层都可以看作是一个非线性函数fl(x), 整个神经网络可以看做为嵌套的复合非线性函数:
F(x)=fL(⋯f3(f2(f1(x∗w1+b1)∗w2+b2)∗w3+b3)⋯) 更新网络参数使用梯度下降法, 从损失函数开始Loss=L(y,F(x)), 得到损失函数后, 根据链式法则从后向前更新每一层的参数. 假设第i层网络的输出是fi(x), 而第i层的输入正是第i−1层的输出, 即xi=fi−1(xi−1). 在求损失L对第i层参数wi的梯度时, 根据链式法则, 可以展开为:
∂wi∂Loss=∂fL∂Loss∂fL−1∂fL⋯∂fi∂fi+1∂wi∂fi 可以看到求第i层参数梯度的时候, 中间会连乘多个∂fi∂fi+1. 越靠前(靠近整个模型得输入位置)层的参数, 中间的连乘项越多.
把每一个连乘项进一步拆解, 网络中每一层的的输出可以表示为ai+1=f(zi+1)=f(ai∗wi+bi), 链式求导结果中的连乘项为∂ai∂ai+1=∂zi+1∂ai+1∂ai∂zi+1=f′(zi+1)∗wi. 即每一个连乘项的值由两部分确定:
因此如果这些连乘项的值都小于1, 随着层数的增多, 靠前层参数的梯度就会趋近于0, 发生了梯度消失; 如果连乘项的值都大于1, 靠前层的梯度将以指数形式叠加, 发生了梯度爆炸.
因此, 要深入分析梯度消失和梯度爆炸的原因, 需要从激活函数和权值大小两个角度进行.
激活函数
链式法则是梯度消失和梯度爆炸的根本原因, 而激活函数又进一步加剧了梯度消失的问题.
连乘项∂fi∂fi+1实际上等于当前层激活函数在x=fi−1情况下梯度值. 而使用sigmoid
, tanh
这些带有饱和区的激活函数时, 由于它们的梯度值都不会超过1, 因此注定了梯度消失的发生.
Sigmoid函数为σ(x)=1+e−x1, 对应的导数为σ′(x)=σ(1−σ), 因此导数的最大值只有0.25, 在深层网络中很容易发生梯度消失.
tanh函数为tanh(x)=ex−e−xex−e−x, 对应的导数为tanh′(x)=1−tanh2(x), 最大值为1.
正是这些梯度恒小于1, 且饱和区域很长且梯度为0的激活函数, 加重了梯度消失的问题.
权值初始化
网络中每一层的的输出可以表示为ai+1=f(zi+1)=f(ai∗wi+bi), 链式求导结果中的连乘项为∂ai∂ai+1=∂zi+1∂ai+1∂ai∂zi+1=f′(zi+1)∗wi. 可以看到, 连乘项中除了激活函数的导数外, 还有一项是本层的参数wi.
当初始化的参数值较大的时候, 参数的连乘会带来梯度爆炸的问题.
总结
梯度消失产生的原因为
根本原因在于链式法则, 靠前层的参数更新在连乘项叠加作用下, 整体梯度异常, 无法正常更新
激活函数的导数值恒小于1, 特别是sigmoid函数的梯度最大值只有0.25, 且梯度为0的饱和区域占绝大部分, 加剧了梯度消失的问题
梯度爆炸产生的原因为
根本原因在于链式法则, 靠前层的参数更新在连乘项叠加作用下, 整体梯度异常, 无法正常更新
权值初始化太大会在训练的一开始就带来梯度爆炸的问题
解决方法
梯度爆炸
WarmUp
极端的梯度爆炸会使得权重的变得非常大, 损失变为NaN
. 而梯度爆炸另一个常见的信号是:
模型不稳定, 训练过程中损失出现显著的变化, 大幅震荡的情况, 且在模型开始训练时最为常见
一种常用的优化方法是在训练的前N步使用WarmUp的方法进行训练, 使用一个极小的步长开始, 在一定的步数范围内, 逐步增大到正常的步长. 在这个过程中, 模型会以很小的梯度进行优化, 将权值参数优化到合理的范围内, 可以有效的消除震荡的情况.
梯度剪切
设置一个梯度剪切阈值, 更新梯度时, 如果梯度超过这个阈值, 那么就将其强制限制在这个范围之内.
权重正则
对参数权重使用正则化约束, L1正则或L2正则. 原理是如果发生梯度爆炸, 在爆炸梯度的作用下权值会变得很大, 因此整体损失就会变大, 从而在之后的更新中校正.
梯度消失
使用ReLU等激活函数
ReLU, Leaky ReLU, ELU等这类激活函数, 可以缓解梯度消失爆炸的问题. 原理很简单, 它们在正半轴的导数始终为1. 相当于连乘项各项的值在此时都为1, 解决了连乘带来的指数放大或缩小的问题.
Normalization
无论是Batch Normalization, 还是Layer Normalization, 它们都会将每个神经元的输出分布, 规范到稳定的分布中. 这种功能, 可以带来减轻梯度消失和梯度爆炸的好处.
在没有Normalization结构时, 我们可以将每一层的输出表示为输入的函数: ai+1=f(zi+1)=f(ai∗wi+bi), 因此链式求导结果中的连乘项为∂ai∂ai+1=∂zi+1∂ai+1∂ai∂zi+1=f′(zi+1)∗wi.
由于参数wi的存在, 当参数过小或过大时, 此连乘项的值就会很小/很大, 随后的连乘就会带来梯度的消失或爆炸.
加入了Normalization之后, 对应层的输出就会发生变化. 以Batch Normalization为例, 它作用在输入和权重计算后, 进入激活函数之前, 首先记:
μl: 列向量, 当前层输出(进入激活之前的输出)的平均值
σl: 列向量, 当前层输出的标准差
因此加入了BN之后的输出可以记为:
zl=BN(al−1wl)=σl1(al−1wl−μl) 考虑经过激活函数之后, 反向传播中的连乘项变为:
∂al−1∂al=∂zl∂al∂al−1∂zl=f′(zl)σlwl 相比之下, 最显著的变化, 是增加了一项σl1, 对权重wl进行缩放. 如果权重wl较小, 则zl=al−1wl必然会较小, 使得统计得到的σl就会较小, 这时使用σl1, 相当于在计算梯度时, 对参数wl进行了放大, 避免了梯度消失; 反之避免了梯度爆炸. 这时BN能够缓解梯度消失和爆炸的主要原因. 实际上Normalization具有权重伸缩不变性, 单纯地对权重进行缩放, 反向传播项∂al−1∂zl的大小不会发生任何变化, 详情参考Normalization综述.
再考虑连乘项中的激活函数导数部分. 如果使用的Sigmoid这种带有饱和区的激活函数, 它们在输入远离0的位置是不会产生梯度的, 只有输入在0附近, 这时的梯度才会达到最大水平. Normalization结构作用在每一层的输出进入到激活函数之前, 如果不考虑re-shift和re-scale的作用, Normalization结构会将进入到激活函数的输入调整为均值为0, 方差为1的分布, 从而使得输入落在激活函数的敏感区域, 可以产生较大的梯度, 避免了梯度消失. 而考虑了re-shift和re-scale, 只是将不同神经元的表现进行差异化, 会使得一部分神经元处在敏感状态, 一部分处在饱和状态, 相当于对整体模型结构做了正则化. 相当于选择了部分参数进行梯度消失, 降低了模型的复杂度.
因此, Normalization对梯度的传播的优化原理如下:
统计的滑动标准差σl在反向传播时调节不同参数wl大小带来的影响, 避免了梯度消失和梯度爆炸
在进入激活函数前, 将输入分布稳定, 使得输入落在激活函数的敏感区域, 在反向传播时能够提供较大的梯度, 避免了梯度消失
梯度消失、爆炸的原因及解决办法
残差结构
残差网络是由一系列的残差块组成, 每一个残差块可以表示为:
yl+1=xl+F(xl,Wl)xl+1=f(yl+1) 其中F(xl,Wl)是残差部分, 将本层的残差输出与本层的输入加和, 然后再经过一个激活函数f, 一般是ReLU, 得到本层的最终输出.
如果靠前的l层与靠后的L有直接的链接, 则可以表示为xL=xl+∑i=lL−1F(xi,Wi), 对应的损失函数关于xl的梯度可以记为:
∂xl∂ε=∂xL∂ε∂xl∂xL=∂xL∂ε(1+∂xl∂∑i=1L−1F(xi,Wi))=∂xL∂ε+∂xL∂ε∂xl∂∑i=1L−1F(xi,Wi) 在训练过程中, ∂xl∂∑i=1L−1F(xi,Wi)不可能一直都为-1, 因此如果L层有梯度∂xL∂ε, 则l层一定有梯度.
上式表明了L层的梯度可以直接传递到任意一个比它浅的l层, 不经过中间层的权重矩阵. 因此解决了梯度消失的问题.
总结
总的来说, 缓解梯度消失和梯度爆炸的方法, 主要有以下几种:
权重正则: 过大的梯度造成很大的权重, 在正则的约束下产生较大的损失, 迫使后续的迭代中参数变小, 避免梯度爆炸
激活函数: 选择导数为常数1的激活函数, 消除连乘项中激活函数带来的影响, 缓解梯度消失和梯度爆炸
合理的权值初始化: 初始化的权值过大过小都会造成问题, 合理的初始化, 能够缓解梯度消失和梯度爆炸的问题
Warm Up: 训练开始时, 使用较小的梯度, 避免过大的梯度让模型反复横跳, 损失震荡无法收敛, 使得模型可以得到训练, 进入到合理的优化路径中. 缓解了训练起始时梯度过大的问题
Normalization: 缩放平移的特性, 很大程度上缓解了梯度消失和梯度爆炸的问题
残差结构: 使得深层的梯度可以直接传递到浅层中, 保证了每一层参数更新时梯度的大小, 避免了梯度消失的问题