LLaMA2模型

  • 秉承“小模型+大数据”的设计理念,LLaMA2在LLaMA1的基础上进一步优化和扩充了训练数据,将语料库的规模扩展至约7TB,实现了对更丰富语言和领域资源的
    覆盖

  • 在预训练阶段之后,LLaMA2采纳了人类反馈强化学习的方法,进一步提升了模型的性能:

    1. 使用了大规模且公开的指令微调数据集对模型进行有监督的微调
    2. LLaMA2还训练了RLHF奖励模型,并基于近似策略优化(Proximal Policy Optimization, PPO)以及拒绝采样(Rejection Sampling)进行强化学习对模型进行更新
  • 在模型架构上,LLaMA2继承了LLaMA1的架构

image.png

  • LLaMA2-34B和LLaMA270B在decode 阶段的 kv cache 优化上做了改变,还额外增加了分组查询注意力(GroupedQueryAttention,GQA),以提升计算效率

image.png

  • 在分组查询注意力机制下,键(key)以及值(value)不再与查询(query)一一对应,而是一组查询共享相同的键和值,从而有效降低内存占用并减少模型总参数量

分组查询注意力

  • MQA,全称 Multi Query Attention, GQA 由 google 提出的 MQA 变种,全称 Group-Query Attention,都是多头注意力(MHA)的变体,本质上是一种共用 KV cache 的优化方法

  • kv cache 优化三种方案:MHA、 MQA 和 GQA 的原理及区别如下:

    1. MHA(Multi-Head Attention):QKV 三部分有相同数量的头(head),且一一对应。每次做 Attention,head1 的 QKV 就做好自己运算就可以,最后输出时将各个头的 self-attention output 相拼接
    2. MQA 则是让 Q 仍然保持原来的头数,但 K 和 V 只有一个头,相当于所有的 Q 头共享一个 K 和 V 头,所以叫做 Multi-Query 。这直接让 KV cache 内存减少了 head_num 倍
    3. GQA 是 MHA 和 MQA 的折中,将 Q 分成 8 组,每组共享相同的一个 kv 头,假设 Q 有 64 个头,则使用 GQA 技术后,kv 头数 = 64 / 8 = 8 。这直接让 KV cache 内存减少了 8 倍
  • MHA、 MQA 和 GQA 原理的可视化对比如下图所示:

image.png

  • LLaMA2 官方实现的 GQA(包含了 kv cahce)代码如下所示(经简化):
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
# 根据n_rep,拓展KV
if n_rep == 1:
return x
return (x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim))

class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
...
self.n_local_heads = args.n_heads // model_parallel_size #Q的头数
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size #KV的头数
self.n_rep = self.n_local_heads // self.n_local_kv_heads
...
self.wq = ColumnParallelLinear(args.dim,args.n_heads * self.head_dim, # Q的头数* head_dim
...)
self.wk = ColumnParallelLinear(args.dim,self.n_kv_heads * self.head_dim, # K的头数* head_dim
...)
self.wv = ColumnParallelLinear(args.dim,self.n_kv_heads * self.head_dim,# V的头数* head_dim
...)
self.wo = RowParallelLinear(args.n_heads * self.head_dim,args.dim,... )

self.cache_k = torch.zeros((args.max_batch_size,args.max_seq_len,self.n_local_kv_heads, #KV的头数
self.head_dim,)).cuda()
self.cache_v = torch.zeros((args.max_batch_size,args.max_seq_len,self.n_local_kv_heads,#KV的头数
self.head_dim,)).cuda()
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) #嵌入RoPE位置编码
...
# 按此时序列的句子长度把kv添加到cache中
# 初始在prompt阶段seqlen>=1, 后续生成过程中seqlen==1
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
# 读取新进来的token所计算得到的k和v
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]

# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)

xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
#计算q*k
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
#加入mask,使得前面的token在于后面的token计算attention时得分为0,mask掉
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)

