符号标记
$Q, K, V\in \mathbb{R}^{n\times d_{model}}$
$W^Q, W^K, W^V, W^O \in \mathbb{R}^{d_{model} \times d_{model}}$, 可以将 $W^Q, W^K, W^V$ 分为 $h$ 个头, 即
$W_i^Q, W_i^K, W_i^V \in \mathbb{R}^{d_{model} \times d_k}$, 其中 $d_k = d_{model} / h$, $i=1,2,...,h$.
这里只考虑矩阵乘的计算量.
多头投影步骤计算量
$QW_i^Q$, $KW_i^K$, $VW_i^V$ 的计算量均为 $n \cdot d_{model} \cdot d_{k}$.
共 $h$ 个头, 所以该步骤总计算量为 $3 \cdot h \cdot n \cdot d_{model} \cdot d_{k}$, 因为 $d_k = d_{model} / h$, 所以计算量可以转化为 $3 \cdot n \cdot d_{model}^2$.
注意力计算步骤计算量
令 $Q'_i=QW_i^Q, K'_i=KW_i^K, V'_i=VW_i^V$, 它们的尺寸均为 $n\times d_{k}$.
$\mathrm{attn}_i = \mathrm{Attention}(Q'_i, K'_i, V'_i)$ (其尺寸也是 $n\times d_{k}$) 的计算量为 $2 \cdot n^2 \cdot d_k$, 因为
$Q'_iK'^T_i$ 的计算量为 $n^2 \cdot d_k$, 其经过 softmax 等操作后乘上 $V'_i$ 的计算量也为 $n^2 \cdot d_k$.
共 $h$ 个头, 所以该步骤的总计算量为 $2 \cdot h \cdot n^2 \cdot d_k$, 即 $2 \cdot n^2 \cdot d_{model}$.
输出投影步骤计算量
令 $S = \mathrm{Concat}(\mathrm{attn}_1, ..., \mathrm{attn}_h)$, 其尺寸为 $n\times hd_{k}$, 即 $n\times d_{model}$.
该步骤 $S W_O$ 的计算量为 $n \cdot d_{model}^2$.
总结
综上多头自注意力的 (所有矩阵乘运算的) 计算量为 $4nd_{model}^2 + 2n^2d_{model}$, 计算复杂度为 $O(nd_{model}^2 + n^2d_{model})$. 可见多头注意力的时间复杂度与 $n$ (序列长度) 和 $d_{model}$ (特征维度) 的平方成正比, 与头的个数或者每个头的特征维度无关.
对于短序列 ($n < d_{model}$): 计算复杂度为 $O(nd_{model}^2)$ 主导.
对于长序列 ($n > d_{model}$): 计算复杂度为 $O(n^2d_{model})$ 主导. 这种情况在实际应用中较为常见.