LLM KV Cache: A Simple Implementation
Introduction¶
在看很多大语言模型的推理代码时,发现有一个非常重要的概念,就是 KV Cache。 这里我们简要介绍一下 KV Cache 的核心原理并给出基于 GPT-2 的代码实现以便于本地复现。 相关的实验和测试代码同样开源在toyllm.
What¶
- 全称是 Key-Value Cache
- 这里的 Key 和 Value 是 Transformer/Attention 中的 Key 和 Value
- 一种空间换时间的优化策略,主要是为了加速大语言模型推理速度
- 作用是缓存模型在推理过程中计算出的中间结果,以便在后续的推理中复用这些结果,从而减少计算量
Why & How¶
目前 LLM 的核心架构都是 Decoder Only 结构 (只用了原始 Transformer 的 Decoder 部分),其核心结构是基于 Attention 的, 而 Attention 的核心计算是基于query
,key
和value
的矩阵运算。而 KV Cache 就是通过加速 Attention 的矩阵计算来加速整个模型的推理速度。具体来说,是通过缓存key
和value
的值来避免重复计算,从而减少计算量。那么一个核心的问题就是原始的 Attention 计算中存在哪些重复计算?下面我们将通过数学推导凸显 LLM 推理阶段原始 Attention 的重复计算问题。 在重复计算的问题被凸显出来之后,KV Cache 的实现原理也就显而易见了。 下面的推导主要参考了 Lei Mao 的博客: Transformer Autoregressive Inference Optimization.
我们先简要回顾一下原始的 Attention 计算过程:

在 LLM 推理阶段 (自回归生成,Auto-regressive Generation),在第 \(n+1\) 个 token 的生成过程中会做如下计算:
其中\(W^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}\), \(W^K \in \mathbb{R}^{d_{\text{model}} \times d_k}\), \(W^V \in \mathbb{R}^{d_{\text{model}} \times d_v}\)分别是 query,key 和 value 的权重矩阵。
我们将此时的 Attention 结果记为\(Y_n\),则有:
在生成第 \(n+1\) 个 token 后,其作为新的输入 \(x_{n+1} \in \mathbb{R}^{1 \times d_{\text{model}}}\)进入模型 (自回归生成), 此时输入张量变为\(X_{n+1} \in \mathbb{R}^{(n+1) \times d_{\text{model}}}\):
此时为了计算下一个生成的 token,我们需要计算 Attention 结果 \(Y_{n+1}\):
其中,
\(Y_{n+1}\)的计算过程如下:
从上面的推导我们可以看到,\(Y_{n+1}\)可以分解为两部分:
- 历史 token 的 attention 结果\(Y_n\),这部分在之前已经计算过
- 新 token 的 attention 结果\(y_{n+1}\),这部分需要重新计算
可以看到,在拿到第\(n+1\)个 token 后,我们只需要计算这个新 token 的 attention 结果\(y_{n+1}\), 因为\(Y_n\)已经计算过了,不需要重新计算。也就是说,在不使用 KV Cache 时,每次生成一个新 token 时,我们都需要:
- 计算当前输入序列\(X_{n+1}\)完整的 attention 矩阵\(Q_{n+1}K_{n+1}^T\),计算复杂度为\(O(n^2)\)
- 对 \(n\) 个 token 重复这个过程,总计算复杂度为\(O(n^3)\)
KV Cache 避免重复计算的方式是缓存数据和变更计算流程:缓存数据就是指缓存\(K_n\)和\(V_n\),变更计算流程就是指在生成新 token 时,只需要计算新 token 的 query 与所有 key 的点积。前者避免了\(K_n\)和\(V_n\)的重复计算,后者避免了\(Y_n\)的重复计算。