近似策略优化(Proximal Policy Optimization, PPO)

  • PPO是OpenAI spinning up下的第三个算法,翻译为“近端策略优化”

  • PPO的创造是基于和TRPO相同的问题:在策略梯度定理的步长α的选取中,如何选取合适的步长,使得更新的参数尽可能对应最好的策略,但也不至于走得太远,以至于导致性能崩溃

  • TRPO的方法是使用将优化函数二阶展开的方法进行优化,而PPO则采用将泰勒一阶展开的方法并使用了一些trick来保证新旧策略之间的距离不要过大

  • PPO算法能够简单快速的达到TRPO相同的效果

  • PPO有两种不同的实现:

    1. 基于惩罚的PPO:将KL散度不是作为一个硬约束,而是作为一个惩罚函数加在优化函数上
    2. 基于Clip的PPO:将优化函数转化成了一个Clip函数,效果好于前者

基于Clip的PPO

  • 对于策略梯度这类方法中最重要的就是如何将最大化策略收益转换成一个优化策略参数的函数,所以我们首先来看PPO中的策略优化函数,可以描述为:

image.png

  • 其中epsilon是控制新旧策略距离的超参数。clip函数可以通过下面这张图来表示清楚:

image.png

image.png

image.png

  • 没有了原来kl散度的约束,那么我们就可以直接用优化函数对θ进行策略更新

  • 如果这个时候我们直接用一阶展开其实就是利用梯度来进行更新,那么我们就可以直接用loss的方式来对策略进行更新,回归到了我们最开始的策略梯度方法

  • 具体PPO算法的过程可以描述成这样:

    1. 初始化环境和策略参数
    2. 策略和环境采样得到回报等数据
    3. 计算折扣累计回报
    4. 计算通过critic模型计算状态价值并通过状态价值计算优势函数(这样计算出来的可以减小方差)
    5. 通过优势函数函数和策略的概率计算出来loss(和之前策略梯度那里很类似,不过这里有一个clip)
    6. 用loss来更新策略
    7. 用折扣累计汇报和critic模型计算状态价值来更新critic模型(回到第2步)
  • 伪代码如下:

image.png

拒绝采样(Rejection Sampling)

什么是拒绝采样(Rejection Sampling)

  • AI 生成的答案并不总是正确的,有时候它会输出胡言乱语、逻辑错误或者无意义的推理链

  • 如果不进行筛选,这些错误答案可能会影响模型的学习过程,甚至让 AI 形成错误的推理模式

  • 为了解决这个问题,拒绝采样(Rejection Sampling, RS)让 AI 在训练过程中优中选优,只保留最优质的推理答案,从而提升整体推理能力

拒绝采样的核心思想

  • 拒绝采样的步骤如下:

image.png


LLaMA2 model.py

Grouped Query Attention (GQA) 的引入

  • 代码位置:ModelArgs类、Attention类、repeat_kv函数

  • 改进点:

    1. LLaMA2新增了n_kv_heads参数,允许键值头(KV Heads)数量少于查询头(Q Heads)
    2. 当n_kv_heads < n_heads时,通过repeat_kv函数将KV头重复多次以匹配Q头数量(如将1个KV头分配给多个Q头),减少显存占用并提升计算效率
1
2
3
4
5
6
7
8
9
10
11
12
# ModelArgs中新增n_kv_heads参数
@dataclass
class ModelArgs:
n_kv_heads: Optional[int] = None # 允许自定义KV头数量

# Attention类初始化时处理GQA逻辑
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads # 计算重复次数

# repeat_kv函数实现KV头复制
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
return x[:, :, :, None, :].expand(...).reshape(...) # 通过扩展和重塑复制KV头

