Multi-Head-Attention的时间复杂度
Bert中的self-attention使用的是Multi-Head Attention.
记序列的长度为n, 每个token向量对应的hidden size为d. 整个Multi-Head Attention的计算过程如下.
生成Q, K, V
Self Attention的输入为一个(n,d)大小的矩阵. 这个矩阵需要分别经过三个不同的大小为(d,d)的Dense, 生成Q, K, V矩阵, 对应的大小都是(n,d).
做三次(n,d)大小的矩阵与(d,d)大小的矩阵相乘, 这一步的时间复杂度为O(nd2).
计算相似度矩阵
普通的Attention中, Q与K两个矩阵相乘, 得到相似度矩阵. 两者的大小都是(n,d), 矩阵相乘得到(n,n)大小的相似度矩阵, 对应的时间复杂度为O(n2d).
但Multi-Head Attention要先把Q, K, V分割为m个head, 每个head中的hidden size记为a=d/m. 然后每个head内部进行Attention的计算, 不同head之间互不干扰.
因此考虑上head, Q和K矩阵需要转换为(n,m,a)大小的tensor, 点积得到(n,n,m)大小的相似度矩阵, 因此这一步的时间复杂度为O(n2ma)=O(n2d).
Softmax计算
对得到的按head划分的, 大小为(n,n,m)的相似度矩阵计算softmax, 这一步的时间复杂度为O(n2m).
计算加权和
将(n,n,m)大小的权值矩阵与(n,m,a)大小的O点乘, 得到O(n,m,a)大小的输入. 这一步沿着n所在的维度进行点乘, 所用的时间复杂度为O(n2ma)=O(n2d).
再经过reshape变成最后的输出, 大小仍然为(n,d).
总结
生成Q, K, V的时间复杂度为O(nd2)
计算相似度矩阵的时间复杂度为O(n2md)
Softamax计算的时间复杂度为O(n2m)
计算加权和的时间复杂度为O(n2d)
整体的时间复杂度为O(n2d+nd2), 即序列长度与hidden size更大的一个占主导.
最后更新于