重计算技巧

前向重计算

重计算的是用时间换空间的思想, 在前向时只保存部分中间节点, 在反向时重新计算没保存的部分.

显存占用原因

网络的一次训练包含前向计算, 反向计算, 优化参数三个步骤.

前向计算过程中, 每个Operation都会输出一个Tensor结果, 整个模型就会输出大量的隐层变量Tensor. 随着模型层数的加深, 这些Tensor在显存中累积会占用大量的显存.

下图是Bert large模型在一次训练迭代过程中显存使用的情况.

可以看到在前向计算过程中, 显存快速累积; 在反向计算过程中, 这些累积的Tensor又被快速消耗; 最后的参数的优化更新过程阶段, 显存一直处于一个低位.

那么为什么要在前向计算过程中累积这些Tensor呢? 这是因为在反向传播过程计算梯度时, 前向的结果Tensor会被使用到.

例如Dense层中的输入与参数的点乘运算, 其中的输入XX就是上一层的输出:

Y=WXY=WX

WW求梯度, 得到的结果是:

dLdW=dLdYXT\frac{d L}{d W}=\frac{d L}{d Y} X^{T}

可以看到要计算本层参数的梯度, 需要上一层输出的Tensor. 如下面的中间图, 完成前向计算每层的输出的计算后进入到反向传播阶段, 以相反的过程依次计算每层的梯度, 然后才能销掉使用到的上一层的前向输出. 因此看到上面图中前向显存累积, 后项逐渐消耗的过程.

解决方法

前向时每一层的得到的Tensor, 能否暂时抛弃掉, 在反向传播计算梯度需要时, 再重新生成呢? 这正是Training Deep Nets with Sublinear Memory Cost论文的思路.

重计算就是让每个训练迭代过程做两次前向计算. 核心思想是将前向计算分割成多个段, 每个段的起始Tensor称为这个段的检查点(checkpoints). 前向计算时, 除了检查点以外的其他隐层Tensor占有的显存可以及时释放. 反向计算用到这些隐层Tensor时, 从前一个检查点开始, 重新进行这个段的前向计算, 就可以重新获得隐层Tensor.

上图中最右侧的图中, 蓝色点对应的Tensor, 即relu-ff层的输出和sigmoid-ff层的输出作为检查点, 前向过程只保存这两个Tensor. 在反向传播计算第一个fc-bp时, 从第一个检查点relu-ff开始, 重新走一遍前向传播过程, 直到计算得到最后的relu-rc, 缓存这个过程中的结果Tensor, 供反向传播使用.

使用这种方法, 前向传播过程占用的显存就会大幅减小, 而在重新计算前向过程时, 只会计算两个检查点之间的结果, 产生的占用相对整个模型也会很小, 因此陡峭的显存上升直线变成了缓慢增长的Z形曲线, 如下图所示.

这样大幅缩小了训练过程中显存的使用上限, 进而可以显著增大模型训练的batch size.

理论上不使用重计算相当于每一层的输出都是检查点, 因此检查点越少, 显存再用的越少, 但同时消耗的实现越多, 因为带来了更多的前向重新计算量.

性能

  • 在BERT Base版本下, batch_size可以增大为原来的3倍左右

  • 在BERT Large版本下, batch_size可以增大为原来的4倍左右

  • 平均每个样本的训练时间大约增加25%

  • 理论上, 层数越多, batch_size可以增大的倍数越大

keras实现

参考bert4keras中的实现, 见其中的recompute_grad函数.

参考资料

最后更新于

这有帮助吗?