引言
本文通过数学推导和代码验证得出: LayerNorm 层后接一个 Linear 层可以转化为不带可学习参数的 LayerNorm 后接一个新的 Linear 层. 如此转化可以省掉 LayerNorm 中 affine 操作, 从而减少了参数量和计算量.
正文
设 $X$ 是 $N \times d$ 矩阵, 其在最后一维上的层归一化 (NLP 中常见该操作) 为
$$\operatorname{LayerNorm}(X) = \vec{1}_N\vec{\gamma}^T \circ (X - \vec{\mu}\vec{1}_d^T) \circ (\vec{\sigma}^{\circ{2}} \vec{1}_d^T+ \epsilon \mathbf{1}_{N\times d})^{\circ-1/2} + \vec{1}_N \vec{\beta}^T$$
式中 $\vec{\mu}=\frac{1}{d}X\vec{1}_d$ 和 $\vec{\sigma}^{\circ 2} = \frac{1}{d}(X - \vec{\mu}\vec{1}_d^T)^{\circ 2}\vec{1}_d$, 尺寸为 $N \times 1$; $\vec{\gamma}$ 和 $\vec{\beta}$ 为 LayerNorm 的可学习参数, 尺寸为 $d \times 1$.
利用结论: $\vec{1}_N\vec{\gamma}^T\circ A = A \operatorname{diag}{\vec{\gamma}}$ (可以直接推得, 由下面的代码亦可验证), 可以将 LayerNorm 的公式转化为:
$$\operatorname{LayerNorm}(X) = \left((X - \vec{\mu}\vec{1}_d^T) \circ (\vec{\sigma}^{\circ{2}} \vec{1}_d^T+ \epsilon \mathbf{1}_{N\times d})^{\circ-1/2}\right) \operatorname{diag}{\vec{\gamma}}+ \vec{1}_N \vec{\beta}^T$$
不妨令 $\hat{X} = (X - \vec{\mu}\vec{1}_d^T) \circ (\vec{\sigma}^{\circ{2}} \vec{1}_d^T+ \epsilon \mathbf{1}_{N\times d})^{\circ-1/2}$, 则上式可以简记为:
$$\operatorname{LayerNorm}(X) = \hat{X} \operatorname{diag}{\vec{\gamma}}+ \vec{1}_N \vec{\beta}^T$$
import numpy as np
N, d = 4, 3
a = np.random.rand(N, d)
gamma = np.random.rand(d, 1)
ones = np.ones((N, 1))
y1 = a * np.dot(ones, gamma.T)
y2 = np.dot(a, np.diag(gamma.flatten()))
print(np.allclose(y1, y2))
如果 LayerNorm 层后面再接一个 Linear 层, 设其权重为 $W$ (尺寸为 $e \times d$), 偏置为 $\vec{b}$ (尺寸为 $e \times 1$), 则:
$$\begin{aligned} \operatorname{Linear}({\operatorname{LayerNorm}(X)}) &= (\hat{X}\operatorname{diag}{\vec{\gamma}} + \vec{1}_N \vec{\beta}^T) W^T + \vec{1}_N\vec{b}^T \\ &= \hat{X}\operatorname{diag}{\vec{\gamma}} W^T + \vec{1}_N \vec{\beta}^T W^T +\vec{1}_N\vec{b}^T \\ &= \hat{X} (W\operatorname{diag}{\vec{\gamma}} ) ^T + \vec{1}_N (W \vec{\beta} + \vec{b})^T \end{aligned}$$
由上式可见, LayerNorm 层后面再接一个 Linear 层可以转化为: 不带可学习参数的 LayerNorm 再后接一个新的 Linear 层, 新 Linear 层的权重为 $W\operatorname{diag}{\vec{\gamma}}$, 偏置为 $W \vec{\beta} + \vec{b}$. 由下面的代码亦可验证.
import torch
import torch.nn as nn
class PreNormLinear(nn.Module):
def __init__(self, in_features, out_features, eps: float = 1e-5,
elementwise_affine=True, device=None, dtype=None):
super(PreNormLinear, self).__init__()
self.factory_kwargs = {'device': device, 'dtype': dtype}
self.in_features = in_features
self.out_features = out_features
self.norm = nn.LayerNorm(in_features, eps=eps,
elementwise_affine=elementwise_affine, **self.factory_kwargs)
self.linear = nn.Linear(in_features, out_features, bias=True, **self.factory_kwargs)
def forward(self, x):
x = self.norm(x)
x = self.linear(x)
return x
if __name__ == '__main__':
batch, sentence_length, embedding_dim, out_features = 20, 5, 10, 128
layer1 = PreNormLinear(embedding_dim, out_features)
layer2 = PreNormLinear(embedding_dim, out_features, elementwise_affine=False)
nn.init.normal_(layer1.norm.weight, 0, 1)
nn.init.normal_(layer1.norm.bias, 0, 1)
nn.init.normal_(layer1.linear.weight, 0, 1)
nn.init.normal_(layer1.linear.bias, 0, 1)
layer2.linear.weight = nn.Parameter(layer1.linear.weight @ torch.diag(layer1.norm.weight))
layer2.linear.bias = nn.Parameter(layer1.linear.weight @ layer1.norm.bias + layer1.linear.bias)
embedding = torch.randn(batch, sentence_length, embedding_dim)
output1 = layer1(embedding)
output2 = layer2(embedding)
print(torch.allclose(output1, output2, atol=1e-5))
修改历史
- 20240415 推导公式并发布
- 20240416 更新部分内容
版权声明
署名-非商业性使用-相同方式共享 4.0 国际许可协议