Multi-Head-Attention的时间复杂度

Bert中的self-attention使用的是Multi-Head Attention.

记序列的长度为nn, 每个token向量对应的hidden sizedd. 整个Multi-Head Attention的计算过程如下.

生成Q, K, V

Self Attention的输入为一个(n,d)(n, d)大小的矩阵. 这个矩阵需要分别经过三个不同的大小为(d,d)(d, d)的Dense, 生成Q, K, V矩阵, 对应的大小都是(n,d)(n, d).

做三次(n,d)(n, d)大小的矩阵与(d,d)(d, d)大小的矩阵相乘, 这一步的时间复杂度为O(nd2)O(nd^2).

计算相似度矩阵

普通的Attention中, Q与K两个矩阵相乘, 得到相似度矩阵. 两者的大小都是(n,d)(n, d), 矩阵相乘得到(n,n)(n, n)大小的相似度矩阵, 对应的时间复杂度为O(n2d)O(n^2d).

但Multi-Head Attention要先把Q, K, V分割为mmhead, 每个head中的hidden size记为a=d/ma=d/m. 然后每个head内部进行Attention的计算, 不同head之间互不干扰.

因此考虑上head, Q和K矩阵需要转换为(n,m,a)(n, m, a)大小的tensor, 点积得到(n,n,m)(n, n, m)大小的相似度矩阵, 因此这一步的时间复杂度为O(n2ma)=O(n2d)O(n^2ma)=O(n^2d).

Softmax计算

对得到的按head划分的, 大小为(n,n,m)(n, n, m)的相似度矩阵计算softmax, 这一步的时间复杂度为O(n2m)O(n^2m).

计算加权和

(n,n,m)(n, n, m)大小的权值矩阵与(n,m,a)(n, m, a)大小的O点乘, 得到O(n,m,a)O(n, m, a)大小的输入. 这一步沿着nn所在的维度进行点乘, 所用的时间复杂度为O(n2ma)=O(n2d)O(n^2ma)=O(n^2d).

再经过reshape变成最后的输出, 大小仍然为(n,d)(n, d).

总结

  • 生成Q, K, V的时间复杂度为O(nd2)O(nd^2)

  • 计算相似度矩阵的时间复杂度为O(n2md)O(n^2md)

  • Softamax计算的时间复杂度为O(n2m)O(n^2m)

  • 计算加权和的时间复杂度为O(n2d)O(n^2d)

整体的时间复杂度为O(n2d+nd2)O(n^2d+nd^2), 即序列长度与hidden size更大的一个占主导.

最后更新于

这有帮助吗?