多头自注意力的计算量

符号标记

$Q, K, V\in \mathbb{R}^{n\times d_{model}}$

$W^Q, W^K, W^V, W^O \in \mathbb{R}^{d_{model} \times d_{model}}$, 可以将 $W^Q, W^K, W^V$ 分为 $h$ 个头, 即

$W_i^Q, W_i^K, W_i^V \in \mathbb{R}^{d_{model} \times d_k}$, 其中 $d_k = d_{model} / h$, $i=1,2,...,h$.

这里只考虑矩阵乘的计算量.

多头投影步骤计算量

$QW_i^Q$, $KW_i^K$, $VW_i^V$ 的计算量均为 $n \cdot d_{model} \cdot d_{k}$.

共 $h$ 个头, 所以该步骤总计算量为 $3 \cdot h \cdot n \cdot d_{model} \cdot d_{k}$, 因为 $d_k = d_{model} / h$, 所以计算量可以转化为 $3 \cdot n \cdot d_{model}^2$.

注意力计算步骤计算量

令 $Q'_i=QW_i^Q, K'_i=KW_i^K, V'_i=VW_i^V$, 它们的尺寸均为 $n\times d_{k}$.

$\mathrm{attn}_i = \mathrm{Attention}(Q'_i, K'_i, V'_i)$ (其尺寸也是 $n\times d_{k}$) 的计算量为 $2 \cdot n^2 \cdot d_k$, 因为
$Q'_iK'^T_i$ 的计算量为 $n^2 \cdot d_k$, 其经过 softmax 等操作后乘上 $V'_i$ 的计算量也为 $n^2 \cdot d_k$.

共 $h$ 个头, 所以该步骤的总计算量为 $2 \cdot h \cdot n^2 \cdot d_k$, 即 $2 \cdot n^2 \cdot d_{model}$.

输出投影步骤计算量

令 $S = \mathrm{Concat}(\mathrm{attn}_1, ..., \mathrm{attn}_h)$, 其尺寸为 $n\times hd_{k}$, 即 $n\times d_{model}$.

该步骤 $S W_O$ 的计算量为 $n \cdot d_{model}^2$.

总结

综上多头自注意力的 (所有矩阵乘运算的) 计算量为 $4nd_{model}^2 + 2n^2d_{model}$, 计算复杂度为 $O(nd_{model}^2 + n^2d_{model})$. 可见多头注意力的时间复杂度与 $n$ (序列长度) 和 $d_{model}$ (特征维度) 的平方成正比, 与头的个数或者每个头的特征维度无关.

对于短序列 ($n < d_{model}$): 计算复杂度为 $O(nd_{model}^2)$ 主导.

对于长序列 ($n > d_{model}$): 计算复杂度为 $O(n^2d_{model})$ 主导. 这种情况在实际应用中较为常见.

参考

版权归属: 采石工
本文链接: https://quarryman.cn/article/20250212
版权声明: 除特别声明外, 文章采用《署名-非商业性使用-相同方式共享 4.0 国际》许可协议.