LayerNorm 层后接 Linear 层的转化形式

引言

本文通过数学推导和代码验证得出: LayerNorm 层后接一个 Linear 层可以转化为不带可学习参数的 LayerNorm 后接一个新的 Linear 层. 如此转化可以省掉 LayerNorm 中 affine 操作, 从而减少了参数量和计算量.

正文

设 $X$ 是 $N \times d$ 矩阵, 其在最后一维上的层归一化 (NLP 中常见该操作) 为

$$\operatorname{LayerNorm}(X) = \vec{1}_N\vec{\gamma}^T \circ (X - \vec{\mu}\vec{1}_d^T) \circ (\vec{\sigma}^{\circ{2}} \vec{1}_d^T+ \epsilon \mathbf{1}_{N\times d})^{\circ-1/2} + \vec{1}_N \vec{\beta}^T$$

式中 $\vec{\mu}=\frac{1}{d}X\vec{1}_d$ 和 $\vec{\sigma}^{\circ 2} = \frac{1}{d}(X - \vec{\mu}\vec{1}_d^T)^{\circ 2}\vec{1}_d$, 尺寸为 $N \times 1$; $\vec{\gamma}$ 和 $\vec{\beta}$ 为 LayerNorm 的可学习参数, 尺寸为 $d \times 1$.

利用结论: $\vec{1}_N\vec{\gamma}^T\circ A = A \operatorname{diag}{\vec{\gamma}}$ (可以直接推得, 由下面的代码亦可验证), 可以将 LayerNorm 的公式转化为:

$$\operatorname{LayerNorm}(X) = \left((X - \vec{\mu}\vec{1}_d^T) \circ (\vec{\sigma}^{\circ{2}} \vec{1}_d^T+ \epsilon \mathbf{1}_{N\times d})^{\circ-1/2}\right) \operatorname{diag}{\vec{\gamma}}+ \vec{1}_N \vec{\beta}^T$$

不妨令 $\hat{X} = (X - \vec{\mu}\vec{1}_d^T) \circ (\vec{\sigma}^{\circ{2}} \vec{1}_d^T+ \epsilon \mathbf{1}_{N\times d})^{\circ-1/2}$, 则上式可以简记为:

$$\operatorname{LayerNorm}(X) = \hat{X} \operatorname{diag}{\vec{\gamma}}+ \vec{1}_N \vec{\beta}^T$$

import numpy as np
N, d = 4, 3
a = np.random.rand(N, d)
gamma = np.random.rand(d, 1)
ones = np.ones((N, 1))
y1 = a * np.dot(ones, gamma.T)
y2 = np.dot(a, np.diag(gamma.flatten()))
print(np.allclose(y1, y2))

如果 LayerNorm 层后面再接一个 Linear 层, 设其权重为 $W$ (尺寸为 $e \times d$), 偏置为 $\vec{b}$ (尺寸为 $e \times 1$), 则:

$$\begin{aligned} \operatorname{Linear}({\operatorname{LayerNorm}(X)}) &= (\hat{X}\operatorname{diag}{\vec{\gamma}} + \vec{1}_N \vec{\beta}^T) W^T + \vec{1}_N\vec{b}^T \\ &= \hat{X}\operatorname{diag}{\vec{\gamma}} W^T + \vec{1}_N \vec{\beta}^T W^T +\vec{1}_N\vec{b}^T \\ &= \hat{X} (W\operatorname{diag}{\vec{\gamma}} ) ^T + \vec{1}_N (W \vec{\beta} + \vec{b})^T \end{aligned}$$

由上式可见, LayerNorm 层后面再接一个 Linear 层可以转化为: 不带可学习参数的 LayerNorm 再后接一个新的 Linear 层, 新 Linear 层的权重为 $W\operatorname{diag}{\vec{\gamma}}$, 偏置为 $W \vec{\beta} + \vec{b}$. 由下面的代码亦可验证.

import torch
import torch.nn as nn


class PreNormLinear(nn.Module):
    def __init__(self, in_features, out_features, eps: float = 1e-5, 
                 elementwise_affine=True, device=None, dtype=None):
        super(PreNormLinear, self).__init__()
        self.factory_kwargs = {'device': device, 'dtype': dtype}
        self.in_features = in_features
        self.out_features = out_features
        self.norm = nn.LayerNorm(in_features, eps=eps, 
                                 elementwise_affine=elementwise_affine, **self.factory_kwargs)
        self.linear = nn.Linear(in_features, out_features, bias=True, **self.factory_kwargs)

    def forward(self, x):
        x = self.norm(x)
        x = self.linear(x)
        return x


if __name__ == '__main__':
    batch, sentence_length, embedding_dim, out_features = 20, 5, 10, 128
    layer1 = PreNormLinear(embedding_dim, out_features)
    layer2 = PreNormLinear(embedding_dim, out_features, elementwise_affine=False)

    nn.init.normal_(layer1.norm.weight, 0, 1)
    nn.init.normal_(layer1.norm.bias, 0, 1)
    nn.init.normal_(layer1.linear.weight, 0, 1)
    nn.init.normal_(layer1.linear.bias, 0, 1)
    layer2.linear.weight = nn.Parameter(layer1.linear.weight @ torch.diag(layer1.norm.weight))
    layer2.linear.bias = nn.Parameter(layer1.linear.weight @ layer1.norm.bias + layer1.linear.bias)

    embedding = torch.randn(batch, sentence_length, embedding_dim)
    output1 = layer1(embedding)
    output2 = layer2(embedding)
    print(torch.allclose(output1, output2, atol=1e-5))

修改历史

  • 20240415 推导公式并发布
  • 20240416 更新部分内容

版权声明

署名-非商业性使用-相同方式共享 4.0 国际许可协议

版权归属: 采石工
本文链接: https://quarryman.cn/article/20240415
版权声明: 除特别声明外, 文章采用《署名-非商业性使用-相同方式共享 4.0 国际》许可协议.