最后更新于
最后更新于
所谓梯度累积(accumulate gradients), 优化器梯度下降所用的梯度, 实际上是多个样本算出来的梯度的平均值. 以batch size为128为例, 可以一次性算出128个样本的梯度, 然后求平均值, 也可以一次只算16个样本的平均梯度, 缓存下来, 继续计算下面16个样本的平均梯度, 与之前缓存的数值累加, 等算够8次之后, 再把累加得到的值除以8, 再执行梯度更新, 这样得到的效果与直接计算128个样本是相同的, 但此时我们在一次计算中只需要计算16个样本, 这样就把batch size降到了16.
延续上例, 以batch size为16, 计算8次平均梯度, 执行一次参数更新, 也就是说前7次都的更新量是0, 第8次才是真正的参数更新, 因此要根据当前的迭代轮数选择执行路径.
但Tensorflow后台不存在只执行一个分支的条件写法, 使用switch
, case
, cond
等方法实际上也执行了每个条件分支, 只是最后得到的结果根据条件从各个分支中选择.
我们声明一个条件矩阵cond(这里是一个条件标量)来显式地控制每一步迭代中参数的更新. 假设我们需要的累积迭代的数量为grad_accum_steps, 在迭代数能被这个数整除时, 说明我们需要更新参数了, 其他情况下参数不变. 因此cond
的计算方法为:
其中的self.iterations
指迭代步数. 这样得到的cond
在要更新参数的一步为1, 其他步时为0.
然后需要初始化一个缓存累积梯度的列表, 列表的长度等于所有可训练tensor的数量, 存储每一个参数tensor在每次迭代中累积的梯度, 记为accum_grads. 所有元素初始化为值为0, 大小与对应参数一致的tensor:
考虑如何更新这个缓存. 假设当前步样本计算得到的平均梯度为grads, 那么在更新参数的这一步, 我们也需要将累积梯度清零, 然后重新累积本轮梯度; 其他步继续累积即可:
最后考虑更新参数. Keras框架中参数, 或中间缓存的更新都是用K.update
方法实现的. 由于累积, 需要将当前累积的梯度除以累积的步数, 才是真正的平均梯度用来更新. 又因为实际上我们只在指定的累积迭代步数grad_accum_steps
的整数倍时执行更新, 因此update
方法要更换为:
其中x
要更新的参数, 将new_x
值写入到x
中实现更新. 在计算得到参数新值new_x
时, 使用到的梯度为:
使用梯度累积的前提是, 模型不包含Batch Normalization. 因为Batch Normalization在梯度下降的时候必须用整个batch的均值方差, 但每次迭代只有整个batch的部分样本, 不一致.
如果网络中用到了Batch Normalization, 就不能使用梯度累积. 如果要增大batch size, 就只能扩大显存.
详细的实现参考中extend_with_gradient_accumulation
函数.