最后更新于
最后更新于
重计算的是用时间换空间的思想, 在前向时只保存部分中间节点, 在反向时重新计算没保存的部分.
网络的一次训练包含前向计算, 反向计算, 优化参数三个步骤.
在前向计算过程中, 每个Operation
都会输出一个Tensor结果, 整个模型就会输出大量的隐层变量Tensor. 随着模型层数的加深, 这些Tensor在显存中累积会占用大量的显存.
下图是Bert large模型在一次训练迭代过程中显存使用的情况.
可以看到在前向计算过程中, 显存快速累积; 在反向计算过程中, 这些累积的Tensor又被快速消耗; 最后的参数的优化更新过程阶段, 显存一直处于一个低位.
那么为什么要在前向计算过程中累积这些Tensor呢? 这是因为在反向传播过程计算梯度时, 前向的结果Tensor会被使用到.
可以看到要计算本层参数的梯度, 需要上一层输出的Tensor. 如下面的中间图, 完成前向计算每层的输出的计算后进入到反向传播阶段, 以相反的过程依次计算每层的梯度, 然后才能销掉使用到的上一层的前向输出. 因此看到上面图中前向显存累积, 后项逐渐消耗的过程.
重计算就是让每个训练迭代过程做两次前向计算. 核心思想是将前向计算分割成多个段, 每个段的起始Tensor称为这个段的检查点(checkpoints). 前向计算时, 除了检查点以外的其他隐层Tensor占有的显存可以及时释放. 反向计算用到这些隐层Tensor时, 从前一个检查点开始, 重新进行这个段的前向计算, 就可以重新获得隐层Tensor.
上图中最右侧的图中, 蓝色点对应的Tensor, 即relu-ff
层的输出和sigmoid-ff
层的输出作为检查点, 前向过程只保存这两个Tensor. 在反向传播计算第一个fc-bp
时, 从第一个检查点relu-ff
开始, 重新走一遍前向传播过程, 直到计算得到最后的relu-rc
, 缓存这个过程中的结果Tensor, 供反向传播使用.
使用这种方法, 前向传播过程占用的显存就会大幅减小, 而在重新计算前向过程时, 只会计算两个检查点之间的结果, 产生的占用相对整个模型也会很小, 因此陡峭的显存上升直线变成了缓慢增长的Z形曲线, 如下图所示.
这样大幅缩小了训练过程中显存的使用上限, 进而可以显著增大模型训练的batch size.
理论上不使用重计算相当于每一层的输出都是检查点, 因此检查点越少, 显存再用的越少, 但同时消耗的实现越多, 因为带来了更多的前向重新计算量.
在BERT Base版本下, batch_size可以增大为原来的3倍左右
在BERT Large版本下, batch_size可以增大为原来的4倍左右
平均每个样本的训练时间大约增加25%
理论上, 层数越多, batch_size可以增大的倍数越大
例如Dense层中的输入与参数的点乘运算, 其中的输入就是上一层的输出:
对求梯度, 得到的结果是:
前向时每一层的得到的Tensor, 能否暂时抛弃掉, 在反向传播计算梯度需要时, 再重新生成呢? 这正是论文的思路.
参考中的实现, 见其中的recompute_grad
函数.