重计算技巧
前向重计算
重计算的是用时间换空间的思想, 在前向时只保存部分中间节点, 在反向时重新计算没保存的部分.
显存占用原因
网络的一次训练包含前向计算, 反向计算, 优化参数三个步骤.
在前向计算过程中, 每个Operation都会输出一个Tensor结果, 整个模型就会输出大量的隐层变量Tensor. 随着模型层数的加深, 这些Tensor在显存中累积会占用大量的显存.
下图是Bert large模型在一次训练迭代过程中显存使用的情况.

可以看到在前向计算过程中, 显存快速累积; 在反向计算过程中, 这些累积的Tensor又被快速消耗; 最后的参数的优化更新过程阶段, 显存一直处于一个低位.
那么为什么要在前向计算过程中累积这些Tensor呢? 这是因为在反向传播过程计算梯度时, 前向的结果Tensor会被使用到.
例如Dense层中的输入与参数的点乘运算, 其中的输入X就是上一层的输出:
Y=WX
对W求梯度, 得到的结果是:
dWdL=dYdLXT
可以看到要计算本层参数的梯度, 需要上一层输出的Tensor. 如下面的中间图, 完成前向计算每层的输出的计算后进入到反向传播阶段, 以相反的过程依次计算每层的梯度, 然后才能销掉使用到的上一层的前向输出. 因此看到上面图中前向显存累积, 后项逐渐消耗的过程.

解决方法
前向时每一层的得到的Tensor, 能否暂时抛弃掉, 在反向传播计算梯度需要时, 再重新生成呢? 这正是Training Deep Nets with Sublinear Memory Cost论文的思路.
重计算就是让每个训练迭代过程做两次前向计算. 核心思想是将前向计算分割成多个段, 每个段的起始Tensor称为这个段的检查点(checkpoints). 前向计算时, 除了检查点以外的其他隐层Tensor占有的显存可以及时释放. 反向计算用到这些隐层Tensor时, 从前一个检查点开始, 重新进行这个段的前向计算, 就可以重新获得隐层Tensor.
上图中最右侧的图中, 蓝色点对应的Tensor, 即relu-ff层的输出和sigmoid-ff层的输出作为检查点, 前向过程只保存这两个Tensor. 在反向传播计算第一个fc-bp时, 从第一个检查点relu-ff开始, 重新走一遍前向传播过程, 直到计算得到最后的relu-rc, 缓存这个过程中的结果Tensor, 供反向传播使用.
使用这种方法, 前向传播过程占用的显存就会大幅减小, 而在重新计算前向过程时, 只会计算两个检查点之间的结果, 产生的占用相对整个模型也会很小, 因此陡峭的显存上升直线变成了缓慢增长的Z形曲线, 如下图所示.

这样大幅缩小了训练过程中显存的使用上限, 进而可以显著增大模型训练的batch size.
理论上不使用重计算相当于每一层的输出都是检查点, 因此检查点越少, 显存再用的越少, 但同时消耗的实现越多, 因为带来了更多的前向重新计算量.
性能
在BERT Base版本下, batch_size可以增大为原来的3倍左右
在BERT Large版本下, batch_size可以增大为原来的4倍左右
平均每个样本的训练时间大约增加25%
理论上, 层数越多, batch_size可以增大的倍数越大
keras实现
参考bert4keras中的实现, 见其中的recompute_grad函数.
参考资料
最后更新于