# 重计算技巧

## 前向重计算

重计算的是**用时间换空间**的思想, **在前向时只保存部分中间节点, 在反向时重新计算没保存的部分**.

### 显存占用原因

网络的一次训练包含**前向计算**, **反向计算**, **优化参数**三个步骤.

在**前向计算**过程中, 每个`Operation`都会输出一个Tensor结果, 整个模型就会输出大量的隐层变量Tensor. 随着模型层数的加深, 这些Tensor在显存中累积会占用大量的显存.

下图是Bert large模型在一次训练迭代过程中显存使用的情况.

![](https://1942165044-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MI7KRyeBH5dlW-CkUtn%2Fsync%2Fa1149c07dcfdcb43a4f7d4a6081ee0bffbaf5748.jpg?generation=1611708012143888\&alt=media)

可以看到在前向计算过程中, 显存快速累积; 在反向计算过程中, 这些累积的Tensor又被快速消耗; 最后的参数的优化更新过程阶段, 显存一直处于一个低位.

那么为什么要在前向计算过程中累积这些Tensor呢? 这是因为**在反向传播过程计算梯度时, 前向的结果Tensor会被使用到**.

例如Dense层中的输入与参数的点乘运算, 其中的输入$$X$$就是上一层的输出:

$$Y=WX$$

对$$W$$求梯度, 得到的结果是:

$$\frac{d L}{d W}=\frac{d L}{d Y} X^{T}$$

可以看到要计算本层参数的梯度, 需要上一层输出的Tensor. 如下面的中间图, 完成前向计算每层的输出的计算后进入到反向传播阶段, 以相反的过程依次计算每层的梯度, 然后才能销掉使用到的上一层的前向输出. 因此看到上面图中前向显存累积, 后项逐渐消耗的过程.

![](https://1942165044-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MI7KRyeBH5dlW-CkUtn%2Fsync%2Fa97c2ea5d3d992a68a7023080312b083bd374655.jpg?generation=1611708011747983\&alt=media)

### 解决方法

前向时每一层的得到的Tensor, 能否暂时抛弃掉, 在反向传播计算梯度需要时, **再重新生成**呢? 这正是[Training Deep Nets with Sublinear Memory Cost](https://arxiv.org/abs/1604.06174)论文的思路.

**重计算**就是**让每个训练迭代过程做两次前向计算**. 核心思想是**将前向计算分割成多个段**, 每个段的起始Tensor称为这个段的检查点(checkpoints). 前向计算时, **除了检查点以外的其他隐层Tensor占有的显存可以及时释放**. 反向计算用到这些隐层Tensor时, 从前一个检查点开始, 重新进行这个段的前向计算, 就可以重新获得隐层Tensor.

上图中最右侧的图中, 蓝色点对应的Tensor, 即`relu-ff`层的输出和`sigmoid-ff`层的输出作为检查点, 前向过程只保存这两个Tensor. 在反向传播计算第一个`fc-bp`时, 从第一个检查点`relu-ff`开始, 重新走一遍前向传播过程, 直到计算得到最后的`relu-rc`, 缓存这个过程中的结果Tensor, 供反向传播使用.

使用这种方法, 前向传播过程占用的显存就会大幅减小, 而在重新计算前向过程时, 只会计算两个检查点之间的结果, 产生的占用相对整个模型也会很小, 因此**陡峭的显存上升直线变成了缓慢增长的Z形曲线**, 如下图所示.

![](https://1942165044-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MI7KRyeBH5dlW-CkUtn%2Fsync%2F66fbac719dd4455590e18d2ff63a0059d8486f8f.jpg?generation=1611708011665271\&alt=media)

这样大幅缩小了训练过程中显存的使用上限, 进而可以显著增大模型训练的batch size.

理论上不使用重计算相当于每一层的输出都是检查点, 因此检查点越少, 显存再用的越少, 但同时消耗的实现越多, 因为带来了更多的前向重新计算量.

### 性能

* 在BERT Base版本下, batch\_size可以增大为原来的3倍左右
* 在BERT Large版本下, batch\_size可以增大为原来的4倍左右
* 平均每个样本的训练时间大约增加25%
* 理论上, 层数越多, batch\_size可以增大的倍数越大

### keras实现

参考[bert4keras](https://github.com/bojone/bert4keras/blob/master/bert4keras/backend.py)中的实现, 见其中的`recompute_grad`函数.

## 参考资料

* [Training Deep Nets with Sublinear Memory Cost](https://arxiv.org/abs/1604.06174)
* [BERT重计算：用22.5%的训练时间节省5倍的显存开销](https://mp.weixin.qq.com/s/CmIVwGFqrSD0wcSN_hgH1A)
* [节省显存的重计算技巧也有了Keras版了](https://kexue.fm/archives/7367)
