AdaFactor

AdaFactor

从Adam优化器谈起

tt为迭代步数, αt\alpha_t为当前步学习率, L(θ)L(\theta)是损失函数, θ\theta是待优化参数, ϵ\epsilon是防止溢出的小正数.

Adam优化器的更新过程如下:

{gt=θL(θt)mt=β1mt1+(1β1)gtvt=β2vt1+(1β2)gt2m^t=mt/(1β1t)v^t=vt/(1β2t)θt=θt1αtm^t/v^t+ϵ\left\{\begin{aligned}&g_t = \nabla_{\theta} L(\theta_t)\\ &m_t = \beta_1 m_{t-1} + \left(1 - \beta_1\right) g_t\\ &v_t = \beta_2 v_{t-1} + \left(1 - \beta_2\right) g_t^2\\ &\hat{m}_t = m_t\left/\left(1 - \beta_1^t\right)\right.\\ &\hat{v}_t = v_t\left/\left(1 - \beta_2^t\right)\right.\\ &\theta_t = \theta_{t-1} - \alpha_t \hat{m}_t\left/\sqrt{\hat{v}_t + \epsilon}\right. \end{aligned}\right.

作为目前最常用的优化器, 有一个缺点是占用的显存较大. 要省显存, 就首先得知道显存花在哪里的. 首先最后计算得到的梯度是占用显存的, 且这部分是任何优化器都无法节省的.

除此之外, Adam优化器使用了一阶梯度mm和二阶梯度vv, 且在每一步使用滑动平均计算, 需要进行缓存, 这两部分也要占用缓存, 且各自占用的大小同上面的梯度一致.

AdaFactor节省显存的思路

抛弃动量

Adam性能优秀, 很重要的一个点是每一个参数都有自适应学习率, 从上面的公式中也可以看出:

θt=θt1αtm^t/v^t+ϵ\theta_t = \theta_{t-1} - \alpha_t \hat{m}_t / \sqrt{\hat{v}_t + \epsilon}

从上式中可以看出, 每个参数的自适应学习率为αt/v^t+ϵ\alpha_t\left/\sqrt{\hat{v}_t + \epsilon}\right., 即通过SGD+二阶动量来实现的.

因此作为节省缓存的第一步, 考虑直接抛弃一阶动量mm, 这样显存的占用直接节省了1/3.

低秩分解

然后继续尝试压缩二阶动量vv的大小. AdaFactor使用到了低秩分解.

Adam中每个参数都会有各自独立的学习率, 但SGD中所有的参数共用一个学习率, 且SGD在很多任务或数据集中也能取得不错的效果, 带来的一个思路是精调每一个参数自己的学习率不是特别重要, 因此启发我们将v^t\hat{v}_t换一种参数更少的近似可能也就足够了. 因此使用低秩分解来实现.

对于m×nm \times n大小的矩阵CC, 希望找到大小为m×km \times k的矩阵AA和大小为k×nk \times n的矩阵BB, 使得:

ABCAB \approx C

使用一个比较小的kk值, 这样矩阵AABB中的参数之和会远小于原来CC中参数的数量, 且仍能取得近似的效果, 这就是上面说的不再精调每个参数的学习率, 让参数之间共享部分信息.

AdaFactor中取k=1k=1, 将显存节省到了极致, 即寻找{ai}i=1m\{a_i\}_{i=1}^m{bj}j=1n\{b_j\}_{j=1}^n, 使得:

aibjci,ja_i b_j \approx c_{i,j}

为了达到近似的效果, 需要一个距离度量标准来进行约束, 容易想到欧式距离i,j(aibjci,j)2\sum_{i,j} (a_i b_j - c_{i,j})^2, 但这样ai,bja_i,b_j没有解析解, 且在优化过程中ci,jc_{i,j}, 即对应于更新中的二阶梯度v^t\hat{v}_t应当是非负的, 但通过上面的目标函数优化得到的ci,jc_{i,j}无法保证非负.

因此AdaFactor使用了新的度量标准, 广义KL散度, 形式为:

l=i,jci,jlogci,jaibjci,j+aibjl = \sum_{i,j} c_{i,j}\log \frac{c_{i,j}}{a_i b_j} - c_{i,j} + a_i b_j

这个度量标准来自不等式xlogxx1(x>0)x\log x\geq x - 1(\forall x > 0), 当且仅当x=1x=1时等号成立. 将x=p/q(p,q>0)x = p / q\,(p,q > 0)带入到不等式当中, 然后两端乘以qq, 则有:

plogpqp+q0p\log \frac{p}{q} - p + q \geq 0

当且仅当p=qp=q时, 等号成立. 将pp替换成ci,jc_{i,j}, qq替换成aibja_i b_j, 并且对所有的分量进行求和, 就得到了上面的广义KL散度的公式. 由于有取最小值的条件, 在将aia_i, bjb_j, ci,jc_{i,j}带入后, 刚好有解析解:

ai=jci,j,bj=ici,ji,jci,ja_i = \sum\limits_{j}c_{i,j},\quad b_j = \frac{\sum\limits_{i}c_{i,j}}{\sum\limits_{i,j}c_{i,j}}

解析解也很形象, 就是行, 列分别求和, 然后相乘, 再除以全体的和. 推导过程参考推导过程.

因此我们就可以维护两组缓存变量vt(r)Rm,vt(c)Rnv^{(r)}_t\in \mathbb{R}^m,v^{(c)}_t\in\mathbb{R}^n, 代表gt2g_t^2低秩分解后的结果, 解析解保证了vt(r)vt(c)v^{(r)}_tv^{(c)}_t点乘于原始二阶动量gt2g_t^2之间在广义KL散度度量下的最大近似性. 因此AdaFactor优化器的计算流程如下:

{gi,j;t=θL(θi,j;t)vi;t(r)=β2vt1;i(r)+(1β2)j(gi,j;t2+ϵ)vj;t(c)=β2vt1;j(c)+(1β2)i(gi,j;t2+ϵ)vi,j;t=vi;t(r)vj;t(c)/jvj;t(c)v^t=vt/(1β2t)θt=θt1αtgt/v^t\left\{\begin{aligned}&g_{i,j;t} = \nabla_{\theta} L(\theta_{i,j;t})\\ &v^{(r)}_{i;t} = \beta_2 v^{(r)}_{t-1;i} + \left(1 - \beta_2\right) \sum\limits_{j}\left(g_{i,j;t}^2+\epsilon\right)\\ &v^{(c)}_{j;t} = \beta_2 v^{(c)}_{t-1;j} + \left(1 - \beta_2\right) \sum\limits_{i}\left(g_{i,j;t}^2+\epsilon\right)\\ &v_{i,j;t} = v^{(r)}_{i;t} v^{(c)}_{j;t}\left/\sum\limits_{j}v^{(c)}_{j;t}\right.\\ &\hat{v}_t = v_t\left/\left(1 - \beta_2^t\right)\right.\\ &\theta_t = \theta_{t-1} - \alpha_t g_t\left/\sqrt{\hat{v}_t}\right. \end{aligned}\right.

vt(r)Rm,vt(c)Rnv^{(r)}_t\in \mathbb{R}^m,v^{(c)}_t\in\mathbb{R}^n两变量的更新逻辑, 保证了两变量的非负性, 使得不等式xlogxx1(x>0)x\log x\geq x - 1(\forall x > 0)始终满足条件, 解析解成立, 得到理论保证的近似效果, 进一步保证了优化的效果.

参考资料

最后更新于

这有帮助吗?