BatchNorm

BatchNorm一般是指BatchNorm2d,因为在Pytorch中另外还有BatchNorm3d,是针对视频或者3D扫描图像等,也就是多了一个维度,这里只考虑2d的情况,即输入的X的shape为(batch_size, channel, height, width)。

def BatchNorm2d(X, eps = 1e-6):
    # X shape (batch_size, channel, height, width)
    alpha = nn.Parameter(torch.ones(X.shape[1])) # (channel,)
    beta = nn.Parameter(torch.zeros(X.shape[1])) # (channel,)

    mean = torch.mean(X, dim=[0,2,3], keepdim = True) # (batch_size, 1, height, width)
    var = torch.var(X, dim=[0,2,3], keepdim = True) # (batch_size, 1, height, width)

    output = (X - mean)/torch.sqrt(var + eps) * alpha.reshape([1,-1,1,1]) + beta.reshape([1,-1,1,1])

    return output

从代码中可以看出,实际上BatchNorm就是对整个Batch的所有数值进行标准化,同时维护两组参数alpha和beta。

LayerNorm

LayerNorm是NLP领域常用的数值标准化方法,其代码实现也很简单,如下:

def LayerNorm(X, eps = 1e-6):
    # X shape (batch_size, seq_len, hidden_dim)
    alpha = nn.Parameter(torch.ones(X.shape[-1])) # (hidden_dim,)
    beta = nn.Parameter(torch.zeros(X.shape[-1])) # (hidden_dim,)

    mean = torch.mean(X, dim = -1, keepdim = True) # (batch_size, seq_len, 1)
    var = torch.var(X, dim = -1, keepdim = True) # (batch_size, seq_len, 1)

    output = (X - mean)/torch.sqrt(var + eps) * alpha + beta

    return output

从代码中可以看出,其实LayerNorm和BatchNorm的区别就在于对哪些维度求均值和方差,同时维护的alpha和beta的维度是不同的。

RMSNorm

实际上,从Llama到现在的大模型更多采用的是RMSNorm,它的基本改进思想是移除LayerNorm中的平移量,因为发现平移量并不会使结果有更好的提升,其决定作用的是缩放量,代码实现如下:

def RMSNorm(X, eps = 1e-6):
    # X shape (batch_size, seq_len, hidden_dim)
    alpha = nn.Parameter(torch.ones(X.shape[-1])) # (hidden_dim,)

    mean = torch.mean(X.pow(2), dim = -1, keepdim = True) # (batch_size, seq_len, 1)

    output = X/torch.sqrt(mean + eps) * alpha

    return output

实际上就是LayerNorm去掉了beta,并且变成了X除以均方根,并乘上可学习缩放参数alpha (原始论文使用gamma命名)。

测试代码

随机初始化一个向量跑一下,看看输出维度是否一致。

X = torch.rand([2,3,4])
output = LayerNorm(X)
print(output.shape)

X = torch.rand([2,3,4,4])
output = BatchNorm2d(X)
print(output.shape)

X = torch.rand([2,3,4])
output = RMSNorm(X)
print(output.shape)

# 输出如下:
# torch.Size([2, 3, 4]) 
# torch.Size([2, 3, 4, 4])
# torch.Size([2, 3, 4])

BatchNorm和LayerNorm都不改变输入维度,符合要求。

理论思考

问:

  1. 为什么NLP领域常用LayerNorm而CV领域常用BatchNorm?
  2. 输入维度为(2,3,4)的数据到LayerNorm中,共需要计算几次均值/方差?此时LayerNorm对应存储的alpha和beta维度是怎么样的?
  3. 输入维度为(2,3,5,5)的数据到BatchNorm中,共需要计算几次均值/方差?此时BatchNorm对应存储的alpha和beta维度是怎么样的?
  4. 输入维度为(2,3,4)的数据到RMSNorm中,共需要计算几次均值?此时RMSNorm对应存储的alpha和beta维度是怎么样的?

答:

  1. 因为NLP的输入是(batch_size, seq_len, hidden_dim),不同句子的长度是不一致的,且不同句子之间不具备严格关联性,使用BatchNorm效果不好。CV常用BatchNorm是因为图像通道具有明确的物理意义(RGB通道)且同一通道的像素服从相似分布(例如天空区域的蓝色通道值集中),使用BatchNorm效果好。
  2. 计算次数:2*3=6次,shape为(4,)。
  3. 计算次数:3次,shape为(3,)。
  4. 计算次数:2*3=6次,shape为(4,),与LayerNorm一致,整体参数量减半。
文章作者: DB咕
本文链接:
版权声明: 本站所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 DB咕
Transformer & LLM LLM LayerNorm BatchNorm RMSNorm 手撕
喜欢就支持一下吧