【LLM手撕系列】(二)RoPE旋转位置编码的代码实现
理论部分:
实际上,Transformer中的位置编码大致可分为三类:
- 绝对位置编码
绝对位置编码中,又分为可学习位置编码和固定位置编码
- 可学习位置编码:即简单使用一层nn.embedding加到原始embedding中,让模型自己去学习位置关系(Bert中使用)
- 固定位置编码:最常用的即为正余弦位置编码,即原始Transformer所采用的,其公式表示为:emb_{pos,2i} = sin(\frac{pos}{base^{\frac{2i}{d\_model}}}),emb_{pos,2i+1} = cos(\frac{pos}{base^{\frac{2i}{d\_model}}})
- 相对位置编码
在T5中采用,本质上是通过计算位置偏移量动态生成注意力偏置(估计绝对位置编码中的后置偏移量)。 - 旋转位置编码(RoPE)
将位置信息编码为旋转矩阵,本质上是通过绝对位置编码的方式实现了相对位置编码(讲正余弦位置编码通过相乘的方式进行结合)。
计算时,大致思路是先构造freq,即f_i = \frac{1}{base^{(2i/d\_model)}},然后构造pos序列(即从0~seq_len的位置序列),构造sin和cos矩阵。然后对X分解为奇数位和偶数位,通过组合计算方式重新计算输出结果。
代码实现 (RoPE)
def get_sin_cos(seq_len, hidden_dim, base):
freq = 1.0 / (base ** (torch.arange(0, hidden_dim, 2, dtype=torch.float) / hidden_dim))
pos = torch.arange(0, seq_len, 1, dtype=torch.float)
#freq: (hidden_dim//2,) pos: (seq_len,) -> (seq_len, hidden_dim//2)
input = pos.unsqueeze(1) @ freq.unsqueeze(0)
cos = torch.cos(input)
sin = torch.sin(input)
return cos, sin
def apply(x, base):
batch_size, seq_len, hidden_dim = x.shape
cos, sin = get_sin_cos(seq_len, hidden_dim, base) #(seq_len, hidden_dim//2)
x1 = x[...,0::2] # (batch_size, seq_len, dim//2)
x2 = x[...,1::2]
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
output1 = x1 * cos - x2 * sin
output2 = x1 * sin + x2 * cos
output = torch.ones_like(x)
output[...,0::2] = output1
output[...,1::2] = output2
return output
x = torch.ones([2,3,4])
res = apply(x,10000)
res
输出如下:
tensor([[[ 1.0000, 1.0000, 1.0000, 1.0000],
[-0.3012, 1.3818, 0.9900, 1.0099],
[-1.3254, 0.4932, 0.9798, 1.0198]],
[[ 1.0000, 1.0000, 1.0000, 1.0000],
[-0.3012, 1.3818, 0.9900, 1.0099],
[-1.3254, 0.4932, 0.9798, 1.0198]]])
本文链接:
/archives/rope-by-hands
版权声明:
本站所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自
DB咕!
喜欢就支持一下吧