理论部分:

实际上,Transformer中的位置编码大致可分为三类:

  1. 绝对位置编码
    绝对位置编码中,又分为可学习位置编码固定位置编码
  • 可学习位置编码:即简单使用一层nn.embedding加到原始embedding中,让模型自己去学习位置关系(Bert中使用)
  • 固定位置编码:最常用的即为正余弦位置编码,即原始Transformer所采用的,其公式表示为:emb_{pos,2i} = sin(\frac{pos}{base^{\frac{2i}{d\_model}}})emb_{pos,2i+1} = cos(\frac{pos}{base^{\frac{2i}{d\_model}}})
  1. 相对位置编码
    在T5中采用,本质上是通过计算位置偏移量动态生成注意力偏置(估计绝对位置编码中的后置偏移量)。
  2. 旋转位置编码(RoPE)
    将位置信息编码为旋转矩阵,本质上是通过绝对位置编码的方式实现了相对位置编码(讲正余弦位置编码通过相乘的方式进行结合)。
    计算时,大致思路是先构造freq,即f_i = \frac{1}{base^{(2i/d\_model)}},然后构造pos序列(即从0~seq_len的位置序列),构造sin和cos矩阵。然后对X分解为奇数位和偶数位,通过组合计算方式重新计算输出结果。

代码实现 (RoPE)

def get_sin_cos(seq_len, hidden_dim, base):
    freq = 1.0 / (base ** (torch.arange(0, hidden_dim, 2, dtype=torch.float) / hidden_dim))
    pos = torch.arange(0, seq_len, 1, dtype=torch.float)
    #freq: (hidden_dim//2,) pos: (seq_len,) -> (seq_len, hidden_dim//2)
    input = pos.unsqueeze(1) @ freq.unsqueeze(0)
    cos = torch.cos(input)
    sin = torch.sin(input)
    return cos, sin

def apply(x, base):
    batch_size, seq_len, hidden_dim = x.shape
    cos, sin = get_sin_cos(seq_len, hidden_dim, base) #(seq_len, hidden_dim//2)
    
    x1 = x[...,0::2] # (batch_size, seq_len, dim//2)
    x2 = x[...,1::2]
    cos = cos.unsqueeze(0)
    sin = sin.unsqueeze(0)

    output1 = x1 * cos - x2 * sin
    output2 = x1 * sin + x2 * cos

    output = torch.ones_like(x)
    output[...,0::2] = output1
    output[...,1::2] = output2
    return output

x = torch.ones([2,3,4])
res = apply(x,10000) 
res

输出如下:

tensor([[[ 1.0000,  1.0000,  1.0000,  1.0000],
          [-0.3012,  1.3818,  0.9900,  1.0099],
          [-1.3254,  0.4932,  0.9798,  1.0198]],
          [[ 1.0000,  1.0000,  1.0000,  1.0000],
          [-0.3012,  1.3818,  0.9900,  1.0099],
          [-1.3254,  0.4932,  0.9798,  1.0198]]])
文章作者: DB咕
本文链接:
版权声明: 本站所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 DB咕
Transformer & LLM NLP 位置编码 RoPE 手撕 LLM
喜欢就支持一下吧