所以在使用 KV Cache 后:
- 计算 attention 矩阵时,只需要计算新 token 的 query 向量与所有 key 向量的点积,再乘以 value 向量,计算复杂度为\(O(n)\)
- 对 \(n\) 个 token 重复这个过程,总计算复杂度为\(O(n^2)\)
如此一来,我们就将计算复杂度从\(O(n^3)\)降低到了\(O(n^2)\),这里的\(n\)是输入序列的长度,\(n\)越大,推理加速效果越明显 (当然显存占用也会增加)。
Code¶
代码的实现其实也很简单,我们先来如何变更计算流程来使用 KV Cache。 首先是 generate
方法的变更,这里我们只展示关键的变更部分。 可以看到,这里每次传入的 model_input_tokens
是当前的输入序列,除了首次传入的长度为 prompt_tokens.shape[1]
的序列外, 之后每次传入的序列长度均为cur_pos - prev_pos = 1
:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
|
之后我们再深入这里self.gpt_model
的实现,这里我们只展示关键的变更部分。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
|
最后我们来看 KVCache
的具体实现:
import torch
from torch import nn
class KVCache(nn.Module):
def __init__(
self,
batch_size: int,
max_seq_len: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
) -> None:
super().__init__()
cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False)
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False)
self.batch_size = batch_size
self.cache_pos = 0
def reset(self) -> None:
"""Reset the cache to zero."""
self.k_cache.zero_()
self.v_cache.zero_()
self.cache_pos = 0
@property
def size(self) -> int:
return self.cache_pos
def update(self, k_val: torch.Tensor, v_val: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Update KV cache with the new ``k_val``, ``v_val`` and return the updated cache.
Args:
k_val (torch.Tensor): Current key tensor with shape [B, H, S, D]
v_val (torch.Tensor): Current value tensor with shape [B, H, S, D]
"""
bsz, _, seq_len, _ = k_val.shape
if bsz > self.k_cache.shape[0]:
raise ValueError( # noqa: TRY003
f"The current cache has been setup with a batch size of {self.k_cache.shape[0]}" # noqa: EM102
f", but found new key tensors with batch size {k_val.shape[0]}!"
)
assert (self.cache_pos + seq_len) <= self.k_cache.shape[2] # noqa: S101
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, self.cache_pos : self.cache_pos + seq_len] = k_val
v_out[:, :, self.cache_pos : self.cache_pos + seq_len] = v_val
# forward cache_pos seq_len positions along
self.cache_pos += seq_len
return k_out, v_out
如果你对这里的register_buffer
有疑惑
简单来说,register_buffer
可以使得我们在模型中注册一个持久化的buffer,这个buffer不会被视为模型的参数(自然也不会更新)。 这里存在一个问题,那就是为什么要用register_buffer
而不是直接用self.k_cache = ...
呢? 答案很简单,通过register_buffer
注册的buffer可以随着model.to(device)
的调用而自动转移到指定的设备上,而后一种方式则不会。
In essence, PyTorch buffers are tensor attributes associated with a PyTorch module or model similar to parameters, but unlike parameters, buffers are not updated during training.
Buffers in PyTorch are particularly useful when dealing with GPU computations, as they need to be transferred between devices (like from CPU to GPU) alongside the model's parameters. Unlike parameters, buffers do not require gradient computation, but they still need to be on the correct device to ensure that all computations are performed correctly.
更多的细节可以参考以下链接:
注意这里KVCache
内部cache
的维度:cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim)
. 其中,batch_size
是模型的 batch size,num_kv_heads
是模型的 kv head 的数量,max_seq_len
是模型的最大序列长度,head_dim
是每个 kv head 的维度。 也就是说,这里会缓存batch 中每个样本的每个 kv head 在每个位置上的 key 和 value(两者均为维度等于head_dim
的向量)。
KVCache
的update
方法会将当前的k_val
和v_val
更新到缓存中,并返回更新后的缓存。 更新时的操作也很简单,就是把当前的k_val
和v_val
更新到缓存中对应的位置上:
k_out[:, :, self.cache_pos : self.cache_pos + seq_len] = k_val
v_out[:, :, self.cache_pos : self.cache_pos + seq_len] = v_val
self.cache_pos += seq_len
这里的self.cache_pos
表示当前缓存的最后一个位置,seq_len
表示当前输入的序列长度。
和 torchtune 实现的些许差别
如果你之前看过torchtune的KV Cache实现,那么你会发现这里KV Cache的实现和torchtune基本是一样的——除了cache_pos
的实现。 在torchtune最早的实现中cache_pos
就是上面的这种形式,不过后续为了兼容torch.compile
将其实现为一个向量而不是一个整数。 具体参考对应issue: #2564, #1663.
The end¶
KV Cache 作为一项核心的 LLM 推理优化技术,已经在很多框架中应用,相关优化也在持续进行中。本文难以详尽介绍,故仅从其原始形态管窥一二。
最后,笔者花费了大量时间构思本文,力求简明易懂,但仍未达到理想状态。因拖延已久,决定先行发布,后续有时间再补充修改。