FFN层动态维度调整

  • 代码位置:FeedForward类

  • 改进点:

    1. 新增ffn_dim_multiplier参数,允许自定义FFN中间层的缩放比例(默认隐藏层维度为4*dim,LLaMA1固定为此值)
    2. 通过ffn_dim_multiplier可灵活调整模型容量,例如增大倍数以提升模型表现,或减小倍数以降低计算量
1
2
3
4
# FeedForward初始化时动态计算隐藏层维度
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim) # 自定义缩放
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) # 确保为multiple_of的倍数

旋转位置编码(RoPE)优化

  • 代码位置:precompute_freqs_cis、apply_rotary_emb

  • 改进点:

    1. 频率矩阵预计算长度扩展:precompute_freqs_cis中计算max_seq_len * 2的频率,可能支持更长的上下文或更灵活的位置插值(如ALiBi外推)
    2. 张量重塑优化:apply_rotary_emb中通过flatten(3)合并维度,提升计算效率
1
2
3
4
5
6
7
8
# LLaMA2预计算更长的频率矩阵
self.freqs_cis = precompute_freqs_cis(
self.params.dim // self.params.n_heads,
self.params.max_seq_len * 2 # 长度扩展为2倍
)

# 更高效的张量操作
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # 合并最后两个维度

模型并行与参数初始化

  • 代码位置:ColumnParallelLinear、RowParallelLinear

  • 改进点:

    1. 显式初始化方法:并行线性层中通过init_method=lambda x: x指定初始化方式(可能采用更合理的分布)
    2. 更精细的并行划分:根据model_parallel_size动态分配本地头数(n_local_heads),提升多卡并行效率
1
2
3
4
5
6
7
8
9
# 并行线性层使用lambda初始化
self.wq = ColumnParallelLinear(
args.dim, args.n_heads * self.head_dim,
init_method=lambda x: x # 可能采用默认初始化或与LLaMA1不同
)

# 动态计算本地头数
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size

推理优化

  • 代码位置:Transformer.forward

  • 改进点:使用@torch.inference_mode()装饰器替代torch.no_grad(),减少内存占用并加速推理(inference_mode禁用梯度计算更彻底)

1
2
3
@torch.inference_mode()  # 更高效的推理模式
def forward(self, tokens: torch.Tensor, start_pos: int):
...

掩码生成优化

  • 代码位置:Transformer.forward中的mask生成

  • 改进点:当seqlen > 1时动态生成因果掩码(torch.triu),支持可变长度输入并减少计算冗余

1
2
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=start_pos + 1) # 动态生成三角掩码

关键模块解析

  • 模型参数定义 (ModelArgs)
1
2
3
4
5
6
7
8
9
10
11
12
13
@dataclass
class ModelArgs:
dim: int = 4096 # 向量维度(每个token的表示维度)
n_layers: int = 32 # Transformer层数
n_heads: int = 32 # 注意力头总数
n_kv_heads: Optional[int] = None # Key/Value头数(GQA特性)
vocab_size: int = -1 # 词表大小(由tokenizer确定)
multiple_of: int = 256 # 确保FFN层维度是该值的倍数
ffn_dim_multiplier: Optional[float] = None # FFN层维度缩放因子
norm_eps: float = 1e-5 # LayerNorm的epsilon值

max_batch_size: int = 32 # 最大批处理大小
max_seq_len: int = 2048 # 最大序列长度
  • 归一化层 (RMSNorm)
1
2
3
4
5
6
7
8
9
10
11
12
13
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) # 可学习的缩放参数

def _norm(self, x):
# 核心计算:x / sqrt(mean(x^2) + eps)
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
output = self._norm(x.float()).type_as(x) # 计算归一化
return output * self.weight # 应用缩放
  • 旋转位置编码 (Rotary Positional Embedding)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
# 预计算旋转角度(复数形式)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim//2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs) # 外积生成频率矩阵
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # 转换为复数形式
return freqs_cis # (seq_len, dim//2)

def apply_rotary_emb(xq, xk, freqs_cis):
# 将输入转换为复数形式
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

# 调整频率矩阵形状以匹配输入
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)

# 应用旋转(复数乘法)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
  • 注意力机制 (Grouped-Query Attention)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
# 关键参数计算
self.n_kv_heads = args.n_kv_heads or args.n_heads # 如果未指定则等于n_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads # 每个KV头重复次数

# 线性变换层(模型并行)
self.wq = ColumnParallelLinear(...) # 处理Query
self.wk = ColumnParallelLinear(...) # 处理Key
self.wv = ColumnParallelLinear(...) # 处理Value
self.wo = RowParallelLinear(...) # 输出层

# KV缓存(用于生成式推理)
self.cache_k = torch.zeros(...).cuda() # (batch, seq, n_kv_heads, dim)
self.cache_v = torch.zeros(...).cuda()

def forward(self, x, start_pos, freqs_cis, mask):
# 步骤1:线性变换得到Q/K/V
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

# 步骤2:重塑形状为 (bs, seq_len, n_heads, head_dim)
xq = xq.view(...)
xk = xk.view(...)
xv = xv.view(...)

# 步骤3:应用旋转位置编码
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)

# 步骤4:更新KV缓存(用于增量生成)
self.cache_k[:bsz, start_pos : start_pos+seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos+seqlen] = xv

# 步骤5:重复KV头以匹配Q的数量(GQA核心)
keys = repeat_kv(keys, self.n_rep)
values = repeat_kv(values, self.n_rep)

# 步骤6:计算注意力分数
scores = torch.matmul(xq, keys.transpose(2,3)) / math.sqrt(self.head_dim)
if mask is not None:
scores += mask # 应用因果掩码
scores = F.softmax(scores.float(), dim=-1).type_as(xq)

# 步骤7:加权求和得到输出
output = torch.matmul(scores, values)
return self.wo(output) # 输出线性变换
  • 前馈网络 (FeedForward)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier):
# 动态计算中间层维度
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier: # 允许自定义缩放
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of -1) // multiple_of)

# 三层线性变换(SwiGLU结构)
self.w1 = ColumnParallelLinear(...) # 门控分支
self.w2 = RowParallelLinear(...) # 输出层
self.w3 = ColumnParallelLinear(...) # 主干分支

def forward(self, x):
# SwiGLU激活:silu(w1(x)) * w3(x)
return self.w2(F.silu(self.w1(x)) * self.w3(x))
  • Transformer块 (TransformerBlock)
1
2
3
4
5
6
7
8
9
10
11
12
13
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
# 主要组件
self.attention = Attention(args) # 注意力层
self.feed_forward = FeedForward(...) # FFN层
self.attention_norm = RMSNorm(...) # 注意力前归一化
self.ffn_norm = RMSNorm(...) # FFN前归一化

def forward(self, x, start_pos, freqs_cis, mask):
# 残差连接结构
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
  • 整体模型架构 (Transformer)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
# 核心组件
self.tok_embeddings = ParallelEmbedding(...) # 词嵌入层
self.layers = nn.ModuleList([...]) # 堆叠Transformer块
self.norm = RMSNorm(params.dim, eps=params.norm_eps) # 最终归一化
self.output = ColumnParallelLinear(...) # 输出投影层

# 预计算旋转编码(长度是max_seq_len的2倍)
self.freqs_cis = precompute_freqs_cis(..., self.params.max_seq_len*2)

@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
# 步骤1:词嵌入
h = self.tok_embeddings(tokens)

# 步骤2:获取当前token的位置编码
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

# 步骤3:生成因果掩码(仅处理新生成的位置)
if seqlen > 1:
mask = torch.triu(torch.full(...), diagonal=start_pos+1)

# 步骤4:逐层处理
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)

# 步骤5:最终归一化+输出投影
return self.output(self.norm(h))

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import math
from dataclasses import dataclass
from typing import Optional, Tuple

