【LLM手撕系列】(一)BatchNorm/LayerNorm/RMSNorm的代码实现
BatchNorm
BatchNorm一般是指BatchNorm2d,因为在Pytorch中另外还有BatchNorm3d,是针对视频或者3D扫描图像等,也就是多了一个维度,这里只考虑2d的情况,即输入的X的shape为(batch_size, channel, height, width)。
def BatchNorm2d(X, eps = 1e-6):
# X shape (batch_size, channel, height, width)
alpha = nn.Parameter(torch.ones(X.shape[1])) # (channel,)
beta = nn.Parameter(torch.zeros(X.shape[1])) # (channel,)
mean = torch.mean(X, dim=[0,2,3], keepdim = True) # (batch_size, 1, height, width)
var = torch.var(X, dim=[0,2,3], keepdim = True) # (batch_size, 1, height, width)
output = (X - mean)/torch.sqrt(var + eps) * alpha.reshape([1,-1,1,1]) + beta.reshape([1,-1,1,1])
return output
从代码中可以看出,实际上BatchNorm就是对整个Batch的所有数值进行标准化,同时维护两组参数alpha和beta。
LayerNorm
LayerNorm是NLP领域常用的数值标准化方法,其代码实现也很简单,如下:
def LayerNorm(X, eps = 1e-6):
# X shape (batch_size, seq_len, hidden_dim)
alpha = nn.Parameter(torch.ones(X.shape[-1])) # (hidden_dim,)
beta = nn.Parameter(torch.zeros(X.shape[-1])) # (hidden_dim,)
mean = torch.mean(X, dim = -1, keepdim = True) # (batch_size, seq_len, 1)
var = torch.var(X, dim = -1, keepdim = True) # (batch_size, seq_len, 1)
output = (X - mean)/torch.sqrt(var + eps) * alpha + beta
return output
从代码中可以看出,其实LayerNorm和BatchNorm的区别就在于对哪些维度求均值和方差,同时维护的alpha和beta的维度是不同的。
RMSNorm
实际上,从Llama到现在的大模型更多采用的是RMSNorm,它的基本改进思想是移除LayerNorm中的平移量,因为发现平移量并不会使结果有更好的提升,其决定作用的是缩放量,代码实现如下:
def RMSNorm(X, eps = 1e-6):
# X shape (batch_size, seq_len, hidden_dim)
alpha = nn.Parameter(torch.ones(X.shape[-1])) # (hidden_dim,)
mean = torch.mean(X.pow(2), dim = -1, keepdim = True) # (batch_size, seq_len, 1)
output = X/torch.sqrt(mean + eps) * alpha
return output
实际上就是LayerNorm去掉了beta,并且变成了X除以均方根,并乘上可学习缩放参数alpha (原始论文使用gamma命名)。
测试代码
随机初始化一个向量跑一下,看看输出维度是否一致。
X = torch.rand([2,3,4])
output = LayerNorm(X)
print(output.shape)
X = torch.rand([2,3,4,4])
output = BatchNorm2d(X)
print(output.shape)
X = torch.rand([2,3,4])
output = RMSNorm(X)
print(output.shape)
# 输出如下:
# torch.Size([2, 3, 4])
# torch.Size([2, 3, 4, 4])
# torch.Size([2, 3, 4])
BatchNorm和LayerNorm都不改变输入维度,符合要求。
理论思考
问:
- 为什么NLP领域常用LayerNorm而CV领域常用BatchNorm?
- 输入维度为(2,3,4)的数据到LayerNorm中,共需要计算几次均值/方差?此时LayerNorm对应存储的alpha和beta维度是怎么样的?
- 输入维度为(2,3,5,5)的数据到BatchNorm中,共需要计算几次均值/方差?此时BatchNorm对应存储的alpha和beta维度是怎么样的?
- 输入维度为(2,3,4)的数据到RMSNorm中,共需要计算几次均值?此时RMSNorm对应存储的alpha和beta维度是怎么样的?
答:
- 因为NLP的输入是(batch_size, seq_len, hidden_dim),不同句子的长度是不一致的,且不同句子之间不具备严格关联性,使用BatchNorm效果不好。CV常用BatchNorm是因为图像通道具有明确的物理意义(RGB通道)且同一通道的像素服从相似分布(例如天空区域的蓝色通道值集中),使用BatchNorm效果好。
- 计算次数:2*3=6次,shape为(4,)。
- 计算次数:3次,shape为(3,)。
- 计算次数:2*3=6次,shape为(4,),与LayerNorm一致,整体参数量减半。
本文链接:
/archives/NorminLLM
版权声明:
本站所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自
DB咕!
喜欢就支持一下吧