红标笔趣阁
31.0MB · 2026-03-24
在前面的章节中,我们学习了大模型的核心原理:给定前面的Token序列,预测下一个Token。但是,当我们实际使用大模型进行文本生成时,会遇到一个严重的性能问题。
假设我们要让模型生成一句话:"今天天气真好"(5个Token)
第1步:输入提示词"今天"
第2步:继续生成
第3步:继续生成
注意到问题了吗?每次生成新Token时,模型都要重新计算前面所有Token的注意力!
让我们用数学来量化这个问题。
在注意力机制中,对于每个Token,我们需要计算:
假设我们要生成长度为100的文本,每个生成步骤的计算量:
| 步骤 | 序列长度 | 需要计算的Token数 | 累计计算量 |
|---|---|---|---|
| 1 | 1 | 1 | 1 |
| 2 | 2 | 2 | 1+2=3 |
| 3 | 3 | 3 | 1+2+3=6 |
| ... | ... | ... | ... |
| 100 | 100 | 100 | 1+2+...+100=5050 |
总计算量: 次Token的注意力计算
但实际上,真正需要的计算量只有100次!因为:
问题的根源:前面Token的K和V在每一步都被重新计算,但它们的值根本不会改变!
KV Cache的思想非常简单:
具体来说:
第1步:生成第1个Token
第2步:生成第2个Token
第3步:生成第3个Token
使用KV Cache后,生成100个Token的计算量:
| 步骤 | 需要新计算的KV | 从缓存读取的KV | 总计算量 |
|---|---|---|---|
| 1 | 1 | 0 | 1 |
| 2 | 1 | 1 | 2 |
| 3 | 1 | 2 | 3 |
| ... | ... | ... | ... |
| 100 | 1 | 99 | 100 |
总计算量:100次(从5050次降到100次,加速50倍!)
对于一个多头注意力层:
每一层的KV Cache形状:
全模型的KV Cache形状(所有层):
假设使用FP16精度(每个数2字节),模型参数:
单个样本的KV Cache内存:
Batch推理的内存(batch_size=32):
这就是为什么大模型推理需要大显存的原因之一!
| 模型 | 层数 | d_model | 头数 | 序列长度 | 单样本KV Cache | Batch=32 |
|---|---|---|---|---|---|---|
| GPT-2 Small | 12 | 768 | 12 | 1024 | 36 MB | 1.1 GB |
| LLaMA-7B | 32 | 4096 | 32 | 2048 | 1 GB | 32 GB |
| LLaMA-13B | 40 | 5120 | 40 | 2048 | 1.6 GB | 51 GB |
| LLaMA-65B | 80 | 8192 | 64 | 2048 | 5.1 GB | 163 GB |
| GPT-3 175B | 96 | 12288 | 96 | 2048 | 9.2 GB | 294 GB |
可以看到,对于超大模型,KV Cache可能比模型权重本身还要占用更多显存!
class MultiHeadAttentionWithKVCache:
def __init__(self, d_model, num_heads):
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 权重矩阵
self.W_Q = Parameter(torch.randn(d_model, d_model))
self.W_K = Parameter(torch.randn(d_model, d_model))
self.W_V = Parameter(torch.randn(d_model, d_model))
self.W_O = Parameter(torch.randn(d_model, d_model))
# KV Cache(初始为空)
self.k_cache = [] # List of cached K tensors
self.v_cache = [] # List of cached V tensors
def forward(self, x, use_cache=True):
"""
x: 输入Token的embedding,形状 (batch_size, 1, d_model)
注意:推理时每次只输入1个新Token
"""
batch_size = x.shape[0]
# 计算新Token的Q、K、V
Q_new = x @ self.W_Q # (batch_size, 1, d_model)
K_new = x @ self.W_K # (batch_size, 1, d_model)
V_new = x @ self.W_V # (batch_size, 1, d_model)
# 重塑为多头形状
Q_new = Q_new.view(batch_size, 1, self.num_heads, self.d_k)
K_new = K_new.view(batch_size, 1, self.num_heads, self.d_k)
V_new = V_new.view(batch_size, 1, self.num_heads, self.d_k)
if use_cache:
# 将新的K、V添加到缓存
self.k_cache.append(K_new)
self.v_cache.append(V_new)
# 拼接所有历史K、V
K = torch.cat(self.k_cache, dim=1) # (batch, seq_len, heads, d_k)
V = torch.cat(self.v_cache, dim=1)
else:
K = K_new
V = V_new
# 计算注意力
# Q: (batch, 1, heads, d_k) - 只有新Token的Query
# K: (batch, seq_len, heads, d_k) - 所有Token的Key(包括历史)
scores = torch.einsum('bqhd,bkhd->bhqk', Q_new, K) / math.sqrt(self.d_k)
# scores: (batch, heads, 1, seq_len)
attn_weights = F.softmax(scores, dim=-1)
# 加权求和
output = torch.einsum('bhqk,bkhd->bqhd', attn_weights, V)
# output: (batch, 1, heads, d_k)
# 重塑并投影
output = output.reshape(batch_size, 1, self.d_model)
output = output @ self.W_O
return output
def clear_cache(self):
"""清空KV Cache,开始新的生成任务"""
self.k_cache = []
self.v_cache = []
只计算新Token的K和V:
K_new = x @ self.W_K # x的形状是(batch, 1, d_model),只有1个Token
从缓存读取历史K、V:
K = torch.cat(self.k_cache, dim=1) # 拼接所有历史Token的K
注意力计算使用完整的K、V:
# Q只有1个Token(新Token)
# K、V有n个Token(所有历史Token + 新Token)
scores = torch.einsum('bqhd,bkhd->bhqk', Q_new, K)
这里有一个非常重要的问题:当我们使用KV Cache时,位置编码怎么办?
回顾一下绝对位置编码(Sinusoidal或Learned):
其中 是第 个位置的位置编码。
问题:当使用KV Cache时,每次只输入1个新Token,但这个Token的绝对位置在不断变化!
举例:
看起来没问题?但实际上有个隐藏的问题:
缓存的K、V已经包含了位置编码信息:
所以,绝对位置编码在KV Cache场景下是兼容的,但需要注意:
在实际实现中,位置编码通常预先计算并存储在一个位置编码表中:
class PositionalEncoding:
def __init__(self, d_model, max_seq_len=5000):
# 预先计算所有位置的编码
self.pe_table = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len).unsqueeze(1) # (max_seq_len, 1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
)
# 偶数维度用sin
self.pe_table[:, 0::2] = torch.sin(position * div_term)
# 奇数维度用cos
self.pe_table[:, 1::2] = torch.cos(position * div_term)
def get_position_encoding(self, position):
"""
position: 当前Token的位置索引(标量)
返回: 该位置的位置编码向量 (d_model,)
"""
return self.pe_table[position]
使用KV Cache时的流程:
# 第1步:生成第1个Token(位置0)
x_0 = token_embedding(token_0) + pe_table[0] # 加上位置0的编码
output_0 = attention(x_0)
# 第2步:生成第2个Token(位置1)
x_1 = token_embedding(token_1) + pe_table[1] # 加上位置1的编码
output_1 = attention(x_1) # 使用KV Cache,读取位置0的K、V
# 第3步:生成第3个Token(位置2)
x_2 = token_embedding(token_2) + pe_table[2] # 加上位置2的编码
output_2 = attention(x_2) # 使用KV Cache,读取位置0、1的K、V
RoPE(Rotary Position Embedding)是一种更现代的位置编码方式,它在计算注意力时动态地将位置信息旋转到Q和K中。
RoPE的优势:
RoPE在KV Cache中的应用:
def apply_rotary_pos_emb(q, k, position):
"""
应用旋转位置编码
q, k: (batch, seq_len, heads, d_k)
position: 当前Token的绝对位置
"""
# 计算旋转角度
theta = position / (10000 ** (torch.arange(0, d_k, 2) / d_k))
# 构造旋转矩阵
cos = torch.cos(theta)
sin = torch.sin(theta)
# 旋转Q和K
q_rot = apply_rotation(q, cos, sin)
k_rot = apply_rotation(k, cos, sin)
return q_rot, k_rot
# 使用KV Cache时
Q_new = apply_rotary_pos_emb(Q_new, position=current_position)
K_new = apply_rotary_pos_emb(K_new, position=current_position)
# 缓存已旋转的K
k_cache.append(K_new)
关键点:
| 位置编码类型 | 与KV Cache的兼容性 | 注意事项 |
|---|---|---|
| 绝对位置编码(Sinusoidal) | 兼容 | 需要预先计算位置编码表,传入正确的位置索引 |
| 绝对位置编码(Learned) | 兼容 | 同上,位置编码表是可学习参数 |
| RoPE | 完美兼容 | 缓存的K已包含位置信息,无需额外处理 |
| ALiBi | 完美兼容 | 位置偏置在计算注意力时动态添加 |
问题:标准多头注意力中,每个头都有自己的K和V,导致KV Cache很大。
解决方案:所有头共享一组K和V。
优势:
劣势:
折中方案:将头分成若干组,每组共享K和V。
例如:32个头分成4组,每组8个头共享一组K、V。
优势:
实际应用:
问题:KV Cache是连续内存块,当序列很长时,可能无法分配足够大的连续内存。
解决方案:将KV Cache分成固定大小的"页",类似操作系统的虚拟内存。
# 传统KV Cache:连续内存
k_cache = torch.zeros(batch, seq_len, heads, d_k) # 需要连续的seq_len空间
# Paged Attention:分页存储
page_size = 16 # 每页存储16个Token的K/V
num_pages = seq_len // page_size
k_cache_pages = [
torch.zeros(batch, page_size, heads, d_k) for _ in range(num_pages)
]
优势:
实际应用:
让我们通过一个完整的例子来理解KV Cache的工作流程。
初始状态:
Step 1:生成"今天"
输入:<BOS>(开始标记)
位置编码:PE[0]
计算:Q_0, K_0, V_0
KV Cache:K_0, V_0
输出:"今天"
Step 2:生成"天气"
输入:"今天"
位置编码:PE[1]
计算:Q_1, K_1, V_1(只计算新Token)
KV Cache:[K_0, K_1], [V_0, V_1](添加新的K、V)
注意力:Q_1 attend to [K_0, K_1]
输出:"天气"
Step 3:生成"真"
输入:"天气"
位置编码:PE[2]
计算:Q_2, K_2, V_2
KV Cache:[K_0, K_1, K_2], [V_0, V_1, V_2]
注意力:Q_2 attend to [K_0, K_1, K_2]
输出:"真"
Step 4:生成"好"
输入:"真"
位置编码:PE[3]
计算:Q_3, K_3, V_3
KV Cache:[K_0, K_1, K_2, K_3], [V_0, V_1, V_2, V_3]
注意力:Q_3 attend to [K_0, K_1, K_2, K_3]
输出:"好"
不使用KV Cache(每步重新计算):
使用KV Cache:
对于更长的序列(例如2048个Token),加速比接近1024倍!
当序列超过最大长度时,丢弃最早的Token:
max_cache_len = 2048
if len(k_cache) >= max_cache_len:
# 移除最早的Token
k_cache.pop(0)
v_cache.pop(0)
# 添加新Token
k_cache.append(K_new)
v_cache.append(V_new)
优势:简单,内存可控 劣势:可能丢失重要的历史信息
只保留最近的N个Token:
window_size = 512
if len(k_cache) >= window_size:
k_cache = k_cache[-window_size:]
v_cache = v_cache[-window_size:]
优势:专注于局部上下文 劣势:无法建模长距离依赖
根据注意力权重,保留重要的Token:
def prune_cache_by_attention(k_cache, v_cache, attention_weights, keep_ratio=0.5):
# 计算每个Token的平均注意力分数
importance = attention_weights.mean(dim=(0, 1)) # (seq_len,)
# 选择重要性最高的Token
num_keep = int(len(k_cache) * keep_ratio)
keep_indices = torch.topk(importance, num_keep).indices
# 只保留重要的Token
k_cache = [k_cache[i] for i in keep_indices]
v_cache = [v_cache[i] for i in keep_indices]
return k_cache, v_cache
优势:保留关键信息 劣势:计算复杂度高
最新的研究表明,大多数注意力权重集中在少数"重要"Token上:
策略:只缓存Heavy Hitters + Recent Tokens
def h2o_cache_management(k_cache, v_cache, attention_weights,
heavy_ratio=0.1, recent_ratio=0.1):
seq_len = len(k_cache)
# 计算累积注意力分数
cumulative_attention = attention_weights.sum(dim=(0, 1, 2)) # (seq_len,)
# 选择Heavy Hitters
num_heavy = int(seq_len * heavy_ratio)
heavy_indices = torch.topk(cumulative_attention, num_heavy).indices
# 选择Recent Tokens
num_recent = int(seq_len * recent_ratio)
recent_indices = torch.arange(seq_len - num_recent, seq_len)
# 合并索引
keep_indices = torch.cat([heavy_indices, recent_indices]).unique()
# 只保留选中的Token
k_cache = [k_cache[i] for i in keep_indices]
v_cache = [v_cache[i] for i in keep_indices]
return k_cache, v_cache
优势:
KV Cache是大模型推理加速的基石技术,几乎所有现代推理系统都依赖它来实现实时交互。理解KV Cache的原理,对于优化大模型部署和推理性能至关重要。