import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
ParallelEmbedding,
RowParallelLinear,
)
from torch import nn


@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1 # defined later by tokenizer
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5

max_batch_size: int = 32
max_seq_len: int = 2048


class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.

Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.

"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The normalized tensor.

"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
"""
Forward pass through the RMSNorm layer.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The output tensor after applying RMSNorm.

"""
output = self._norm(x.float()).type_as(x)
return output * self.weight


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.

Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.

Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.




"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
"""
Reshape frequency tensor for broadcasting it with another tensor.

This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.

Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.

Returns:
torch.Tensor: Reshaped frequency tensor.

Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.

This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.

Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings.
xk (torch.Tensor): Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.

Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.



"""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)


class Attention(nn.Module):
"""Multi-head attention module."""
def __init__(self, args: ModelArgs):
"""
Initialize the Attention module.

Args:
args (ModelArgs): Model configuration parameters.

Attributes:
n_kv_heads (int): Number of key and value heads.
n_local_heads (int): Number of local query heads.
n_local_kv_heads (int): Number of local key and value heads.
n_rep (int): Number of repetitions for local heads.
head_dim (int): Dimension size of each attention head.
wq (ColumnParallelLinear): Linear transformation for queries.
wk (ColumnParallelLinear): Linear transformation for keys.
wv (ColumnParallelLinear): Linear transformation for values.
wo (RowParallelLinear): Linear transformation for output.
cache_k (torch.Tensor): Cached keys for attention.
cache_v (torch.Tensor): Cached values for attention.

"""
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads

self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)

self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()

def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
"""
Forward pass of the attention module.

Args:
x (torch.Tensor): Input tensor.
start_pos (int): Starting position for caching.
freqs_cis (torch.Tensor): Precomputed frequency tensor.
mask (torch.Tensor, optional): Attention mask tensor.

Returns:
torch.Tensor: Output tensor after attention.

"""
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)

self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]

# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)

xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)


class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
"""
Initialize the FeedForward module.

Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.

Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.

"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
self.w2 = RowParallelLinear(
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
)
self.w3 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)

def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))


class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
"""
Initialize a TransformerBlock.

Args:
layer_id (int): Identifier for the layer.
args (ModelArgs): Model configuration parameters.

Attributes:
n_heads (int): Number of attention heads.
dim (int): Dimension size of the model.
head_dim (int): Dimension size of each attention head.
attention (Attention): Attention module.
feed_forward (FeedForward): FeedForward module.
layer_id (int): Identifier for the layer.
attention_norm (RMSNorm): Layer normalization for attention output.
ffn_norm (RMSNorm): Layer normalization for feedforward output.

"""
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
"""
Perform a forward pass through the TransformerBlock.

Args:
x (torch.Tensor): Input tensor.
start_pos (int): Starting position for attention caching.
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.

Returns:
torch.Tensor: Output tensor after applying attention and feedforward layers.

"""
h = x + self.attention.forward(
self.attention_norm(x), start_pos, freqs_cis, mask
)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out


class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
"""
Initialize a Transformer model.

Args:
params (ModelArgs): Model configuration parameters.

Attributes:
params (ModelArgs): Model configuration parameters.
vocab_size (int): Vocabulary size.
n_layers (int): Number of layers in the model.
tok_embeddings (ParallelEmbedding): Token embeddings.
layers (torch.nn.ModuleList): List of Transformer blocks.
norm (RMSNorm): Layer normalization for the model output.
output (ColumnParallelLinear): Linear layer for final output.
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.

"""
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers

self.tok_embeddings = ParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
)

self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))

self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)

self.freqs_cis = precompute_freqs_cis(
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
)

@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
"""
Perform a forward pass through the Transformer model.

Args:
tokens (torch.Tensor): Input token indices.
start_pos (int): Starting position for attention caching.

Returns:
torch.Tensor: Output logits after applying the Transformer model.

"""
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
)
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h).float()
return output