简打卡
51.76M · 2026-03-28
graph TB
subgraph 问题[三大瓶颈]
P1[" 推理速度慢<br/>自回归逐词生成<br/>大量重复计算"]
P2[" 显存占用高<br/>KV矩阵随序列长度增长<br/>多头存储冗余"]
P3[" 序列长度受限<br/>O(n²)复杂度<br/>长文本处理困难"]
end
subgraph 解决方案
S1[" KV Cache<br/>缓存已计算的KV"]
S2[" MQA/GQA<br/>共享KV降低显存"]
S3[" Sparse Attention<br/>稀疏注意力模式"]
end
P1 --> S1
P2 --> S2
P3 --> S3
style P1 fill:#ffcdd2
style P2 fill:#ffccbc
style P3 fill:#ffab91
style S1 fill:#a5d6a7
style S2 fill:#81c784
style S3 fill:#66bb6a
| 模型 | KV Cache | MQA/GQA | Sparse Attn | MoE | 上下文长度 |
|---|---|---|---|---|---|
| GPT-3 | 2K | ||||
| LLaMA | 4K | ||||
| LLaMA2 | GQA | 4K | |||
| GPT-4 | 部分 | 推测 | 32K/128K | ||
| Mixtral 8x7B | GQA | 32K | |||
| Claude 3 | ? | 200K |
场景:GPT模型生成"我爱学习AI"
sequenceDiagram
participant Input
participant Model
participant Output
Note over Input,Output: Step 1: 生成"我"
Input->>Model: [START]
Model->>Output: "我"
Note over Input,Output: Step 2: 生成"爱"
Input->>Model: [START, 我]
Note right of Model: 重新计算"我"的KV
Model->>Output: "爱"
Note over Input,Output: Step 3: 生成"学习"
Input->>Model: [START, 我, 爱]
Note right of Model: 重新计算"我""爱"的KV
Model->>Output: "学习"
Note over Input,Output: Step 4: 生成"AI"
Input->>Model: [START, 我, 爱, 学习]
Note right of Model: 重新计算所有历史KV
Model->>Output: "AI"
问题分析:
总计算量:
核心思想:缓存已经计算过的Key和Value矩阵,新token只需计算自己的KV。
graph TB
subgraph 无Cache[Without KV Cache]
S1["Step 1<br/>计算: [START]"]
S2["Step 2<br/>计算: [START, 我]<br/> 重复计算START"]
S3["Step 3<br/>计算: [START, 我, 爱]<br/> 重复计算START,我"]
end
subgraph 有Cache[With KV Cache]
C1["Step 1<br/>计算&缓存: [START]"]
C2["Step 2<br/> 读取: [START]<br/>计算&缓存: [我]"]
C3["Step 3<br/> 读取: [START, 我]<br/>计算&缓存: [爱]"]
end
S1 --> S2 --> S3
C1 --> C2 --> C3
style S2 fill:#ffcdd2
style S3 fill:#ffcdd2
style C2 fill:#a5d6a7
style C3 fill:#a5d6a7
加速效果:
标准Attention:
在第步:
缓存更新:
# Pseudo-code
cache_K = [] # 初始化KV缓存
cache_V = []
for t in range(max_len):
# 1. 计算当前token的KV
k_t = compute_key(x_t)
v_t = compute_value(x_t)
# 2. 追加到缓存
cache_K.append(k_t)
cache_V.append(v_t)
# 3. 使用全部缓存计算注意力
q_t = compute_query(x_t)
attention = softmax(q_t @ cache_K.T) @ cache_V
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttentionWithCache(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, x, cache=None, use_cache=False):
"""
参数:
x: [batch_size, seq_len, d_model]
cache: {'key': [batch, n_heads, past_len, d_k],
'value': [batch, n_heads, past_len, d_k]}
use_cache: 是否返回更新后的cache
"""
batch_size, seq_len, _ = x.size()
# 1. 计算当前输入的QKV
Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_K(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_V(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
# 2. 如果有cache,拼接历史KV
if cache is not None:
K = torch.cat([cache['key'], K], dim=2) # 拼接到seq_len维度
V = torch.cat([cache['value'], V], dim=2)
# 3. 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# 4. 合并多头
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
output = self.W_O(attn_output)
# 5. 更新cache
if use_cache:
new_cache = {'key': K, 'value': V}
return output, new_cache
return output
# 使用示例:模拟自回归生成
d_model = 512
n_heads = 8
max_len = 10
mha = MultiHeadAttentionWithCache(d_model, n_heads)
# 初始化
cache = None
all_outputs = []
for t in range(max_len):
# 当前token (实际中是上一步的输出)
current_token = torch.randn(1, 1, d_model) # [batch=1, seq_len=1, d_model]
# 前向传播 with cache
output, cache = mha(current_token, cache=cache, use_cache=True)
all_outputs.append(output)
print(f"Step {t+1}:")
print(f" Cache K shape: {cache['key'].shape}")
print(f" Cache V shape: {cache['value'].shape}")
# 输出示例:
# Step 1:
# Cache K shape: torch.Size([1, 8, 1, 64])
# Cache V shape: torch.Size([1, 8, 1, 64])
# Step 2:
# Cache K shape: torch.Size([1, 8, 2, 64]) ← 长度递增
# Cache V shape: torch.Size([1, 8, 2, 64])
# ...
分析:对于单个样本
示例:LLaMA2-7B
单个序列就需要2GB显存! 这就是为什么需要MQA/GQA优化。
问题:在多头注意力中,每个头都有独立的KV矩阵,造成显存冗余。
graph TB
subgraph 标准MHA[Multi-Head Attention]
Q1["Q1"] --> H1["Head 1"]
K1["K1"] --> H1
V1["V1"] --> H1
Q2["Q2"] --> H2["Head 2"]
K2["K2"] --> H2
V2["V2"] --> H2
Qn["Qn"] --> Hn["Head n"]
Kn["Kn"] --> Hn
Vn["Vn"] --> Hn
end
subgraph MQA[Multi-Query Attention]
Q1m["Q1"] --> H1m["Head 1"]
SharedKV["共享 K, V"] --> H1m
SharedKV --> H2m["Head 2"]
SharedKV --> Hnm["Head n"]
Q2m["Q2"] --> H2m
Qnm["Qn"] --> Hnm
end
style SharedKV fill:#a5d6a7
style K1 fill:#ffcdd2
style K2 fill:#ffcdd2
style Kn fill:#ffcdd2
核心思想:所有注意力头共享同一组Key和Value,只有Query独立。
标准MHA:
MQA:
注意: 在所有头之间共享。
参数量对比:
| 配置 | MHA | MQA | 节省 |
|---|---|---|---|
| Q权重 | 0 | ||
| K权重 | |||
| V权重 |
示例(h=32):
class MultiQueryAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 每个头独立的Query
self.W_Q = nn.Linear(d_model, d_model)
# 共享的Key和Value
self.W_K = nn.Linear(d_model, self.d_k) # 注意维度!
self.W_V = nn.Linear(d_model, self.d_k)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, _ = x.size()
# 1. 计算多头Query
Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k)
Q = Q.transpose(1, 2) # [batch, n_heads, seq_len, d_k]
# 2. 计算共享的K和V
K = self.W_K(x) # [batch, seq_len, d_k]
V = self.W_V(x) # [batch, seq_len, d_k]
# 扩展到所有头(通过broadcast)
K = K.unsqueeze(1) # [batch, 1, seq_len, d_k]
V = V.unsqueeze(1)
# 3. 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# [batch, n_heads, seq_len, d_k]
# 4. 合并多头
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
output = self.W_O(attn_output)
return output
# 对比参数量
d_model = 512
n_heads = 8
mha = MultiHeadAttention(d_model, n_heads)
mqa = MultiQueryAttention(d_model, n_heads)
print(f"MHA 参数量: {sum(p.numel() for p in mha.parameters())}")
print(f"MQA 参数量: {sum(p.numel() for p in mqa.parameters())}")
# MHA 参数量: 1,050,624
# MQA 参数量: 820,224 (节省22%)
graph LR
Pro[" 优点"] --> P1["显存占用大幅降低"]
Pro --> P2["推理速度显著提升"]
Con[" 缺点"] --> C1["表达能力下降"]
Con --> C2["精度略有损失"]
Con --> C3["多头冗余度太低"]
style Pro fill:#a5d6a7
style Con fill:#ffcdd2
实验数据(PaLM论文):
核心思想:将多个Query头分组,每组共享一对KV。
graph TB
subgraph MHA[Multi-Head: h个独立KV]
MHA_Heads["Head1 Head2 ... Head-h<br/>K1,V1 K2,V2 ... Kh,Vh"]
end
subgraph GQA[Grouped-Query: g组共享KV]
GQA_Group1["组1: Head1,2,3,4<br/>共享 K1,V1"]
GQA_Group2["组2: Head5,6,7,8<br/>共享 K2,V2"]
end
subgraph MQA[Multi-Query: 1组共享KV]
MQA_All["所有Head<br/>共享 K,V"]
end
MHA -.折中方案.-> GQA
GQA -.极端情况.-> MQA
style MHA fill:#ffccbc
style GQA fill:#fff9c4
style MQA fill:#a5d6a7
数学关系:
常见配置:
| 模型 | Query头数 | KV组数 | 每组头数 | 显存节省 |
|---|---|---|---|---|
| LLaMA2-7B | 32 | 8 | 4 | 75% |
| LLaMA2-13B | 40 | 5 | 8 | 87.5% |
| LLaMA2-70B | 64 | 8 | 8 | 87.5% |
| Mixtral 8x7B | 32 | 8 | 4 | 75% |
graph TB
X["输入 X"] --> Linear["线性变换"]
Linear --> Q["Query<br/>[h个头]"]
Linear --> K["Key<br/>[g组]"]
Linear --> V["Value<br/>[g组]"]
subgraph 组1
Q1["Q头1-4"] --> Attn1["注意力计算"]
K1["K1"] --> Attn1
V1["V1"] --> Attn1
end
subgraph 组2
Q2["Q头5-8"] --> Attn2["注意力计算"]
K2["K2"] --> Attn2
V2["V2"] --> Attn2
end
Attn1 --> Concat["拼接"]
Attn2 --> Concat
Concat --> Output["输出"]
style Q fill:#fff9c4
style K fill:#a5d6a7
style V fill:#81c784
style Output fill:#c5e1a5
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, n_heads, n_kv_groups):
"""
参数:
d_model: 模型维度(如4096)
n_heads: Query头数(如32)
n_kv_groups: KV组数(如8)
"""
super().__init__()
assert n_heads % n_kv_groups == 0, "n_heads必须能被n_kv_groups整除"
self.d_model = d_model
self.n_heads = n_heads
self.n_kv_groups = n_kv_groups
self.n_heads_per_group = n_heads // n_kv_groups
self.d_k = d_model // n_heads
# Query: 每个头独立
self.W_Q = nn.Linear(d_model, d_model)
# Key & Value: 每组一个
self.W_K = nn.Linear(d_model, n_kv_groups * self.d_k)
self.W_V = nn.Linear(d_model, n_kv_groups * self.d_k)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, _ = x.size()
# 1. 计算Q (所有头)
Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k)
Q = Q.transpose(1, 2) # [batch, n_heads, seq_len, d_k]
# 2. 计算K, V (每组一个)
K = self.W_K(x).view(batch_size, seq_len, self.n_kv_groups, self.d_k)
V = self.W_V(x).view(batch_size, seq_len, self.n_kv_groups, self.d_k)
K = K.transpose(1, 2) # [batch, n_kv_groups, seq_len, d_k]
V = V.transpose(1, 2)
# 3. 将KV复制到每组内的所有头
K = K.repeat_interleave(self.n_heads_per_group, dim=1)
V = V.repeat_interleave(self.n_heads_per_group, dim=1)
# 现在 K, V: [batch, n_heads, seq_len, d_k]
# 4. 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# 5. 合并多头
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
output = self.W_O(attn_output)
return output
# 使用示例
d_model = 4096
n_heads = 32
n_kv_groups = 8 # LLaMA2-7B配置
gqa = GroupedQueryAttention(d_model, n_heads, n_kv_groups)
x = torch.randn(2, 10, d_model)
output = gqa(x)
print(f"输入: {x.shape}") # torch.Size([2, 10, 4096])
print(f"输出: {output.shape}") # torch.Size([2, 10, 4096])
graph TB
subgraph 性能对比
Quality["模型质量<br/>(困惑度 Perplexity)"]
Speed["推理速度<br/>(tokens/sec)"]
Memory["显存占用<br/>(GB)"]
end
subgraph MHA评分
Q_MHA["最好 ⭐⭐⭐⭐⭐"]
S_MHA["最慢 ⭐⭐"]
M_MHA["最高 ⭐"]
end
subgraph GQA评分
Q_GQA["接近MHA ⭐⭐⭐⭐"]
S_GQA["较快 ⭐⭐⭐⭐"]
M_GQA["适中 ⭐⭐⭐"]
end
subgraph MQA评分
Q_MQA["略低 ⭐⭐⭐"]
S_MQA["最快 ⭐⭐⭐⭐⭐"]
M_MQA["最低 ⭐⭐⭐⭐⭐"]
end
Quality --> Q_MHA
Quality --> Q_GQA
Quality --> Q_MQA
Speed --> S_MHA
Speed --> S_GQA
Speed --> S_MQA
Memory --> M_MHA
Memory --> M_GQA
Memory --> M_MQA
style Q_GQA fill:#fff59d
style S_GQA fill:#fff59d
style M_GQA fill:#fff59d
实验数据(LLaMA2论文):
标准Attention的瓶颈:
其中 是序列长度, 是维度。
graph LR
Seq["序列长度"] --> Comp["计算复杂度"]
L1["1K tokens"] --> C1["O(1M)"]
L2["10K tokens"] --> C2["O(100M)"]
L3["100K tokens"] --> C3["O(10B)"]
style L1 fill:#a5d6a7
style L2 fill:#fff9c4
style L3 fill:#ffcdd2
Claude 3处理200K上下文需要什么?
核心思想:不是所有token都需要关注所有其他token。
graph TB
subgraph Full[全注意力 O(n²)]
F["每个token<br/>关注所有token"]
end
subgraph Sparse[稀疏注意力]
S1["局部注意力<br/>Sliding Window"]
S2["全局注意力<br/>Global Tokens"]
S3["随机注意力<br/>Random Sampling"]
S4["分块注意力<br/>Blocked"]
end
Full -.优化.-> Sparse
style Full fill:#ffcdd2
style S1 fill:#a5d6a7
style S2 fill:#81c784
style S3 fill:#66bb6a
style S4 fill:#4caf50
思想:每个token只关注前后固定窗口内的token。
graph LR
subgraph 注意力矩阵
T1["Token 1"] -.-> W1["窗口1-3"]
T2["Token 2"] -.-> W2["窗口1-4"]
T3["Token 3"] -.-> W3["窗口1-5"]
T4["Token 4"] -.-> W4["窗口2-6"]
end
style T1 fill:#fff9c4
style W1 fill:#a5d6a7
复杂度: ,其中 是窗口大小(如512)
实现:
def sliding_window_mask(seq_len, window_size):
"""
生成滑动窗口mask
"""
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
start = max(0, i - window_size)
end = min(seq_len, i + window_size + 1)
mask[i, start:end] = 1
return mask
# 示例
mask = sliding_window_mask(10, window_size=2)
print(mask)
# tensor([[1., 1., 1., 0., 0., ...],
# [1., 1., 1., 1., 0., ...],
# [1., 1., 1., 1., 1., ...],
# ...])
思想:少数全局token关注所有,大部分token只做局部关注。
graph TB
subgraph 全局Token
G["CLS, SEP<br/>关注所有token"]
end
subgraph 局部Token
L["普通token<br/>只关注窗口内"]
end
G -.全注意力.-> All["全部序列"]
L -.局部.-> Window["小窗口"]
style G fill:#ffeb3b
style L fill:#90caf9
实现:
def longformer_mask(seq_len, window_size, global_indices):
"""
Longformer注意力mask
global_indices: 全局token的位置(如[0, 1])
"""
# 基础:滑动窗口
mask = sliding_window_mask(seq_len, window_size)
# 全局token可以关注所有
for idx in global_indices:
mask[idx, :] = 1 # 该行全1
mask[:, idx] = 1 # 该列全1
return mask
思想:将序列分块,块内全注意力,块间稀疏连接。
graph TB
subgraph Block1[块1]
B1_T1["Token 1-8"]
end
subgraph Block2[块2]
B2_T1["Token 9-16"]
end
subgraph Block3[块3]
B3_T1["Token 17-24"]
end
Block1 <-.块内全连接.-> Block1
Block2 <-.块内全连接.-> Block2
Block3 <-.块内全连接.-> Block3
Block1 -.稀疏连接.-> Block2
Block2 -.稀疏连接.-> Block3
style Block1 fill:#e3f2fd
style Block2 fill:#fff9c4
style Block3 fill:#f3e5f5
特殊说明:FlashAttention不改变注意力模式,而是优化GPU内存访问。
graph LR
subgraph 标准Attention[标准实现]
Step1["1. 计算QK^T<br/>写入HBM"]
Step2["2. 读取,Softmax<br/>写回HBM"]
Step3["3. 读取,乘V<br/>写回HBM"]
end
subgraph FlashAttn[FlashAttention]
Fused["分块计算<br/>全程在SRAM<br/>减少HBM访问"]
end
Step1 --> Step2 --> Step3
style Step1 fill:#ffccbc
style Step2 fill:#ffccbc
style Step3 fill:#ffccbc
style Fused fill:#a5d6a7
加速效果:
问题:大模型参数多,但每次前向传播只需要激活部分参数。
graph TB
Input["输入Token"] --> Router["路由网络<br/>决策选择专家"]
Router -->|20%概率| E1["专家1<br/>数学推理"]
Router -->|5%概率| E2["专家2<br/>代码生成"]
Router -->|60%概率| E3["专家3<br/>通用知识"]
Router -->|10%概率| E4["专家4<br/>创意写作"]
Router -->|5%概率| En["专家N<br/>..."]
E1 --> Combine["加权组合"]
E2 --> Combine
E3 --> Combine
E4 --> Combine
En --> Combine
Combine --> Output["输出"]
style Router fill:#fff59d
style E3 fill:#a5d6a7
style Combine fill:#90caf9
关键特点:
graph TB
X["输入 X"] --> SelfAttn["自注意力"]
SelfAttn --> Norm1["LayerNorm"]
Norm1 --> Router["路由网络<br/>Gating"]
subgraph MoE层
Router -->|权重w1| Expert1["FFN 专家1"]
Router -->|权重w2| Expert2["FFN 专家2"]
Router -->|权重0| Expert3["FFN 专家3<br/>未激活"]
Router -->|权重0| ExpertN["FFN 专家N<br/>未激活"]
end
Expert1 --> Sum["加权求和<br/>w1·E1 + w2·E2"]
Expert2 --> Sum
Sum --> Norm2["LayerNorm"]
Norm2 --> Output["输出"]
style Router fill:#fff59d
style Expert1 fill:#a5d6a7
style Expert2 fill:#81c784
style Expert3 fill:#e0e0e0
style ExpertN fill:#e0e0e0
Softmax路由:
Top-K选择:
PyTorch实现:
class MoELayer(nn.Module):
def __init__(self, d_model, d_ff, num_experts, top_k=2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# 路由网络
self.gate = nn.Linear(d_model, num_experts)
# 专家网络(FFN)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
for _ in range(num_experts)
])
def forward(self, x):
"""
x: [batch_size, seq_len, d_model]
"""
batch_size, seq_len, d_model = x.size()
# 1. 路由打分
gate_logits = self.gate(x) # [batch, seq_len, num_experts]
# 2. 选择Top-K专家
top_k_logits, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1)
# top_k_indices: [batch, seq_len, top_k]
# 3. Softmax归一化(只在Top-K上)
top_k_gates = F.softmax(top_k_logits, dim=-1)
# [batch, seq_len, top_k]
# 4. 计算专家输出并加权求和
output = torch.zeros_like(x)
for k in range(self.top_k):
# 获取当前专家索引
expert_idx = top_k_indices[:, :, k] # [batch, seq_len]
gate_weight = top_k_gates[:, :, k] # [batch, seq_len]
# 批量处理(简化版,实际中需要更高效的实现)
for i in range(self.num_experts):
mask = (expert_idx == i) # [batch, seq_len]
if mask.any():
expert_output = self.experts[i](x)
output += expert_output * gate_weight.unsqueeze(-1) * mask.unsqueeze(-1)
return output
# 使用示例
d_model = 512
d_ff = 2048
num_experts = 8
top_k = 2
moe = MoELayer(d_model, d_ff, num_experts, top_k)
x = torch.randn(2, 10, d_model)
output = moe(x)
print(f"输入: {x.shape}") # torch.Size([2, 10, 512])
print(f"输出: {output.shape}") # torch.Size([2, 10, 512])
架构特点:
graph TB
Model["Mixtral 8x7B"] --> Params["总参数: 47B"]
Model --> Active["激活参数: 13B"]
Model --> Speed["推理速度 ≈ 13B模型"]
Model --> Quality["性能接近 70B模型"]
style Model fill:#fff59d
style Speed fill:#a5d6a7
style Quality fill:#81c784
性能数据:
| 挑战 | 说明 | 解决方案 |
|---|---|---|
| 负载均衡 | 某些专家被过度使用 | 添加辅助损失函数 |
| 通信开销 | 分布式训练时专家在不同GPU | 专家并行策略 |
| 泛化性 | 专家过度专业化 | 正则化技术 |
负载均衡损失:
其中 CV 是变异系数,鼓励专家使用均匀。
| 技术 | 加速比 | 显存节省 | 质量损失 | 实现难度 | 适用场景 |
|---|---|---|---|---|---|
| KV Cache | 50x+ | 0% | 0% | ⭐ | 所有自回归模型(必备) |
| MQA | 2x | 96% | 3-5% | ⭐⭐ | 极致推理速度场景 |
| GQA | 1.3x | 75% | <1% | ⭐⭐ | 推荐,平衡方案 |
| Sparse Attn | 10x+ | 50%+ | 0-5% | ⭐⭐⭐⭐ | 超长文本(100K+) |
| MoE | 5x | 70% | 0% | ⭐⭐⭐⭐⭐ | 超大模型,计算受限 |
graph TD
Start{需求是什么?} --> Q1{序列长度?}
Q1 -->|<4K| Short[标准场景]
Q1 -->|4K-32K| Medium[中长文本]
Q1 -->|>32K| Long[超长文本]
Short --> Q2{显存限制?}
Q2 -->|宽松| Use_MHA[使用标准MHA<br/>+ KV Cache]
Q2 -->|紧张| Use_GQA[使用GQA<br/>+ KV Cache]
Medium --> Q3{质量要求?}
Q3 -->|最高| MHA_Long[MHA + KV Cache]
Q3 -->|平衡| GQA_Long[GQA + Sliding Window]
Long --> Sparse[Sparse Attention<br/>必选方案]
Start --> Q4{是否超大模型?}
Q4 -->|>100B| Consider_MoE[考虑MoE架构]
style Use_GQA fill:#fff59d
style GQA_Long fill:#fff59d
style Sparse fill:#a5d6a7
style Consider_MoE fill:#81c784
OpenAI GPT系列:
Meta LLaMA系列:
Google PaLM/Gemini:
Anthropic Claude:
class OptimizedTransformerBlock(nn.Module):
"""
集成GQA + KV Cache的优化Transformer Block
"""
def __init__(self, d_model, n_heads, n_kv_groups, d_ff, dropout=0.1):
super().__init__()
# GQA
self.gqa = GroupedQueryAttention(d_model, n_heads, n_kv_groups)
# FFN
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, cache=None, use_cache=False):
# Self-attention with cache
attn_out, new_cache = self.gqa(x, cache=cache, use_cache=use_cache)
x = self.norm1(x + self.dropout(attn_out))
# FFN
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out))
if use_cache:
return x, new_cache
return x
# LLaMA2-7B配置
d_model = 4096
n_heads = 32
n_kv_groups = 8 # GQA-8
d_ff = 11008
n_layers = 32
# 构建完整模型
class OptimizedLLM(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, n_kv_groups, d_ff, n_layers):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([
OptimizedTransformerBlock(d_model, n_heads, n_kv_groups, d_ff)
for _ in range(n_layers)
])
self.lm_head = nn.Linear(d_model, vocab_size)
def forward(self, input_ids, caches=None, use_cache=False):
x = self.embedding(input_ids)
new_caches = []
for i, layer in enumerate(self.layers):
cache = caches[i] if caches else None
if use_cache:
x, new_cache = layer(x, cache=cache, use_cache=True)
new_caches.append(new_cache)
else:
x = layer(x)
logits = self.lm_head(x)
if use_cache:
return logits, new_caches
return logits
# 使用示例
vocab_size = 32000
model = OptimizedLLM(vocab_size, d_model, n_heads, n_kv_groups, d_ff, n_layers)
print(f"模型参数量: {sum(p.numel() for p in model.parameters())/1e9:.2f}B")
# 输出: 模型参数量: 6.74B (接近LLaMA2-7B)
mindmap
root((现代LLM优化))
推理加速
KV Cache
缓存历史KV
O(n²)→O(n)
Flash Attention
IO优化
SRAM计算
显存优化
MQA
共享KV
节省96%
GQA
分组共享
节省75%
长文本
Sparse Attention
滑动窗口
全局+局部
RoPE
相对位置编码
超大模型
MoE
稀疏激活
专家路由
模型并行
专家并行
张量并行
1. 更长的上下文
2. 更高效的架构
3. 动态计算
4. 硬件协同优化
1. 计算KV Cache节省
# 给定LLaMA2-13B配置,计算生成1000个token的KV Cache大小
# n_layers=40, n_heads=40, d_k=128, seq_len=1000
2. 实现Sliding Window Mask
def create_sliding_window_mask(seq_len, window_size):
# TODO: 实现并可视化
pass
3. 对比GQA不同配置
# 实验GQA-4 vs GQA-8 vs MHA的性能和显存
论文:
代码: