Nano-vLLM代码手撕笔记

28 minute read

Published:

Nano-vLLM 为vLLM(v0)的简化版实现,其支持单节点多卡(TP)offline推理,并实现了prefix caching、paged attention、cuda graph优化等核心技术,非常适合入门大模型推理框架。以引擎的启动与初始化和请求的输入、处理、输出过程为脉络,这里记录了本人对于Nano-vLLM仓库中几乎全部代码的逐行手撕笔记,涵盖了其涉及的所有知识点。

目录

安装踩坑记录

直接按照官方仓库的说明进行安装可能无法成功,主要是flash attn的编译和版本兼容性等带来的问题。安装之后直接运行demo也有可能跑不通,需要修改某些源码。经历多次踩坑后将遇到的问题以及解决方案记录于此

目录结构概览

.
├── LICENSE
├── README.md
├── assets
│   └── logo.png
├── bench.py
├── example.py
├── nanovllm
│   ├── __init__.py
│   ├── config.py
│   ├── engine
│   │   ├── block_manager.py
│   │   ├── llm_engine.py
│   │   ├── model_runner.py
│   │   ├── scheduler.py
│   │   └── sequence.py
│   ├── layers
│   │   ├── activation.py
│   │   ├── attention.py
│   │   ├── embed_head.py
│   │   ├── layernorm.py
│   │   ├── linear.py
│   │   ├── rotary_embedding.py
│   │   └── sampler.py
│   ├── llm.py
│   ├── models
│   │   └── qwen3.py
│   ├── sampling_params.py
│   └── utils
│       ├── context.py
│       └── loader.py
└── pyproject.toml

其中:

  • nanovllm/engine:引擎模块

    负责模型调度、推理执行和资源管理等,是nanovllm的核心

    • llm_engine.py

      定义了LLMEngine类,是整个推理引擎的入口,也是协调各组件工作的主类。初始化该类时,其init函数中会创建scheduler、启动各个子进程并创建对应的model_runner等。其还包含add_requeststep等用于系统全局执行的函数。

    • block_manager.py:管理KV blocks

    • model_runner.py:负责一个rank上的模型执行,在tp模式下其持有模型参数的一部分。其init函数中包括torch进程组的创建、对应部分的模型加载、kv cache资源分配等。

    • scheduler.py:实现请求调度逻辑,管理推理任务队列

    • sequence.py:定义了推理序列数据结构,用于跟踪生成过程

  • nanovllm/models:定义了支持的大模型总体架构

  • nanovllm/layers:定义了各层的具体实现

  • nanovllm/utils:辅助功能支持,loader.py用于加载模型,context.py为上下文管理工具

实验代码:

from nanovllm import LLM, SamplingParams
import torch
import os

model = "Qwen3-0.6B"
model_path = os.path.join(os.path.expanduser("~"), "huggingface", model)
llm = LLM(model_path, enforce_eager=False, tensor_parallel_size=1, dtype=torch.float16)
sampling_params = SamplingParams(temperature=0.1, max_tokens=512)
prompts = [
    "In the introduction, explain ", 
    "您好!"]
outputs = llm.generate(prompts, sampling_params)
print(outputs[0]["text"])
print("="*20)
print(outputs[1]["text"])

推理引擎初始化与停止

LLMEngine

初始化LLMEngine对象(LLM直接继承了LLMEngine):

class LLMEngine:

    def __init__(self, model, **kwargs):
        config_fields = {field.name for field in fields(Config)}
        config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
        config = Config(model, **config_kwargs)
        self.ps = []
        self.events = []
        ctx = mp.get_context("spawn")
        for i in range(1, config.tensor_parallel_size):
            event = ctx.Event()
            process = ctx.Process(target=ModelRunner, args=(config, i, event))
            process.start()
            self.ps.append(process)
            self.events.append(event)
        self.model_runner = ModelRunner(config, 0, self.events)
        self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
        config.eos = self.tokenizer.eos_token_id
        self.scheduler = Scheduler(config)
        atexit.register(self.exit)

其中:

  • tp_size=1,则只有一个主进程(rank=0),在主进程上创建一个ModelRunner对象用来承载整个模型

  • tp_size>1,则除了主进程(rank=0)以外,还会创建tp_size-1个子进程,主进程和这些子进程分别创建自己的ModelRunner对象并承载一部分模型参数。另外,还会创建一系列的Event,用于进程之间的通知操作,例如等待模型加载完成等。

    tp_size=4为例,则会额外创建3个子进程,主进程(rank=0)既负责作为主进程全局调度,同时也有一个负责1/4的模型参数的ModelRunner,其余3个子进程分别具有一个负责1/4参数的ModelRunner

主进程还会创建tokenizer、scheduler等组件。相当于主进程既负责tokenizer、scheduler等,还会负责一部分模型的ModelRunner。

在vllm v1中为了避免cpu端的tokenizer阻塞了gpu端的模型执行,因此进行了解耦优化,避免主进程既负责cpu操作又负责部分gpu操作:主进程只负责tokenizer、scheduler等全局组件,所有ModelRunner都会在子进程上创建,如tp_size=1时总共会有一个主进程和一个子进程,其中子进程的ModelRunner承载全部的模型参数。

主进程还注册了一个atexit函数,表示进程结束时会执行自定义的self.exit函数:

atexit.register(self.exit)

其使得engine停止时给自己的model runner发出exit信号,然后删除model runner,并终止所有子进程。详见后文引擎停止部分。

Scheduler与Block Manager的初始化

Scheduler定义在engine/scheduler.py中:

class Scheduler:

    def __init__(self, config: Config):
        self.max_num_seqs = config.max_num_seqs
        self.max_num_batched_tokens = config.max_num_batched_tokens
        self.eos = config.eos
        self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
        self.waiting: deque[Sequence] = deque()
        self.running: deque[Sequence] = deque()

其初始化时定义了一系列用于运行时调度的数据结构:

  • self.max_num_seqs:单次调度时允许的最大并行序列数(最大bsz)

  • self.max_num_batched_tokens:一个batch中的token总数上限

  • self.waiting:一个双端队列,用于存放等待被调度的序列

  • self.running:一个双端序列,用于存放已分配资源并正在运行的序列

多个请求会在这两个队列之间轮转。调度逻辑默认采用prefill优先策略,即当收到新的prefill请求时,可打断当前正在进行的decode执行,被中断的请求将被移入 waiting队列。抢占逻辑为:当请求执行资源不足时,按照入队顺序将running队列中后进入的请求转移至 waiting队列。

scheduler中还定义了self.block_manager用于管理kv blocks,BlockManager定义在engine/block_manager.py中:

class BlockManager:

    def __init__(self, num_blocks: int, block_size: int):
        self.block_size = block_size
        self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
        self.hash_to_block_id: dict[int, int] = dict()
        self.free_block_ids: deque[int] = deque(range(num_blocks))
        self.used_block_ids: set[int] = set()

block manager用于管理一组block的分配、释放、缓存逻辑。其中:

  • num_blocks:当前device能容纳的block总数(config.num_kvcache_blocks),是在model runner初始化过程中通过假输入测出来的

  • self.block_size:每个block的大小(容纳的token数),是一个预定义值,在Qwen3-0.6B中取256

  • self.blocks:一个由Block对象构成的列表,列表长度就是block总数num_blocks。它保存的是“逻辑kv blocks”。

  • self.hash_to_block_id:一个字典,为哈希值到block id到映射,用于快速查找

  • self.free_block_ids:一个由空闲的block id构成的双端队列,可以看成“未激活”的block。初始时为当前所有blocks(range(num_blocks)

  • self.used_block_ids:一个由正在被使用的block的id组成的集合,也即当前至少有1个尚未生成完毕的序列在使用它,可以看成“active”的block。初始时所有block都没被使用,因此为空。

再具体到每个Block对象:

class Block:

    def __init__(self, block_id):
        self.block_id = block_id
        self.ref_count = 0
        self.hash = -1
        self.token_ids = []

    def update(self, hash: int, token_ids: list[int]):
        self.hash = hash
        self.token_ids = token_ids

    def reset(self):
        self.ref_count = 1
        self.hash = -1
        self.token_ids = []

每个block具有:

  • self.block_id:该block的id,是它的唯一标识

  • self.ref_count:该block被引用的计数,也即有多少个序列正在使用这个block(包括首次产生这个block的序列以及后续命中该block缓存从而也在利用它的序列)

  • self.hash:block内容的哈希值,用于判断是否缓存命中。只有满的block才有哈希值(并参与缓存记录与匹配),未满的block的哈希值设为-1,不参与缓存记录与匹配

  • self.token_ids:该block存储的token id列表

其含有的方法包括:

  • update:更新该block的哈希值和保存的tokens的id(当该block被满填充时才会触发,因为未满的block不参与计算哈希值,其hashtoken_ids都为空,即使其已经含有了一部分tokens)

  • reset:将其“激活”(可以是第一次被利用时激活,也可以是曾经被用过但用它的序列全都已经生成完毕了,现在由于其kv缓存被命中所以重新激活)

Block对象其实就是”逻辑kv block“,其只保存tokens的id信息,并不保存具体的kv数据。真正的kv数据保存在物理块中(Attention.kv_cache)。

ModelRunner

注:model runner是每个rank上都有的,因此接下来的所有内容都指的是一个rank上的

各个rank上的ModelRunner的初始化如下:

class ModelRunner:

    def __init__(self, config: Config, rank: int, event: Event \| list[Event]):
        self.config = config
        hf_config = config.hf_config
        self.block_size = config.kvcache_block_size
        self.enforce_eager = config.enforce_eager
        self.world_size = config.tensor_parallel_size
        self.rank = rank
        self.event = event

        dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
        torch.cuda.set_device(rank)
        default_dtype = torch.get_default_dtype()
        torch.set_default_dtype(hf_config.torch_dtype)
        torch.set_default_device("cuda")
        self.model = Qwen3ForCausalLM(hf_config)
        load_model(self.model, config.model)
        self.sampler = Sampler()
        self.warmup_model()
        self.allocate_kv_cache()
        if not self.enforce_eager:
            self.capture_cudagraph()
        torch.set_default_device("cpu")
        torch.set_default_dtype(default_dtype)

        if self.world_size > 1:
            if rank == 0:
                self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
                dist.barrier()
            else:
                dist.barrier()
                self.shm = SharedMemory(name="nanovllm")
                self.loop()

可见,每个rank上的ModelRunner初始化时包含了初始化torch默认进程组、设置当前进程的默认cuda device等,并加载该进程对应模型的权重片段。还会进行模型的warmup、捕捉cuda graph、分配kv cache空间等。如果tp_size>1则意味着该进程需要和其他进程进行通信,创建一块大小为1MB的SharedMemory来进行进程间的信号传递(例如终止进程等)。

分片模型加载与Embedding/lm_head/Linear层

模型加载属于当前rank的ModelRunner初始化过程中的一部分,先定义模型,然后调用load weight函数来加载该rank对应的模型参数:

self.model = Qwen3ForCausalLM(hf_config)
load_model(self.model, config.model)

其中load_model函数其实就是读取了模型的safetensors权重文件,逐层拆出各层参数,然后使用这些层自己的weight_loader来按照各自的切分方式加载一部分参数(具体的load方法定义在各层中,详见下文):

from safetensors import safe_open

def load_model(model: nn.Module, path: str):
    packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
    for file in glob(os.path.join(path, "*.safetensors")):
        with safe_open(file, "pt", "cpu") as f:
            for weight_name in f.keys():
                for k in packed_modules_mapping:
                    if k in weight_name:
                        v, shard_id = packed_modules_mapping[k]
                        param_name = weight_name.replace(k, v)
                        param = model.get_parameter(param_name)
                        
                        weight_loader = getattr(param, "weight_loader")
                        weight_loader(param, f.get_tensor(weight_name), shard_id)
                        break
                else:
                    param = model.get_parameter(weight_name)
                    weight_loader = getattr(param, "weight_loader", default_weight_loader)
                    weight_loader(param, f.get_tensor(weight_name))

Embedding层的切分:

每个rank上的模型的embedding层和lm_head层使用了VocabParallelEmbedding和其子类ParallelLMHead,定义在layers/embed_head.py中:

  • VocabParallelEmbedding

    class VocabParallelEmbedding(nn.Module):
      
        def __init__(
            self,
            num_embeddings: int,
            embedding_dim: int,
        ):
            super().__init__()
            self.tp_rank = dist.get_rank()
            self.tp_size = dist.get_world_size()
            assert num_embeddings % self.tp_size == 0
            self.num_embeddings = num_embeddings
            self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
            self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
            self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
            self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
            self.weight.weight_loader = self.weight_loader
      
        def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
            param_data = param.data
            shard_size = param_data.size(0)
            start_idx = self.tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
            param_data.copy_(loaded_weight)
      
        def forward(self, x: torch.Tensor):
            if self.tp_size > 1:
                mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
                x = mask * (x - self.vocab_start_idx)
            y = F.embedding(x, self.weight)
            if self.tp_size > 1:
                y = mask.unsqueeze(1) * y
                dist.all_reduce(y)
            return y
    

    词表并行是对embedding权重矩阵(embedding查找表)按行切分,也即每个device上放一部分tokens的embedding。其中:

    • self.num_embeddings:总词表大小

    • self.num_embeddings_per_partition=self.num_embeddings // self.tp_size:每个rank上平均承担的embedding行数

    • self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
      self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
      

      本rank负责的词表的起始位置和结束位置

    • self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim)):预分配本rank的embedding层的权重张量

    加载权重时,将读取的总embedding权重数据中的指定部分通过narrow函数切出来,然后copy到embedding层的权重张量中。

    推理时,对于给到该rank上的完整输入序列x,首先筛选出序列中属于本rank负责的部分tokens(将其他token mask掉),并使用本rank上的这部分embedding权重映射得到这些tokens的embedding,而其他token的位置则暂时以0占位。当所有rank上的embedding都映射完毕后,最后再通过all reduce将各个rank上的结果合并,最终使得所有rank上都得到了输入序列x的完整embedding。

  • ParallelLMHead

    class ParallelLMHead(VocabParallelEmbedding):
      
        def __init__(
            self,
            num_embeddings: int,
            embedding_dim: int,
            bias: bool = False,
        ):
            assert not bias
            super().__init__(num_embeddings, embedding_dim)
      
        def forward(self, x: torch.Tensor):
            context = get_context()
            if context.is_prefill:
                last_indices = context.cu_seqlens_q[1:] - 1
                x = x[last_indices].contiguous()
            logits = F.linear(x, self.weight)
            if self.tp_size > 1:
                all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
                dist.gather(logits, all_logits, 0)
                logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
            return logits
    

    lm head把模型最后一层输出的hidden states向量通过一个线性层映射成维数为词表大小V的logits向量,各个元素即为体现各个token被生成概率的logit。

    其权重形状和embedding层一样,也是(num_embeddings, embedding_dim),本质上可以看成embedding的逆过程(甚至可以和embedding层共享权重,详见下文模型部分),因此权重初始化和其父类VocabParallelEmbedding保持一致即可。

    在forward时:

    • 若处于prefill阶段,则输入x是由batch中各个子序列连成的一个大序列,形状为(seq_len1+seq_len2+...+seq_lenn, hidden_dim)cu_seqlens_q中保存的是每个子序列在x中的起始位置,因此cu_seqlens_q[1:] - 1就是每个子序列的last token的位置,该位置的hidden state通过lm head映射得到的就是该子序列的下一个token预测logit。

      因此,如下两行的逻辑就是得到各个子序列的last token位置,并将它们的hidden state向量分别提取出来,然后再拼凑成一个连续的形状为(bsz, hidden_dim)的新张量x,其通过lm head后就可得到每个子序列的next token的logit:

      last_indices = context.cu_seqlens_q[1:] - 1
      x = x[last_indices].contiguous()
      
    • 若处于decode阶段,则输入的x形状为(bsz, hidden_dim),也即由batch中每个子序列的last token连成的序列,它们经过lm head后得到的就是各个子序列的next token的logit,因此无需额外处理。

    然后即可将形状为(bsz, hidden_dim)x送入本rank持有的lm head线性层中,该层权重形状为(num_embeddings_per_partition, embedding_dim),执行$xW^T$后即可得到本rank上负责的这些embeddings对应的各个tokens(总共num_embeddings_per_partition个)在这些子序列的next token生成概率logits,形状为(bsz, num_embeddings_per_partition)

    tp_size>1的情况下,rank0住进程负责创建一个buffer all_logits用于汇总各个rank上的部分logits。all_logits是一个长度为tp_size的列表,每个列表元素是一个形状为(bsz, num_embeddings_per_partition)的空张量,用于承载来自各个rank的部分logits张量。然后,使用gather操作将所有rank上的部分logits都汇总到rank0的这个all_logits中,最后再将它们拼接成完整的logits后由rank0返回。

    可见,最后只有rank0会返回真正的logits,其他rank返回的都是None。因为之后这些logits会被送到rank0上的tokenizer中进行decode,而tokenizer只存在于rank0中,因此其他rank将logits送给rank0后使命就结束了。

Linear层切分

Attention和FFN这两个模块中涉及多种linear层的切分策略,定义在layers/linear.py中,linear层的基类为:

class LinearBase(nn.Module):

    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = False,
        tp_dim: int \| None = None,
    ):
        super().__init__()
        self.tp_dim = tp_dim
        self.tp_rank = dist.get_rank()
        self.tp_size = dist.get_world_size()
        self.weight = nn.Parameter(torch.empty(output_size, input_size))
        self.weight.weight_loader = self.weight_loader
        if bias:
            self.bias = nn.Parameter(torch.empty(output_size))
            self.bias.weight_loader = self.weight_loader
        else:
            self.register_parameter("bias", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

可见,其考虑了pytorch中对于线性层权重转置的问题:self.weight的形状为(output_size, input_size),forward调用F.linear(x, weight)计算时执行$xW^T$,相当于(seq_len, input_size)*(input_size, output_size)=(seq_len, output_size)。因此本项目中所有自定义linear层中的self.weight都是真实计算时的专置版本,若计算时想要按列切分则应对self.weight按行切分,反之亦然。下文中“计算时权重”指self.weight.T

其有如下几种子类,定义了不同切分方式:

  • ReplicatedLinear

    不进行权重切分,将同一份完整权重复制到所有rank上,相当于dp的做法。因此当前rank的模型加载的就是输入的完整参数的一个副本(在nano-vllm tp中没有用到):

    class ReplicatedLinear(LinearBase):
      
        def __init__(
            self,
            input_size: int,
            output_size: int,
            bias: bool = False,
        ):
            super().__init__(input_size, output_size, bias)
      
        def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
            param.data.copy_(loaded_weight)
      
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return F.linear(x, self.weight, self.bias)
    
  • ColumnParallelLinear

    class ColumnParallelLinear(LinearBase):
      
        def __init__(
            self,
            input_size: int,
            output_size: int,
            bias: bool = False,
        ):
            tp_size = dist.get_world_size()
            super().__init__(input_size, divide(output_size, tp_size), bias, 0)
      
        def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
            param_data = param.data
            shard_size = param_data.size(self.tp_dim)
            start_idx = self.tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
            param_data.copy_(loaded_weight)
      
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return F.linear(x, self.weight, self.bias)
    

    其用于承载按列切分的线性层,用于按列(output dimension)切分的tp。

    在原始transformer ffn中可用于$W_{up}$,在attention中可用于$W_Q,W_K,W_V$。而由于这里的Qwen3采用略有变化的SwiGLU作为ffn,因此采用其子类MergedColumnParallelLinear,而QKV也采用的是其另一个子类QKVParallelLinear

    其初始化计算时权重形状为(input_size, output_size//tp_size)的线性层(实际self.weight形状为(output_size//tp_size, input_size)),也即其output size(计算时权重的列数,self.weight的行数)是完整权重的1/tp_size。例如A=[A1\|A2\|A3\|A4]中的A1(计算时权重)

    加载权重时:

    • param_data是本层定义的Parameter(self.weight),形状为(output_size//tp_size, input_size)。初始时其为空张量,接下来会将真实权重值加载进来

    • self.tp_dim=0表示按self.weight的行进行切分(相当于计算时权重按列切分),因此shard_size=output_size//tp_size,也即每个分片的行数。

      start_idx = self.tp_rank * shard_size表示相对于完整权重来说,本分片对应的行数起点,接下来本分片应该获得完整权重从该行开始的shard_size

    • ` loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)loaded_weight为从模型权重文件中读取的本层完整权重(形状(output_size, input_size))。narrow操作表示沿dim=0(行)维度,从完整权重中切出从第start_idx行开始的shard_size行,这也就是本分片被分配的权重部分(A[start_idx: start_idx+shard_size, :]`)

      最终将这个切片复制到本层参数中,这就完成了本层的加载:param_data.copy_(loaded_weight)

  • MergedColumnParallelLinear

    class MergedColumnParallelLinear(ColumnParallelLinear):
      
        def __init__(
            self,
            input_size: int,
            output_sizes: list[int],
            bias: bool = False,
        ):
            self.output_sizes = output_sizes
            super().__init__(input_size, sum(output_sizes), bias)
      
        def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
            param_data = param.data
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
            param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
            loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
            param_data.copy_(loaded_weight)
    

    是普通ColumnParallelLinear的一个子类,可以将多个ColumnParallelLinear合并为一个矩阵乘,从而将多个矩阵乘看成一个矩阵乘来计算,减小多个kernel启动的开销。其用于SwiGLU中的$W_{up},W_{gate}$。

    在SwiGLU中计算$x\cdot W_{up}$和$x\cdot W_{gate}$是相互独立的,因此是可以并行的,因此将两个权重矩阵按列连接成一个矩阵$[W_{up}|W_{gate}]$,然后执行一次矩阵乘$x\cdot[W_{up}|W_{gate}]=[xW_{up}|xW_{gate}]$,将得到的结果$[xW_{up}|xW_{gate}]$按列切成两份即可分别得到$xW_{up}$和$xW_{gate}$的结果了。

    其初始化的output_sizes是一个列表,里边保存了要merge的各个矩阵乘的output_size(也即计算时权重的列数),如output_sizes=[3072,3072],则它们按列合并成一个大权重矩阵后的output_size=sum(output_sizes)=3072+3072=6144。此时相当于将问题转化为和权重$[W_{gate}|W_{up}]$做ColumnParallelLinear矩阵乘,接下来即可模仿ColumnParallelLinear的分片逻辑来处理tp。

    这里和普通ColumnParallelLinear的权重加载逻辑主要区别在于,多了一个loaded_shard_id标识符,其取值为0,1,分别表示当前读的是$W_{gate}$权重还是$W_{up}$权重,从而确定其在self.weight中的正确放置位置。

    tp_size>1时,融合后的权重为部分$W_{up1}$+部分$W_{gate}$,如一个rank具有$[xW_{up1}|xW_{gate1}]$,另一个rank具有$[xW_{up2}|xW_{gate2}]$,而不是一个rank具有$xW_{up}$而另一个rank具有$xW_{gate}$。这样方便后续将两部分分离,从而进行后续计算。

  • QKVParallelLinear

    class QKVParallelLinear(ColumnParallelLinear):
      
        def __init__(
            self,
            hidden_size: int,
            head_size: int,
            total_num_heads: int,
            total_num_kv_heads: int \| None = None,
            bias: bool = False,
        ):
            tp_size = dist.get_world_size()
            total_num_kv_heads = total_num_kv_heads or total_num_heads
            self.head_size = head_size
            self.num_heads = divide(total_num_heads, tp_size)
            self.num_kv_heads = divide(total_num_kv_heads, tp_size)
            output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
            super().__init__(hidden_size, output_size, bias)
      
        def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
            param_data = param.data
            assert loaded_shard_id in ["q", "k", "v"]
            if loaded_shard_id == "q":
                shard_size = self.num_heads * self.head_size
                shard_offset = 0
            elif loaded_shard_id == "k":
                shard_size = self.num_kv_heads * self.head_size
                shard_offset = self.num_heads * self.head_size
            else:
                shard_size = self.num_kv_heads * self.head_size
                shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
            param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
            loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
            param_data.copy_(loaded_weight)
    

    它同样是ColumnParallelLinear的一个子类,可以将$xW_Q,xW_K,xW_V$融合成一个矩阵乘$x[W_Q|W_K|W_V]$

    其处理逻辑和MergedColumnParallelLinear类似,只不过前者处理的是两个矩阵$W_{up},W_{gate}$的融合,这里处理的是3个矩阵$W_Q,W_K,W_V$的融合。标识符loaded_shard_id'q','k','v'用于标识当前读的是QKV哪部分的权重。

    tp_size>1的情况下,每个rank获得的都是按列切的部分$W_Q,W_K,W_V$横向排列成的大矩阵,例如$[W_{Q1}|W_{K1}|W_{V1}]$。

    由于multi head也是输入和按列切分的权重分别计算得到的,如$xW_Q=[xW_{Qh1}|\cdots|xW_{Qhn}]$,因此其天然适合按列做tp切分。例如rank1上的$[W_{Q1}|W_{K1}|W_{V1}]$包含了head1~head4所需的QKV权重,rank2上的$[W_{Q2}|W_{K2}|W_{V2}]$包含了head5~head8所需的QKV权重,接下来各个rank上可以进一步独立地进行multi head attention计算(不同head互不影响),然后再分别独立地和$W_o$的各个行切片做矩阵乘后,最后all reduce即可获得最终输出结果。

  • RowParallelLinear

    class RowParallelLinear(LinearBase):
      
        def __init__(
            self,
            input_size: int,
            output_size: int,
            bias: bool = False,
        ):
            tp_size = dist.get_world_size()
            super().__init__(divide(input_size, tp_size), output_size, bias, 1)
      
        def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
            param_data = param.data
            shard_size = param_data.size(self.tp_dim)
            start_idx = self.tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
            param_data.copy_(loaded_weight)
      
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
            if self.tp_size > 1:
                dist.all_reduce(y)
            return y
    

    其用于承载按行切分的线性层,用于按行(input dimension)切分的tp,通常位于ColumnParallelLinear层的后边。其用于ffn中的$W_{down}$和attention中的$W_o$。

    其初始化计算时权重形状为(input_size//tp_size, output_size)的线性层(实际self.weight形状为(output_size, input_size//tp_size)),例如$B=\begin{bmatrix}B_1\B_2\end{bmatrix}$中的$B_1$(计算时权重)

    加载权重时,self.tp_dim=1表示按self.weight的列进行切分(相当于计算时权重按行切分)

    在forward中,其输入的x一般是上一个ColumnParallelLinear的输出,也即一个分片的计算结果(如$xA_1$),它可以独立地和$B_1$做矩阵乘,得到和最终完整结果形状一样的部分和结果。最后它需要和其他rank的计算结果(如$xA_2\cdot B_2$)做all reduce得到最终完整结果。这样就保证了每个rank上该层的输出都是完整的结果。

上述分片策略在模型中的应用

每个rank上的模型最外层被包装为一个Qwen3ForCausalLM对象,定义在models/qwen3.py中:

class Qwen3ForCausalLM(nn.Module):
    packed_modules_mapping = {
        "q_proj": ("qkv_proj", "q"),
        "k_proj": ("qkv_proj", "k"),
        "v_proj": ("qkv_proj", "v"),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }

    def __init__(
        self,
        config: Qwen3Config
    ) -> None:
        super().__init__()
        self.model = Qwen3Model(config)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
        if config.tie_word_embeddings:
            self.lm_head.weight.data = self.model.embed_tokens.weight.data

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
    ) -> torch.Tensor:
        return self.model(input_ids, positions)

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        return self.lm_head(hidden_states)

其中定义了模型主干self.model和输出头self.lm_head。输出头lm_head使用ParallelLMHead,用于把最后一层输出的hidden states(设形状为$h_t\in\mathbb R^{1\times d}$)通过一个$W_{out}\in\mathbb R^{V\times d}$的lm head linear层映射,得到的$logits\in\mathbb R^{1\times V}$的向量就是词表中每个token的logit概率:

\[logits=h_tW_{out}^T\]

如果指定了config.tie_word_embeddings=True的话,则会使得embedding层(也即embedding查找表,形状也是$\mathbb R^{V\times d}$)和lm head层共享一套权重值,从而使的二者的总共显存占用减半。这样做相当于让每个token的logit值等于hidden state和它的embedding的内积,也即hidden state和该token embedding越接近那么该token被生成的概率就越大:$logits_i=h_t\cdot E_i$。这种做法不仅省显存,还有利于embedding的语义一致性,使得输入token空间和输出token空间一致。

每个rank上的self.model对象为模型主干:

class Qwen3Model(nn.Module):

    def __init__(
        self,
        config: Qwen3Config,
    ) -> None:
        super().__init__()
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        residual = None
        for layer in self.layers:
            hidden_states, residual = layer(positions, hidden_states, residual)
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

其中定义了embedding层和各个decoder layer。embedding层使用的是VocalParallelEmbedding

每个rank上初始化的Qwen3模型中的一个decoder Layer定义如下:

dclass Qwen3DecoderLayer(nn.Module):

    def __init__(
        self,
        config: Qwen3Config,
    ) -> None:
        super().__init__()
        self.self_attn = Qwen3Attention(
            hidden_size=config.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            max_position=config.max_position_embeddings,
            rms_norm_eps=config.rms_norm_eps,
            qkv_bias=getattr(config, 'attention_bias', True),
            head_dim=getattr(config, 'head_dim', None),
            rope_theta=getattr(config, "rope_theta", 1000000),
            rope_scaling=getattr(config, "rope_scaling", None),
        )
        self.mlp = Qwen3MLP(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
        )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: torch.Tensor \| None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if residual is None:
            hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions, hidden_states)
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual

其中主要包含了Attention和FFN两部分:

  • Qwen3Attention

    class Qwen3Attention(nn.Module):
      
        def __init__(
            self,
            hidden_size: int,
            num_heads: int,
            num_kv_heads: int,
            max_position: int = 4096 * 32,
            head_dim: int \| None = None,
            rms_norm_eps: float = 1e-06,
            qkv_bias: bool = False,
            rope_theta: float = 10000,
            rope_scaling: tuple \| None = None,
        ) -> None:
            super().__init__()
      				
            # 一些config参数的设置
            # ...
      
            self.qkv_proj = QKVParallelLinear(
                hidden_size,
                self.head_dim,
                self.total_num_heads,
                self.total_num_kv_heads,
                bias=qkv_bias,
            )
            self.o_proj = RowParallelLinear(
                self.total_num_heads * self.head_dim,
                hidden_size,
                bias=False,
            )
            self.rotary_emb = get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=max_position,
                base=rope_theta,
                rope_scaling=rope_scaling,
            )
            self.attn = Attention(
                self.num_heads,
                self.head_dim,
                self.scaling,
                self.num_kv_heads,
            )
            if not self.qkv_bias:
                self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
                self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
      
        def forward(
            self,
            positions: torch.Tensor,
            hidden_states: torch.Tensor,
        ) -> torch.Tensor:
            qkv = self.qkv_proj(hidden_states)
            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
            q = q.view(-1, self.num_heads, self.head_dim)
            k = k.view(-1, self.num_kv_heads, self.head_dim)
            v = v.view(-1, self.num_kv_heads, self.head_dim)
            if not self.qkv_bias:
                q = self.q_norm(q)
                k = self.k_norm(k)
            q, k = self.rotary_emb(positions, q, k)
            o = self.attn(q, k, v)
            output = self.o_proj(o.flatten(1, -1))
            return output
    

    tp_size>1时,这里的$W_Q,W_K,W_V$都是按列切分的,$W_o$按行切分。每个rank上分别具有一部分heads的$W_Q,W_K,W_V$,它们独立地做完attention后,再分别独立地和$W_o$的各个行切片做矩阵乘,得到部分和结果后通过all reduce使得每个rank上都得到一份最终结果(all reduce操作在self.o_projRowParallelLinear)中进行)。

  • Qwen3MLP

    class Qwen3MLP(nn.Module):
      
        def __init__(
            self,
            hidden_size: int,
            intermediate_size: int,
            hidden_act: str,
        ) -> None:
            super().__init__()
            self.gate_up_proj = MergedColumnParallelLinear(
                hidden_size,
                [intermediate_size] * 2,
                bias=False,
            )
            self.down_proj = RowParallelLinear(
                intermediate_size,
                hidden_size,
                bias=False,
            )
            assert hidden_act == "silu"
            self.act_fn = SiluAndMul()
      
        def forward(self, x):
            gate_up = self.gate_up_proj(x)
            x = self.act_fn(gate_up)
            x = self.down_proj(x)
            return x
    

    MLP部分的数学表达式为LLaMA SwiGLU:

    \[\begin{align} \text{FFN}(x)&=\text{down\_proj}(\text{up\_proj}(x)\otimes \text{Activate}(\text{gate\_proj}(x)))\]
\[&=((x\cdot W_{up})\odot\text{SiLU}(x\cdot W_{gate}))W_{down} \end{align}\]

其中:

  • gate_up=self.gate_up_proj(x):输入完整的hidden state $x$,融合执行$x\cdot W_{up}$和$x\cdot W_{gate}$,返回结果为二者横向的拼接$[xW_{up}|xW_{gate}]$

    这里的self.gate_up_proj是一个MergedColumnParallelLinear对象,在tp_size>1的情况下返回的是本rank负责的部分切片结果,例如rank1为$[xW_{up1}|xW_{gate1}]$,rank2为$[xW_{up2}|xW_{gate2}]$。

  • x = self.act_fn(gate_up):执行$(x\cdot W_{up})\odot\text{SiLU}(x\cdot W_{gate})$

    self.act_fn=SiluAndMul()的forward定义为:

    @torch.compile  
    def forward(self, x: torch.Tensor) -> torch.Tensor:
          x, y = x.chunk(2, -1)
          return F.silu(x) * y
    

    其输入的gate_up为$x\cdot W_{up}$和$x\cdot W_{gate}$的结果横向并列连接成的张量\([xW_{up}\|xW_{gate}]\),x.chunk(2, -1)就是沿hidden dim维度(按列)将其切分为两块,切分之后的x,y分别就是$x\cdot W_{up}$和$x\cdot W_{gate}$,因此return的结果是$(x\cdot W_{up})\odot\text{SiLU}(x\cdot W_{gate})$。

    tp_size>1,则此时的结果仍然是按列切分的部分切片结果,例如rank1的结果为$(x\cdot W_{up1})\odot\text{SiLU}(x\cdot W_{gate1})$。

  • x = self.down_proj(x):执行$((x\cdot W_{up})\odot\text{SiLU}(x\cdot W_{gate}))W_{down}$,返回FFN的最终输出

    这里的self.down_projRowParallelLinear,当tp_size>1时其接收的输入是按列切分的部分切片结果,如rank1的输入为$(x\cdot W_{up1})\odot\text{SiLU}(x\cdot W_{gate1})$,其权重为$W_{down}$按行切分的部分结果,如rank1为$W_{down1}$,因此若直接给两部分做矩阵乘得到的是一个部分和结果。但由于RowParallelLinear的forward中最后进行了all reduce,因此最终返回的是完整的结果。

综上可见,在tp模式下,各个rank上的数据流都是完整数据。所有rank都会收到完整的hidden state来输入attention模块或FFN模块,而这两个模块内部只含有该rank对应的一部分权重,因此内部执行的是hidden state和部分权重的计算,计算得到的也是部分结果。但因为这两个模块都是由ColumnParallelLinear+RowParallelLinear构成,其中RowParallelLinear最终会执行all reduce操作,因此各个rank上的attention和ffn的输出均为完整hidden state,从而给下一层提供完整的hidden state。

warmup与预分配kv cache空间

加载完模型后,进行一次模拟forward,触发模型运行时所有可能的开销后,通过实验得到模型的显存占用,然后再根据剩余显存来为KV blocks预分配空间(也即物理kv blocks):

        self.warmup_model()
        self.allocate_kv_cache()

这是因为模型初始化、cuda kernel、pytorch runtime等都会在第一次运行时额外申请显存,剩下的显存可以安全地都留给kv cache。先用self.warmup_model()统计一下模型运行时还能剩多少显存,self.allocate_kv_cache()则根据统计结果来创建一个用于保存kv blocks的巨大空张量(相当于预分配一个空间上限)。

  • self.warmup_model()

        def warmup_model(self):
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
            num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
            seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
            self.run(seqs, True)
            torch.cuda.empty_cache()
    

    释放PyTorch占用的未分配显存并重置峰值显存占用纪录后,构造一个batch大小和序列长度均达到最大的假输入张量,然后输入模型做一次prefill来记录负载拉满的情况下最高的总显存占用是多少。

    其中,max_num_batched_tokens表示模型一次forward中处理的token总数上限(也即bsz*seq_len的上限),max_model_len为单个序列最大长度。因此构建的假输入中相当于把每个序列都拉到了长度上限,然后在这个batch中塞入了上限数量的序列,从而模拟模型能够见到的最大负载。

  • self.allocate_kv_cache()

        def allocate_kv_cache(self):
            config = self.config
            hf_config = config.hf_config
            free, total = torch.cuda.mem_get_info()
            used = total - free
            peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
            current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
            num_kv_heads = hf_config.num_key_value_heads // self.world_size
            head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
            block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
            config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
            assert config.num_kvcache_blocks > 0
            self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
            layer_id = 0
            for module in self.model.modules():
                if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
                    module.k_cache = self.kv_cache[0, layer_id]
                    module.v_cache = self.kv_cache[1, layer_id]
                    layer_id += 1
    
    • 经过warmup后,统计了模型运行时的峰值显存占用,获得显存信息如下:

              free, total = torch.cuda.mem_get_info()
              used = total - free
              peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
              current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
      
      • free:当前device上空闲的显存量

      • total:当前device的总显存

      • usedtotal - free即为当前device上pytorch持有的显存总量(模型权重+cuda runtime+pytorch runtime等静态上下文占用的总显存量)

      • peak:warmup期间在负载拉满的情况下”allocated_bytes”占用的峰值显存,也即“由PyTorch内存分配器管理的、已实际分配给Tensors使用的内存峰值”,相当于模型权重+最大activation占用的内存量。其不包括pytorch context等静态上下文占用。

      • current:当前存活的PyTorch tensor/计算图占用的显存

      由此可以算出能够留给KV cache的显存量为:

      total * config.gpu_memory_utilization - peak - (used - current)
      

      其中total * config.gpu_memory_utilization为当前device上允许vllm使用的总空间(例如0.9表示当前device有90%的显存允许被vllm拿去用来存储kv cache和运行模型等),其减去模型权重+activation的最大值(peak),然后再减去其他的静态上下文占用(used-current),剩下的就是可以放心地预分配给kv cache的显存容量了。显存量以byte计。

    • 接下来,根据预定义的kv block大小计算出每个kv block占据的显存量,然后根据上文算出的kv cache最大容量来得出kv block的总数量,最后再为self.kv_cache预分配一个空张量来占满最大容量空间:

              block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
              config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
              assert config.num_kvcache_blocks > 0
              self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
      

      其中:

      • block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize:每个kv block占据的显存量(以byte计)。

        • 2 * self.block_size * num_kv_heads * head_dim为一个kv block保存的元素数量。其中的self.block_size为预定义的block容量(Qwen3-0.6B默认取256),表示一个block负责多少个token的kv;num_kv_heads= hf_config.num_key_value_heads // self.world_size表示该device负责的head数量,num_kv_heads*head_dim表示1个token的k或v的(该device负责的那几个head的)元素数量;2*表示每个token都具有k和v,因此占用要乘以2。

        • torch_dtype.itemsize表示一个元素占据的byte数,当数据类型为fp16或bf16时一个元素为16bit(2byte),因此该值取2。

        • num_hidden_layes为模型总层数

        综上可见,一个kv block负责保存若干个token在所有层的kv cache,block_bytes表示一个block占据的显存byte数。

      • config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes:block总数。

        也即用上边算出来的能预留给kv cache的最大显存量除以每个block的显存量,可得当前device最多能容纳多少个blocks。

      • 然后,为kv cache预分配一个超大的tensor空间,用来保存未来所有可能出现的kv cache。

        self.kv_cache = torch.empty(
          2, 
          hf_config.num_hidden_layers, 
          config.num_kvcache_blocks, 
          self.block_size, 	
          num_kv_heads, 
          head_dim
        )
        

        正常来讲,一个长度为seq_len的序列(包含所有层)的kv的形状应为:

        KV: [2, num_hidden_layers, seq_len, num_kv_heads, head_dim]
        

        在paged attention机制下,一个序列中的tokens被分组到若干个blocks负责,相当于把seq_len切成了多个block:seq_len = num_kvcache_blocks * block_size。由此也就得到了self.kv_cache的张量形状,其体现了block table的格式。

      • 最后,将上述self.kv_cache分发给模型各layer的attention模块的k_cachev_cache,方便它们直接访问到总的self.kv_cache的对应位置:

        layer_id = 0
        for module in self.model.modules():
            if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
                module.k_cache = self.kv_cache[0, layer_id]
                module.v_cache = self.kv_cache[1, layer_id]
                layer_id += 1
        

      这里得到的self.kv_cache其实就是“物理kv blocks”,其用于实际存储各个block中的token的kv数据

cuda graph捕捉与context设置

接下来进行cuda graph捕捉,其通过模拟多种batch size取值下的decode过程来捕捉这些情况的输入下的cuda graph,从而在真正推理decode时遇到这些bsz时可以直接利用对应的cuda graph进行replay,从而将多个kernel launch融合成一次launch。

self.enforce_eager=True表示不捕捉cuda graph,推理时动态建图,否则执行建图。

        if not self.enforce_eager:
            self.capture_cudagraph()

cuda graph要求计算图完全静态,也即图中各个算子的输入张量形状在捕捉图和后边执行图(replay)时必须保持一致。

由于prefill的序列长度非常不固定,导致该阶段输入张量形状千变万化,不符合cuda graph的要求,因此prefill阶段不捕捉图(并且prefill阶段是计算瓶颈的,主要延时会花在attention计算上,kernel launch的开销占比相对较小,即使建图也没什么明显收益)。

decode阶段序列长度恒为1,因此其形状相对稳定。当然推理时可能取多种batch size,因此这里在捕捉图时预先尝试多种最常见的bsz(例如1,2,4,8,16,32,48,64,...,512),并且将它们对应的cuda graph全都记录下来。推理时如果恰好碰上了这些bsz中的一个那就自然可以直接replay它对应的图,就算没有恰好遇到也可以就近padding到就近的bsz然后调用对应的图(例如bsz=13时通过补3个空序列可以变成bsz=16)

事实上,在普通的attention实现下,decode阶段通常也不能入图,因为attention算子中虽然q的形状固定为[bsz, num_heads, 1, head_dim],但k,v的形状[bsz, num_heads, seq_len, head_dim]中的seq_len是动态变化的。而vllm paged attention由于预分配了一个完全静态的大self.kv_cache张量,不管真实的历史序列长度是什么它的形状永远是[2, num_hidden_layers, num_blocks, block_size, num_kv_heads, head_dim],无非是可能有些blocks的元素是未填充的空值,但张量本身的完全静态的。由此可见,paged attention除了减少内存碎片化以外,还通过预分配静态的kv cache张量来支持了decoding阶段的入图。

本质来看,graph capture时会记录:1)kernel类型(如GEMM、Layernorm、softmax),2)kernel参数(如gridDim, blockDim, shared memory),3)kernel的输入输出地址。

其中,kernel参数的记录决定了图中所有参与计算的张量形状都不能变化,否则底层的blockDim等可能被改变,和入图时不符。kernel的输入输出地址不能变是因为graph记录的是各个张量的gpu指针的地址,而不是tensor对象,因此在入图和后续replay时同一个张量的地址必须保持一致。假设入图时Q.data_ptr()=0x1000,graoh中记录kernel(Q=0x1000),若replay时Q.data_ptr()=0x2000,但kernel依然访问0x1000地址来读取Q的数据,这就导致无法读到正确的数据。

因此,如果想将一个执行流入图,好的做法是预分配所有输入输出的tensor对象,如:

```python

static_input = torch.empty(shape, device=”cuda”)

```

然后执行时对于不同批次的数据只是将这些数据copy到这个静态tensor中,来填充新数据,而不是new一个新的tensor对象:

```python

static_input.copy_(new_input)

```

这样就可以保证static_input.shapestatic_input.data_ptr()一直不变,无论具体数据是什么。vllm预分配self.kv_cache就体现了这一思想。

具体的入图代码如下:

    @torch.inference_mode()
    def capture_cudagraph(self):
        config = self.config
        hf_config = config.hf_config
        max_bs = min(self.config.max_num_seqs, 512)
        max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
        input_ids = torch.zeros(max_bs, dtype=torch.int64)
        positions = torch.zeros(max_bs, dtype=torch.int64)
        slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
        context_lens = torch.zeros(max_bs, dtype=torch.int32)
        block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
        outputs = torch.zeros(max_bs, hf_config.hidden_size)
        self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
        self.graphs = {}
        self.graph_pool = None

        for bs in reversed(self.graph_bs):
            graph = torch.cuda.CUDAGraph()
            set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
            outputs[:bs] = self.model(input_ids[:bs], positions[:bs])    # warmup
            with torch.cuda.graph(graph, self.graph_pool):
                outputs[:bs] = self.model(input_ids[:bs], positions[:bs])    # capture
            if self.graph_pool is None:
                self.graph_pool = graph.pool()
            self.graphs[bs] = graph
            torch.cuda.synchronize()
            reset_context()

        self.graph_vars = dict(
            input_ids=input_ids,
            positions=positions,
            slot_mapping=slot_mapping,
            context_lens=context_lens,
            block_tables=block_tables,
            outputs=outputs,
        )
  • 首先确定本系统能接收的batch size最大值,这里取max_bs=512,然后构建decoding阶段能见到的最大输入和输出,以及预分配的其他相关张量:

    input_ids = torch.zeros(max_bs, dtype=torch.int64)
    positions = torch.zeros(max_bs, dtype=torch.int64)
    slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
    context_lens = torch.zeros(max_bs, dtype=torch.int32)
    block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
    outputs = torch.zeros(max_bs, hf_config.hidden_size)
    

    这里input_ids=[t0,t1,...,t511], input_ids.shape=[512],这种写法会被视为512个seq_len=1的序列组成的batch(本质上相当于size=[512, 1],简化为[512]),而不是一个长度为512的序列(应写成size=[1, 512]),因此这些tokens会被看作属于不同序列,它们之间不会产生attention。相应地,positions=[0,0,...,0]表示每个token在其序列中的位置都是0。其他的block_tables等也都拉到了decoding阶段能看到的最大张量。

  • 进一步,列出所有要尝试的bsz:

    self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
    

    对于较小的bsz设置的区间比较精细:1,2,4,8,对于16以上的bsz则以16为间隔,直到最大bsz:16,32,48,64,...,512

  • self.graphs字典来保存各个bsz下的graph,self.graph_pool用于多个cuda graph的内存复用,否则每个graph都会单独申请显存

    self.graphs = {}
    self.graph_pool = None
    
  • 接下来开始从大到小依次捕捉所有bsz下的cuda graph:

    for bs in reversed(self.graph_bs):
        graph = torch.cuda.CUDAGraph()
        set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
        outputs[:bs] = self.model(input_ids[:bs], positions[:bs])    # warmup
        with torch.cuda.graph(graph, self.graph_pool):
            outputs[:bs] = self.model(input_ids[:bs], positions[:bs])    # capture
        if self.graph_pool is None:	# largest bsz
            self.graph_pool = graph.pool()
        self.graphs[bs] = graph
        torch.cuda.synchronize()
        reset_context()
    

    之所以bsz从大到小,是因为第一轮尝试最大bsz时可以申请到graph所需的最大内存,将其保存为graph pool后,后边尝试的所有bsz所需内存都不会超过它,因此后边都可以安心复用了。

    在每轮cuda graph捕捉时,其会处于一组set_context()reset_context()之间。context本质是一个全局的运行时状态,用于记录一些tensor和标志位等,用于给attention kernel和kv cache操作等提供运行时的隐式参数,同时保证捕捉graph时这些tensor的地址是稳定的。context定义在utils/context.py中:

    @dataclass
    class Context:
        is_prefill: bool = False
        cu_seqlens_q: torch.Tensor \| None = None
        cu_seqlens_k: torch.Tensor \| None = None
        max_seqlen_q: int = 0
        max_seqlen_k: int = 0
        slot_mapping: torch.Tensor \| None = None
        context_lens: torch.Tensor \| None = None
        block_tables: torch.Tensor \| None = None
      
    _CONTEXT = Context()
      
    def get_context():
        return _CONTEXT
      
    def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
        global _CONTEXT
        _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
      
    def reset_context():
        global _CONTEXT
        _CONTEXT = Context()
    

    context中的主要信息包括:

    • is_prefill:当前是在运行prefill推理还是decode推理

    • cu_seqlens_q/k:处理batch生成的变长序列时所用的索引数组。为了提高FA等算子的gpu利用率,并避免padding带来的计算浪费,会将一个batch中的所有子序列pack成一个大的长序列。例如,一个batch中有3个prompt,其中len(seq1)=3, len(seq2)=5, len(seq3)=2,将它们连起来可得一个长度为10的大序列:Q=concat(seq1,seq2,seq3)

      cu_seqlens是一个prefix sum数组,其用来标识大序列中的每个子序列的边界,从而可以给attention算子提供各个子序列的位置信息,确保只在每个子序列内部做attention。在上例中:cu_seqlens_q=[0,3,8,10],也即seq1=Q[0:3], seq2=Q[3:8], seq3=Q[8:10]

    • slot_mapping:在paged attention中用于把token在当前batch中的位置映射到物理self.kv_cache中的实际存储位置(slot)。

      在每个物理kv block会负责若干个token的kv,每个token占据一个slot。物理块中的每个slot id都是全局唯一的,如果知道一个token的slot id的话即可将其映射到对应的物理块slot,从而获得其kv。

      slot_mapping用于建立物理块和逻辑块之间的映射,用于记录序列中每个token对应的slot id。一个(packed)序列中的各个token的kv在逻辑块中是连续的,但在物理块中并不一定连续。例如下图中,序列里token 0,1,2,3对应的物理slot为48,49,50,51,记录在slot_mapping中就是[48,49,50,51]。利用slot_mapping即可正确访问到每个token在物理block中的kv。

    每个step的推理进行时都会处于一个context下(每个step都会根据当前情况重新设置一套context),也即具有针对本次推理的一个全局_CONTEXT对象,从而让各个算子知道本step是prefill/decode、batch的pack情况等信息。

    设好context后,进一步以当前bsz的输入input_ids[:bs]进行两次模型forward,第一次是warmup确保所有kernel都正确启动了,第二次是在捕捉graph的环境下运行,从而捕捉到cuda graph并存入self.graphs字典中:

    outputs[:bs] = self.model(input_ids[:bs], positions[:bs])    # warmup
    with torch.cuda.graph(graph, self.graph_pool):
        outputs[:bs] = self.model(input_ids[:bs], positions[:bs])    # capture
    if self.graph_pool is None:	# largest bsz
        self.graph_pool = graph.pool()
    self.graphs[bs] = graph
    
  • 所有bsz的cuda graph都捕捉完毕并存到self.graphs中之后,最后再将预分配的input_ids,output_ids等输入输出张量进行记录保存:

    self.graph_vars = dict(
        input_ids=input_ids,
        positions=positions,
        slot_mapping=slot_mapping,
        context_lens=context_lens,
        block_tables=block_tables,
        outputs=outputs,
    )
    

    这样未来在推理时,每次的新输入都是把数据copy到这个input_ids张量里,输出数据也是被放到outputs张量里,而不是重新创建新的tensors。这就确保了graph捕捉时和真正推理时所用的这些tensor都是同一个,只不过包含的具体数据不同,确保地址不变。

共享内存IPC与引擎停止

tp_size>1的情况下,每个rank上的model runner初始化的最后,都会连接到一块共享内存,用于进程间的信号传递(如主进程的"run","exit"命令)以及每次调度后输入数据的分发

        if self.world_size > 1:
            if rank == 0:
                self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
                dist.barrier()
            else:
                dist.barrier()
                self.shm = SharedMemory(name="nanovllm")
                self.loop()

其中rank0(主进程)的model runner负责创建这个SharedMemory对象,大小为1MB。其他子进程的model runner等创建好后连接到这块共享内存,然后进入self.loop(),通过self.read_shm()不断从共享内存中试图读取来自主进程的信号(和输入数据),从而实时监听来自主进程的命令:

    def loop(self):
        while True:
            method_name, args = self.read_shm()
            self.call(method_name, *args)
            if method_name == "exit":
                break

当主进程接收到停止信息后(例如用户手动Ctrl+C),LLM engine会调用其自定义的self.exit()函数,给自己的(rank0)的model runner发出exit信号,然后删除model runner,并终止所有子进程:

    def exit(self):
        self.model_runner.call("exit")
        del self.model_runner
        for p in self.ps:
            p.join()

tp_size>1,则主进程的model runner还会负责通过self.write_shm来将exit信号写入所有进程共享的shared memory,来通知其他子进程的model runner现在程序要结束了:

    def call(self, method_name, *args):
        if self.world_size > 1 and self.rank == 0:
            self.write_shm(method_name, *args)
        method = getattr(self, method_name, None)
        return method(*args)

各个子进程的model runner从始至终都在self.loop()中不断读取信号。其读到exit信号后会执行self.exit()函数:断开其和shared memory的连接(rank0会进一步负责删除掉shared memory)、删除cuda graph、删除进程组:

    def exit(self):
        if self.world_size > 1:
            self.shm.close()
            dist.barrier()
            if self.rank == 0:
                self.shm.unlink()
        if not self.enforce_eager:
            del self.graphs, self.graph_pool
        torch.cuda.synchronize()
        dist.destroy_process_group()

从而完成所有资源的关停。

推理流程

参考:https://zhuanlan.zhihu.com/p/2010638958783131701

alt text

llm.generate()定义在LLMEngine中,用于对一批输入prompts进行offline推理生成:

    def generate(
        self,
        prompts: list[str] \| list[list[int]],
        sampling_params: SamplingParams \| list[SamplingParams],
        use_tqdm: bool = True,
    ) -> list[str]:
        # tqdm配置和sampling params设置
      	# ...
        for prompt, sp in zip(prompts, sampling_params):
            self.add_request(prompt, sp)
        outputs = {}
        prefill_throughput = decode_throughput = 0.
        while not self.is_finished():
          output, num_tokens = self.step()
           # tqdm记录
           # ...
          for seq_id, token_ids in output:
            outputs[seq_id] = token_ids
            
        outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
        outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
        if use_tqdm:
            pbar.close()
        return outputs

请求添加与Sequence构建

首先,依次将batch里的各个prompt送入self.add_request()函数:

def add_request(self, prompt: str \| list[int], sampling_params: SamplingParams):
    if isinstance(prompt, str):
        prompt = self.tokenizer.encode(prompt)
    seq = Sequence(prompt, sampling_params)
    self.scheduler.add(seq)

在这个函数中,首先将prompt进行tokenize,然后将其构造为一个Sequence对象(定义在engine/sequence.py中),这个Sequence对象中包含了这个序列的所有信息,包括其当前长度、调度状态(等待中/生成中/已结束等):

class Sequence:
    block_size = 256
    counter = count()

    def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
        self.seq_id = next(Sequence.counter)
        self.status = SequenceStatus.WAITING
        self.token_ids = copy(token_ids)
        self.last_token = token_ids[-1]
        self.num_tokens = len(self.token_ids)
        self.num_prompt_tokens = len(token_ids)
        self.num_cached_tokens = 0	# 该序列中前多少tokens成功命中缓存了,后续prefill时可以跳过它们
        self.block_table = []	# 该序列由哪些blocks负责
        self.temperature = sampling_params.temperature
        self.max_tokens = sampling_params.max_tokens
        self.ignore_eos = sampling_params.ignore_eos

然后将其添加到scheduler的waiting队列中:

    def add(self, seq: Sequence):
        self.waiting.append(seq)

请求调度与执行

将batch中的各个prompt都添加到scheduler的self.waiting队列中后,进入while not self.is_finished()循环,每轮循环都执行一步self.step()来执行一次调度,每次调度会根据当前的请求队列来拼凑出一个batch,然后送入模型进行一次forward并decode出下一个token,直到缓存队列(包括running和waiting序列)中的所有序列都生成完成:

def step(self):
    seqs, is_prefill = self.scheduler.schedule()
    token_ids = self.model_runner.call("run", seqs, is_prefill)
    self.scheduler.postprocess(seqs, token_ids)
    outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
    num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
    return outputs, num_tokens

schedule操作概览

self.step()中首先会触发scheduler.schedule()函数,其负责调度prefill和decode请求,基于当前请求队列情况、资源情况、优先级等因素拼凑出一个batch(seq),并返回,从而进一步让model runner处理这个batch。

调度逻辑默认采用vllm v0的prefill优先策略,即当收到新的prefill请求时,可打断当前正在进行的decode执行,被中断的请求将被移入 waiting队列。抢占逻辑为:当请求执行资源不足时,按照入队顺序将running队列中后进入的请求转移至 waiting队列。

def schedule(self) -> tuple[list[Sequence], bool]:
    # prefill
    scheduled_seqs = []
    num_seqs = 0
    num_batched_tokens = 0
    while self.waiting and num_seqs < self.max_num_seqs:
        seq = self.waiting[0]
        if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
            break
        num_seqs += 1
        self.block_manager.allocate(seq)
        num_batched_tokens += len(seq) - seq.num_cached_tokens
        seq.status = SequenceStatus.RUNNING
        self.waiting.popleft()
        self.running.append(seq)
        scheduled_seqs.append(seq)
    if scheduled_seqs:
        return scheduled_seqs, True

    # decode
    while self.running and num_seqs < self.max_num_seqs:
        seq = self.running.popleft()
        while not self.block_manager.can_append(seq):
            if self.running:
                self.preempt(self.running.pop())
            else:
                self.preempt(seq)
                break
        else:
            num_seqs += 1
            self.block_manager.may_append(seq)
            scheduled_seqs.append(seq)
    assert scheduled_seqs
    self.running.extendleft(reversed(scheduled_seqs))
    return scheduled_seqs, False

其中,scheduled_seqs为本步调度后要进一步送到执行单元进行执行的序列列表(它们会被拼成一个batch送入模型),num_seqs为本步调度的序列数量,num_batched_tokens为本步调度的所有序列需要计算的token总数(除去缓存命中无序重复计算的那些prefix tokens)。

schedule – prefill&block manager的分配&prefix caching

waiting队列的主体是由新加入的prefill请求构成的,但也可能包含一些本来处于running队列中但由于资源不足而被逐出并放回waiting队列的请求,这些请求也可以看作prefill请求。prefill的优先级高于decode,在每一步调度时若发现有prefill请求且成功调度了其中至少一个,那么就直接完成构建本次调度的batch并返回,不再继续看decode的情况。

可见,若在某些序列decoding过程中新来了prefill请求,那么decode就会被打断,scheduler优先调度prefill请求,直到waiting队列被清空后才会调度decode请求。

self.waiting队列不为空时,尝试依次调度waiting队列中的每个请求,直到清空所有waiting状态的序列:

    # prefill
  	while self.waiting and num_seqs < self.max_num_seqs:
        seq = self.waiting[0]
        if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
            break
        num_seqs += 1
        self.block_manager.allocate(seq)
        num_batched_tokens += len(seq) - seq.num_cached_tokens
        seq.status = SequenceStatus.RUNNING
        self.waiting.popleft()
        self.running.append(seq)
        scheduled_seqs.append(seq)
    if scheduled_seqs:
        return scheduled_seqs, True

每次取waiting队列最左端的序列,然后看一下当前剩余的block资源是否支持处理它:

def can_allocate(self, seq: Sequence) -> bool:
    return len(self.free_block_ids) >= seq.num_blocks

若加上它后会导致当前构造的batch的总token量超过规定上限,或当前资源不够为这个序列分配kv blocks(体现为该序列要占用的kv blocks数量大于当前block manager剩余的free blocks数量)则直接退出对于waiting队列的处理循环,等待下一次调度时再尝试。

若可以处理这个序列,则将该序列添加到scheduled_seqs列表中,作为下一个送去执行的batch的一部分,同时也相应地更新num_seqsnum_batched_tokens等统计值,并将其从waiting队列弹出并加到running队列末尾。


allocate逻辑:

若该序列可以被处理,则进一步使用self.block_manager.allocate(seq)为其分配kv blocks(包括prefix cache匹配和记录逻辑),定义在engine/block_manager.py中:

def allocate(self, seq: Sequence):
    assert not seq.block_table
    h = -1
    cache_miss = False
    for i in range(seq.num_blocks):
        token_ids = seq.block(i)
        h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
        block_id = self.hash_to_block_id.get(h, -1)
        if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
            cache_miss = True
        if cache_miss:
            block_id = self.free_block_ids[0]
            block = self._allocate_block(block_id)
        else:
            seq.num_cached_tokens += self.block_size
            if block_id in self.used_block_ids:
                block = self.blocks[block_id]
                block.ref_count += 1
            else:
                block = self._allocate_block(block_id)
        if h != -1:
            block.update(h, token_ids)
            self.hash_to_block_id[h] = block_id
        seq.block_table.append(block_id)
  • assert not seq.block_table:确保这个序列还没有被分配过block,体现为它的block_table为空

  • h:指前一个block的哈希值,用于链式哈希。初始化为-1表示该序列的第1个block前没有block了。

  • cache_miss:用于标记该序列是否可以命中已有的kv缓存,初始化为False表示当前还没有发现cache miss的情况。在接下来逐个生成block时,只要发现有一个block没匹配到可以复用的cache,那么就将其置为True,也即只要出现一次miss后边的所有blocks就都不可能复用已有cache了。

    prefix caching在vllm中是block层面的匹配,也即给一个新序列的一个子段分配一个block时,可以搜索一下之前是否存在一个状态完全一样的block,如果有的话即可拿它的kv cache来复用到这里,从而无需为这个block重新算一遍kv cache了。其要求从第一个token开始到当前block的所有tokens(也即prefix)必须都和cache里的完全一样,一旦有一个地方断了那么后边的序列无论如何都不可能和cache序列的prefix完全一致了,后边的blocks也就不可能使用已有block的kv cache了。

    例如,请求1的序列是A,B,C,D,E,F,请求2是A,B,C,Z,E,F,设每个字母代表一个block容纳的tokens,则对于请求2来说,A,B,C这三个block都完全可以复用请求1的相应block的kv cache,但是从Z开始就不匹配了,后边的E虽然本身和请求1中的E一样,但前缀分别是A,B,C,D,EA,B,C,Z,E,因此还是不能复用。

然后使用seq.num_blocks算得该序列prefill后所需的kv blocks数量:

@property
def num_blocks(self):
    return (self.num_tokens + self.block_size - 1) // self.block_size

然后依次分配每一个所需的block,每轮循环中分配一个block,用来承载该序列中的一部分tokens。分配其第i个block的过程:

  • token_ids = seq.block(i):得到当前序列分配给其第i个block的token子序列的token ids:

    def block(self, i):
        assert 0 <= i < self.num_blocks
        return self.token_ids[i*self.block_size: (i+1)*self.block_size]
    
  • h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1:基于上一个block的哈希值和该block负责的token子序列,给该block计算一个64位的哈希值,用作该block状态的唯一标识。

    这里的判断条件意思是只有满block才参与缓存(参与匹配之前的cache/作为cache等待后续匹配),如果block未满的话则不参与,哈希值置为h=-1

    @classmethod
    def compute_hash(cls, token_ids: list[int], prefix: int = -1):
        h = xxhash.xxh64()
        if prefix != -1:
            h.update(prefix.to_bytes(8, "little"))
        h.update(np.array(token_ids).tobytes())
        return h.intdigest()
    

    该部分结合prefix(上一个block的哈希值,记录了当前token子序列的前缀序列信息)和当前的token子序列计算出了当前block的唯一哈希值,从而确保只有前缀+当前子序列和当前情况完全一致时才能碰上一样的哈希值(也即一个block的状态由其本身包含的token子序列以及前缀序列共同决定)。最终返回的是一串数字来表示哈希值。

    其可以用于prefix caching优化,一方面在后面可以用于检测目前已有的kv blocks是否有能和这个子序列匹配的,如果有的话可以直接拿来复用其kv cache,避免重复计算;另一方面也作为当前block的缓存记录,未来如果有新的block和它完全一样时可以服用它的kv cache。

  • block_id = self.hash_to_block_id.get(h, -1)

    从hash-block查找表中查找当前block的哈希值h是否能匹配到某个已有的block,如发现的话就返回那个被命中的block的全局id作为block_id,否则返回-1

  • if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
        cache_miss = True
    

    如果没找到能够完全匹配前缀和当前子序列的block,或虽然匹配上了但实际上两个block的内容不同(由哈希碰撞导致,例如hash(ABCD)==hash(XYZW),因此保险起见还是要检查下两个block的具体token内容是不是一样),则认定在这个block处没有成功匹配,因此设置cache_miss=False,表示该block以及该序列中后边的所有blocks都不可能找到匹配的缓存了,只能老老实实地分配新的block空间了。

  • if cache_miss:
        block_id = self.free_block_ids[0]
        block = self._allocate_block(block_id)
          
    else:
        seq.num_cached_tokens += self.block_size
        if block_id in self.used_block_ids:
            block = self.blocks[block_id]
            block.ref_count += 1
        else:
            block = self._allocate_block(block_id)
    

    这部分分别描述cache miss或cache hit情况下对于block对象的特别处理逻辑

    • 如果cache miss了,说明当前这个block不可能找到匹配的cache了,因此从当前的空闲block队列左端拿一个空闲block对象的id。

      确保该block目前确实未被任何序列使用后,将其信息重置,然后将这个block对象从空闲队列移到已使用队列,并返回这个block对象(后续用于承载当前的token子序列),可以看成将这个block对象“激活”:

      def _allocate_block(self, block_id: int) -> Block:
          block = self.blocks[block_id]
          assert block.ref_count == 0
          block.reset()
          self.free_block_ids.remove(block_id)
          self.used_block_ids.add(block_id)
          return self.blocks[block_id]
      
    • 如果cache命中了:

      • seq.num_cached_tokens中增加这个token子段的token数量,这个值用来记录当前序列的前多少tokens成功命中了cache,从而在后续执行prefill计算时可以跳过这些tokens,无需重新计算它们的kv cache。

      • 然后看一下命中的这个block对象是否处于self.used_block_ids中,如果是的话则说明这个block正在被使用,只需直接取这个block对象并将其引用次数+1,否则就使用self._allocate_block()将其重新激活。

        该block“不再被使用”指的是之前产生/使用过它的序列都已经生成完毕,因此它被移出self.used_block_ids并放回self.free_block_ids中,表示它的状态由active转为inactive。但它对应的token子序列和的kv cache数据并未被删除,前者仍记录在block对象中,后者仍然存在于self.kv_cache中,还可以作为缓存被命中使用。

  • if h != -1:
        block.update(h, token_ids)
        self.hash_to_block_id[h] = block_id
    seq.block_table.append(block_id)
    

    如果当前token子序列的哈希值有效(能填满一个block),那么就用其哈希值和token ids来更新为其分配的block对象的信息,然后将这个哈希值-block对记录入self.hash_to_block_id查找表,用于被未来的blocks进行缓存匹配。

    最后,将这个block对象的id加入当前序列的block_table记录中。


由此完成了对于当前序列的所有blocks的分配。回到scheduler.schedule()中,继续对于该序列进行处理:

num_batched_tokens += len(seq) - seq.num_cached_tokens
seq.status = SequenceStatus.RUNNING
self.waiting.popleft()
self.running.append(seq)
scheduled_seqs.append(seq)

将该序列剔除掉缓存命中的tokens后,剩余的需要计算的token数量累加到该batch需要计算的token总数num_batched_tokens中,并将该序列对象的状态设为RUNNING。然后将其从waiting队列中移出并放入running队列,表示该序列正在进行生成。最后将该序列添加到本次调度的batch中。至此对于该序列的处理就结束了。

当所有waiting队列中的请求都被清空后,看一下是否成功调度了至少一个prefill请求,如果是的话则直接完成本步调度的batch构建并返回,不再进一步做decode的调度:

if scheduled_seqs:
    return scheduled_seqs, True

可见,prefill请求的优先级高于decode,其可以打断decode,直到某一步调度时发现所有prefill请求都处理完了才会进去处理decode请求。

schedule – decode

正在decode的序列都被储存在running队列中。若本步调度时waiting队列不存在等待调度的序列(或因为资源有限无法调度任何一个prefill请求),那么就来依次处理running队列中的各个decoding序列:

注意:当调度每一个seq序列时,严格来说并不是为它“接下来要生成的下一个token准备kv slot”,而是“为上一步生成的token分配kv slots,并将序列调入本轮调度构建的batch等待下一步生成”。

例如,序列上一步生成的token为:t1,t2,t3 -> t4,此时序列为seq=t1,t2,t3,t4len(seq)中也包含了t4,然后进入本步调度。但此时并没有t4的kv cache,因为它是上一轮刚刚decode出来的,还未经过过模型,因此它实际上还没被安排好kv slot位置,处于未finalize的状态。直到将seq=t1,t2,t3,t4(实际上在kv cache机制下就是输入t4)输入模型,也就是本轮调度batch构建好并执行后,才会产生t4的kv cache。

可见,本轮调度中的kv分配环节是为末尾token t4安排kv slot(例如放入已有未满的kv block或申请一个新的block),从而承载它接下来通过模型时产生的kv cache,而不是为当前尚未产生的next token t5预分配空间。

# decode
while self.running and num_seqs < self.max_num_seqs:
    seq = self.running.popleft()
    while not self.block_manager.can_append(seq):
        if self.running:
            self.preempt(self.running.pop())
        else:
            self.preempt(seq)
            break
    else:
        num_seqs += 1
        self.block_manager.may_append(seq)
        scheduled_seqs.append(seq)
assert scheduled_seqs
self.running.extendleft(reversed(scheduled_seqs))
return scheduled_seqs, False

每次取出running队列左端的序列后,这里使用block_manager.can_append()来检查一下:若这个decode请求需要申请新的block,那么当前是否还能拿出至少1个空闲的block:

def can_append(self, seq: Sequence) -> bool:
    return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)

在decode过程中,若当前最近的几个tokens所用的block未满,则末尾token可以加到这个block中,无需申请开启新的block。例如,设block_size=4,则token0需要申请一个新block,然后token1,2,3都可以使用这个block,直到token4才需要申请下一个新的block。

(len(seq) % self.block_size == 1)是一个bool值,用于计算末尾token需要申请的新block数。若取False则等同于0,说明末尾token可以放入之前没填满的block,无需新block;若取True则等同于1,说明末尾token需要新申请一个block。因此,在False情况下没有任何资源限制,len(self.free_block_ids)>=0在任何情况下都成立,因此该decode请求可以被调度;若取True则需要len(self.free_block_ids)>=1,也即需要当前有至少1个空闲block来分配给下一个token。

  • 如果self.block_manager.can_append(seq)==False,则说明末尾token需要申请1个新block,但当前系统中已经没有任何1个新block了。因此进一步开启抢占循环来尝试通过牺牲其他低优先级序列来尽力确保当前序列可以被finalize并调度:

    while not self.block_manager.can_append(seq):
        if self.running:
            self.preempt(self.running.pop())
        else:
            self.preempt(seq)
            break
    
    • 若当前running队列中还有其他等待decode的序列,则逐出running队列最末尾的序列(FIFO下它的优先级最低),释放其占据的kv blocks资源,然后再查看能不能调度当前序列。如果不能的话就继续逐出running队列末尾的序列,直到拥有足够的空闲block(1个就可以)可以调度当前序列或把running队列逐空了为止。

    • 若在running队列逐空了之后,依然凑不出1个空闲的kv block,那么就确实没办法在本轮调度当前序列了,因此只能将当前序列也逐出。

    使用self.preempt()来逐出序列,使得它的资源被抢占。其将序列的状态值设为WAITING,并释放其占用的kv blocks对象,并将seq.block_tablesseq.num_cached_tokens等记录重置。最后再将其从running队列移回waiting队列,等待以后被调度:

    def preempt(self, seq: Sequence):
        seq.status = SequenceStatus.WAITING
        self.block_manager.deallocate(seq)
        self.waiting.appendleft(seq)
    

    block_manager.deallocate(seq)操作如下:

    def deallocate(self, seq: Sequence):
        for block_id in reversed(seq.block_table):
            block = self.blocks[block_id]
            block.ref_count -= 1
            if block.ref_count == 0:
              self._deallocate_block(block_id)
        seq.num_cached_tokens = 0
        seq.block_table.clear()
      
          
    def _deallocate_block(self, block_id: int) -> Block:
        assert self.blocks[block_id].ref_count == 0
        self.used_block_ids.remove(block_id)
        self.free_block_ids.append(block_id)
    

    其从后往前依次尝试释放该序列占用的各个block对象,将它们的ref_count-=1,表示引用它们的序列少了一个(但不能直接释放它们,因为有可能其他序列还在引用它们)。若发现某个block对象此时引用数归零了,则说明它已不再被任何一个active的序列引用了,因此将它从used_block_ids集合移入free_block_ids集合,表示它成为了空闲的可使用的block。最后还会将序列的num_cached_tokens归零,将其block_table清空。

  • 如果可以调度这个序列,则将本轮调度batch的序列数量num_seqs+=1,表示确认要调度这个序列。然后使用block_manager.may_append(seq)来为这个序列的末端token(即将产生的kv cache)安排kv block,最后将这个序列加入本轮调度构造的batch列表scheduled_seqs中:

    else:
        num_seqs += 1
        self.block_manager.may_append(seq)
        scheduled_seqs.append(seq)
    

    其中,block_manager.may_append()的逻辑如下,可以处理末端token需要新block、可以填充到未满的已有block等情况,从而决定针对这个末端token的block处理策略:

    def may_append(self, seq: Sequence):
        block_table = seq.block_table
        last_block = self.blocks[block_table[-1]]
        if len(seq) % self.block_size == 1:
            assert last_block.hash != -1	# 最近的block既然已经满了,那么应该有哈希值
            block_id = self.free_block_ids[0]
            self._allocate_block(block_id)
            block_table.append(block_id)
        elif len(seq) % self.block_size == 0:
            assert last_block.hash == -1	# 最近的block既然未满(添加末端token后才满),那么就不应该有哈希值
            token_ids = seq.block(seq.num_blocks-1)
            prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
            h = self.compute_hash(token_ids, prefix)
            last_block.update(h, token_ids)
            self.hash_to_block_id[h] = last_block.block_id
        else:
            assert last_block.hash == -1	# 最近的block既然未满,那就不应该有哈希值
    

    last_block表示当前序列目前占用的最后一个block的id

    • len(seq) % self.block_size == 1,则表示序列占用的最后一个block恰好已满,末端token需要一个新block来承载它的kv cache,因此从free_block_ids集合中取出一个空闲block对象,并将其激活,然后将这个block对象添加到本序列的block_table末尾。

      例如,block_size=4,当前序列已占用的block为 block_0=[t1,t2,t3,t4],则末端token t5需要申请1个新block:block_1=[t5]

    • len(seq) % self.block_size == 0,则表示末端token恰好可以填满序列中的最后一个block。将末端token加入后,这个block就变成了满block,因此可以为它计算哈希值了。取该block的token子序列的前缀和当前token子序列后,即可计算该block的哈希值,并通过update函数来将哈希和token子序列信息记录入block对象。这个block也会被记录到block_manager.hast_to_block_id映射表中,从而参与缓存匹配。

      注意,获取该block包含的token子序列时并不能直接从block对象中拿出来,而是通过seq.block(i)方法得到的。这是因为对于未满的block,其hashtoken_ids都为空。因此需要使用Sequence.block()方法来根据索引映射从sequence中得到这个block负责的部分tokens,从而跟末端token连接起来后构成该block负责的完整子序列,并用于计算哈希和放入block对象,最终填充block对象的hashtoken_ids字段。

      def block(self, i):
          assert 0 <= i < self.num_blocks
          return self.token_ids[i*self.block_size: (i+1)*self.block_size]
      

      例如当前序列已占用的block为block_0=[t1,t2,t3],末端token恰好可以填满这个block:block_0->[t1,t2,t3,t4]

    • 若非上述两种情况,则说明末端token可以加入该序列最后一个block,且加入后这个block还是未满状态。因此可以不做任何事,因为最后一个block仍不是满block所以还是不需要计算哈希,其hashtoken_ids字段仍保持为空

至此,本轮调度的batch scheduled_seq已经构建好:

assert scheduled_seqs  # 确保至少调度了1个序列
self.running.extendleft(reversed(scheduled_seqs))
return scheduled_seqs, False

其中self.running.extendleft(reversed(scheduled_seqs))是为了把本轮被调度的序列放回running队列的最前面,并保持原来的顺序,这样可以确保在下一轮调度时它们仍然处于队首,从而使得下一轮它们依然会被优先处理。

这样实际上实现了一种近似的round-robin调度,使得每轮调度开始时各个decode序列在running队列中的顺序都不变,从而保持优先级一直不变(例如靠前的请求可能是更早被用户发起的,那么它们理应较先被decode完毕)。这样并不会导致队尾序列一直没有机会,因为每轮调度时都会从头到尾将running队列扫一遍:

while self.running and num_seqs < self.max_num_seqs:
    seq = self.running.popleft()
    # ...

例如,本次调度时running=[A,B,C,D],将它们依次放入batch并生成下一个token后(假设都还没生成结束),下一轮调度时仍然是running=[A,B,C,D],因为上一次调度拼好batch后重新按原先的优先级恢复了running队列。

请求执行概览

step()函数中,schedule操作返回了本轮调度的batch后,相当于为本次模型执行准备好了输入,然后即可调用model runner执行本步请求:

def step(self):
    seqs, is_prefill = self.scheduler.schedule()
    token_ids = self.model_runner.call("run", seqs, is_prefill)

具体而言,在主进程的model_runner中发布"run"命令,其不仅会执行自己的run()函数,在tp_size>1时还会通过共享内存将"run"信号以及输入序列发布给其他rank的model runner(子进程的model runner初始化后会通过不断尝试读取共享内存来监听来自主进程的命令和数据):

def call(self, method_name, *args):
    if self.world_size > 1 and self.rank == 0:
        self.write_shm(method_name, *args)
    method = getattr(self, method_name, None)
    return method(*args)

这样所有rank上的model runner就都收到了输入数据,并开始执行run()函数:

def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
    input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
    temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
    logits = self.run_model(input_ids, positions, is_prefill)
    token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
    reset_context()
    return token_ids

请求执行 – 输入序列预处理与context构建

首先需要对于输入的seq列表(包含本batch的各个序列的Sequence对象)做预处理:

  • Prefill情况:

    input_ids, positions = self.prepare_prefill(seqs)
    

    self.prepare_prefill()的主要功能是将一批不等长序列pack成一个长序列,从而适于FA kernel的处理,同时也在cpu上准备好了kv cache的写入地址(slot_mapping)、attention边界信息(cu_seqlens)等元数据信息,并将其移动到device上后放入Context对象。

    注意,这里首先在cpu上以列表形式准备并构建input_ids, slot_mapping等元数据,构建完成后使用input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)来在cpu的pin memory上构造对应的tensor,然后以non blocking的方式来非阻塞地异步移动到device上。这是因为这些元数据在每个step都要重新构建并传递到device上,属于高频移动的小数据,因此适合将它们放到cpu的pin memory中以便device可以直接访问并搬运,同时不阻塞cpu。

    prepare_prefill()的定义如下:

    def prepare_block_tables(self, seqs: list[Sequence]):
        max_len = max(len(seq.block_table) for seq in seqs)
        block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
        block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
        return block_tables
      
    def prepare_prefill(self, seqs: list[Sequence]):
        input_ids = []
        positions = []
        cu_seqlens_q = [0]
        cu_seqlens_k = [0]
        max_seqlen_q = 0
        max_seqlen_k = 0
        slot_mapping = []
        block_tables = None
        for seq in seqs:
            seqlen = len(seq)
            # 对于seqs列表中的每个prefill Sequence对象,去除其中命中prefix cache的num_cached_tokens个前缀tokens(因为它们的kv cache已经有了,不需要重复计算了),取剩下的tokens作为本轮要被计算的序列
            input_ids.extend(seq[seq.num_cached_tokens:])
            positions.extend(list(range(seq.num_cached_tokens, seqlen)))
            seqlen_q = seqlen - seq.num_cached_tokens
            seqlen_k = seqlen
            # ...
            for i in range(seq.num_cached_blocks, seq.num_blocks):
               	# ...
                slot_mapping.extend(list(range(start, end)))
         
      # 若发现一个prefill请求的K长度大于Q长度,则说明这个请求的输入序列中存在cached tokens,例如Q=[CDEF],K=[ABCDEF]表示前2个tokens命中了cache,因此它们仅存在于kv,不存在于q(q中都是本次通过模型时需要被计算的tokens)。这就需要一个block tables来查找kv cache,实现旧kv和新kv的混合,使得做attention时能找到这些命中的缓存prefix kv
      # 若Q=K则说明没有prefix cache,就是prefill一个全新的prompt,不需要额外的table来查找
        if cu_seqlens_k[-1] > cu_seqlens_q[-1]:    # prefix cache
            block_tables = self.prepare_block_tables(seqs)
              
        input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
            positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
            cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
            cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
            slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
        set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
        return input_ids, positions
    

    示例:输入batch中含有两个seq,block_size=256,其中:

    • seqs[0].tokens_ids=[641,279,16800,11,10339,220]

    • seqs[1].token_ids=[111308,6313]

    prepare完毕后,放入context的信息(都是以gpu上的tensor存在,这里省略tensor标记):

    • cu_seqlens_q=[0,6,8]cu_seqlens_k=[0,6,8]:表示了两个序列的末尾token分别在6,8位置

    • max_len_q=6, max_len_k=6:本batch的最长长度是seq1,长度为6

    • seqs[0].block_table=[0], seqs[1].block_table=[1]:表示seq0中的逻辑block0对应了物理block0,seq1中的逻辑block1对应了物理block1(因为本例中block_size=256,大于二者输入长度,因此两个prefill序列分别只需一个block就可容纳)

    • slot_mapping=[0,1,2,3,4,5,256,257]:表示seq0中的6个token的kv cache存储在物理kv cache的slot 0,1,2,3,4,5(也即物理block0的前6个slot),seq1中的6个token的kv cache存储在物理kv cache的slot 256,257(也即物理block1的前2个slot)

    返回结果:

    • input_ids=[641,279,16800,11,10339,220, 111308,6313]:两个序列首尾连接成的长序列,作为后续模型的直接输入

    • positions=[0,1,2,3,4,5,0,1]:两个序列各自的position序列连接后得到的长序列的position信息,确保每个token对应的position id仍然是其在原序列中的那个。

  • Decode情况:

    input_ids, positions = self.prepare_decode(seqs)
    

    在decode请求中,相比于prefill需要考虑prefix cache hit以及不同序列长度不同,其每个序列都一定只计算1个query token,且所有历史token的kv cache也一定已经存在了kv cache中。

    def prepare_decode(self, seqs: list[Sequence]):
        input_ids = []
        positions = []
        slot_mapping = []
        context_lens = []
        for seq in seqs:
            input_ids.append(seq.last_token)
            positions.append(len(seq) - 1)
            context_lens.append(len(seq))
            slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens  - 1)
        input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
        positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
        slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
        context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
        block_tables = self.prepare_block_tables(seqs)
        set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
        return input_ids, positions
    

    进一步沿用上述示例,其中的两个序列做完prefill后分别产生了一个新的token,二者分别接到两个序列后得到了第一轮decode的输入:

    • seqs[0].tokens_ids=[641,279,16800,11,10339,220,18]

    • seqs[1].token_ids=[111308,6313,35946]

    事实上decode step只关心两个序列的最后一个token,因此处理后可得:

    • slot_mapping=[6,258]:这两个token的kv cache所处的slot id,分别接在第0、1个block中其前序序列的位置([0,1,2,3,4,5], [256, 257])之后

    • context_lens=[7,3]:此时两个序列的长度,分别为7和3

    • block_tables=[[0], [1]]:两个token所在的block table序号

    返回结果:

    • input_ids=[18,35946]:可见就是把两个序列的最后一个token连起来,pack成了一个长度为2的序列作为本步模型的输入

    • positions=[6,2]:两个序列最后一个token各自的位置

请求执行 – run model

准备好本step的input_idspositions,并将相应的slot_mapping等信息构建成context后,即可开始运行模型,得到输出logits:

    logits = self.run_model(input_ids, positions, is_prefill)

model runner的run_model()方法如下

@torch.inference_mode()
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
    if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
        return self.model.compute_logits(self.model(input_ids, positions))
    else:
        bs = input_ids.size(0)
        context = get_context()
        graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
        graph_vars = self.graph_vars
        graph_vars["input_ids"][:bs] = input_ids
        graph_vars["positions"][:bs] = positions
        graph_vars["slot_mapping"].fill_(-1)
        graph_vars["slot_mapping"][:bs] = context.slot_mapping
        graph_vars["context_lens"].zero_()
        graph_vars["context_lens"][:bs] = context.context_lens
        graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
        graph.replay()
        return self.model.compute_logits(graph_vars["outputs"][:bs])

可见:

  • 当处于prefill阶段,或强制不启用cuda graph(enforce_eager=True),或decode但bsz过大时,不能使用已捕捉的graph,因此只能直接naive地输入模型来进行推理self.model(input_ids, positions),得到输出张量(last hidden states)后计算logits(通过lm_head层)并返回

  • 在其它情况下会启用cuda graph,通过graph replay来启动并执行计算。从graph_vars中取出相关的输入输出张量(和graph capture时的输入输出地址是相同的),然后将本step的具体数据复制到输入中,即准备好了本轮输入张量。然后即可根据本step输入的形状来取得对应的graph,并通过graph.replay()来对本轮的输入数据进行计算,并将结果放到输出张量中,最终再对输出张量计算logits并返回

模型中的具体结构和transformers中的没什么区别,这里只关心attention部分,attention模块的定义位于layers/attention.py中:

class Attention(nn.Module):

    def __init__(
        self,
        num_heads,
        head_dim,
        scale,
        num_kv_heads,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.scale = scale
        self.num_kv_heads = num_kv_heads
        self.k_cache = self.v_cache = torch.tensor([])

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        context = get_context()
        k_cache, v_cache = self.k_cache, self.v_cache
        if k_cache.numel() and v_cache.numel():
            store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
        if context.is_prefill:
            if context.block_tables is not None:    # prefix cache
                k, v = k_cache, v_cache
            o = flash_attn_varlen_func(q, k, v,
                                       max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
                                       max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
                                       softmax_scale=self.scale, causal=True, block_table=context.block_tables)
        else:    # decode
            o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
                                        cache_seqlens=context.context_lens, block_table=context.block_tables, 
                                        softmax_scale=self.scale, causal=True)
        return o

attention模块的forward逻辑如下,其输入本层计算好的q,k,v,然后将k,v存到物理kv cache的相应位置,最后调用flash attention的kernel计算得到attention结果o。主要逻辑如下:

  • k,v的保存:

    k_cache, v_cache = self.k_cache, self.v_cache
    if k_cache.numel() and v_cache.numel():
      store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
    

    如果已经预分配好本层的self.k_cacheself.v_cache,则将本层刚刚产生的新kv cache存入self.k_cache和self.v_cache中。

    事实上只有初始化过程中的warmup时self.k_cache, self.v_cache为空,从而不进入store_kvcache分支。warmup之后的allocate_kv_cache()操作会根据warmup中记录的峰值显存占用来最大限度地预分配好kv cache空间,并下发到每一层的self.k_cacheself.v_cache中(详见预分配kv cache部分),从而使得初始化完毕后每个attention层的self.k_cacheself.v_cache都是形状为(num_kvcache_blocks, block_size, num_kv_heads, head_dim)的全0空张量。因此,后续真正推理时无论是prefill请求还是decode请求都会进入store_kvcache分支

    store_kvcache()执行的是根据slot_mapping将本层刚刚产生的新kv cache存入物理kvcache(也即self.k_cache, self.v_cache)相应位置的操作,其python wrapper如下:

    import triton
    import triton.language as tl
      
    def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
        N, num_heads, head_dim = key.shape
        D = num_heads * head_dim
        assert key.stride(-1) == 1 and value.stride(-1) == 1
        assert key.stride(1) == head_dim and value.stride(1) == head_dim
        assert k_cache.stride(1) == D and v_cache.stride(1) == D
        assert slot_mapping.numel() == N
        store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
    

    输入的k,v的形状为(N, num_heads, head_dim),这里的N是batch中所有序列pack后的长度,也即本轮输入的token数量。D为每个token的k,v的总隐含维数。其调用一个triton kernel来找到每个输入token的kv对应的slot位置并将其写入(启动N个并行的program,每个program负责1个token的kv写入):

    def store_kvcache_kernel(
        key_ptr,
        key_stride,
        value_ptr,
        value_stride,
        k_cache_ptr,
        v_cache_ptr,
        slot_mapping_ptr,
        D: tl.constexpr,
    ):
        idx = tl.program_id(0)
        slot = tl.load(slot_mapping_ptr + idx)
        if slot == -1: return
        key_offsets = idx * key_stride + tl.arange(0, D)
        value_offsets = idx * value_stride + tl.arange(0, D)
        key = tl.load(key_ptr + key_offsets)
        value = tl.load(value_ptr + value_offsets)
        cache_offsets = slot * D + tl.arange(0, D)
        tl.store(k_cache_ptr + cache_offsets, key)
        tl.store(v_cache_ptr + cache_offsets, value)
    

    可见,该kernel读取对应token的k,v以及该token对应的slot id后,可以通过slot id来找到这个token的kv在物理kv cache中对应的偏移量,然后将其写入即可。由于triton kernel中读写连续数据时只关心数据的起始位置指针以及覆盖范围,因此在确保k,v本身存储连续以及物理kv cache各个相邻slot存储连续的情况下,可以直接读写一段长度为D=num_heads*head_dim的数据,来完成该token的kv的写入,而无需再更细地分别处理各个head。

    slot==-1的情况为padding或prefix reuse,说明该token的kv无需被存入物理kv cache中。

  • attention计算:

    这里直接调用了flast attention用于prefill和decode得attention计算的官方接口:

    from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
    

    prefill使用flash_attn_varlen_func(),其专门用于变长batch(中的各个序列pack得到的总序列)在prefill时的attn计算,其通过cu_seqlens来识别各个子序列在总序列中的界限从而使得attn操作局限在每个子序列内部:

    if context.is_prefill:
        if context.block_tables is not None:    # prefix cache
            k, v = k_cache, v_cache
        o = flash_attn_varlen_func(q, k, v,
                                   max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
                                   max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
                                   softmax_scale=self.scale, causal=True, block_table=context.block_tables)
    

    decode使用flash_attn_with_kvcache(),其用于decode阶段的attn计算。通过传入block table和kv cache,其内部会读取attn所需的kv并进行计算:

    else:    # decode
        o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
                                    cache_seqlens=context.context_lens, block_table=context.block_tables, 
                                    softmax_scale=self.scale, causal=True)
    

后处理

def step():
  	# schedule
    # ...
    token_ids = self.model_runner.call("run", seqs, is_prefill)
    self.scheduler.postprocess(seqs, token_ids)
    outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
    num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
    return outputs, num_tokens

模型推理输出token_ids即为batch中的各个序列生成的下一个token。例如在上述prefill示例中,输入的seqs对象中包含两个序列:

  • seqs[0].tokens_ids=[641,279,16800,11,10339,220]

  • seqs[1].token_ids=[111308,6313]

prefill后得到两个序列的各自下一个token:token_ids=[18, 35946]

然后,将此时的seqstoken_ids输入scheduler.postprocess()中进行后处理,主要作用为判断该序列是否完成了生成,以及若完成了生成该做什么善后处理:

def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
    for seq, token_id in zip(seqs, token_ids):
        seq.append_token(token_id)
        if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
            seq.status = SequenceStatus.FINISHED
            self.block_manager.deallocate(seq)
            self.running.remove(seq)

对于每个序列,首先将这个新生成的token加到其末尾,然后判断此时该序列是否完成了生成(例如达到长度上限或产生EOS token)。如果该序列此时已完成,则将其状态设为FINISHED,移出running队列,并对于其包含的kv cache blocks进行善后处理:

def _deallocate_block(self, block_id: int) -> Block:
    assert self.blocks[block_id].ref_count == 0
    self.used_block_ids.remove(block_id)
    self.free_block_ids.append(block_id)

def deallocate(self, seq: Sequence):
    for block_id in reversed(seq.block_table):
        block = self.blocks[block_id]
        block.ref_count -= 1
        if block.ref_count == 0:
            self._deallocate_block(block_id)
    seq.num_cached_tokens = 0
    seq.block_table.clear()

对于该序列使用的每个block,将其引用数-1,如果发现某block在该序列完成后引用数归零,则说明此时没有任何一个active的序列还在使用它,进一步触发其释放操作,将该block由used_block_ids集合移入free_block_ids集合,从而将其标记为闲置,可以在后续请求中被占用。

最终,将本step所有所有生成完成的序列(的seq id)汇总到outputs中,并作为本step的最终返回值进行返回。

推理结束

调用一次generate()对于一批输入prompts进行offline推理,在所有请求全部生成完毕之前不断执行推理step,每个step都构造一个batch并调用模型执行一步推理,并返回本步推理后完成生成的序列id。outputs用来承载本批输入prompt的最终输出,每蹦出一个生成完毕的序列就将其加入到outputs中,直到所有请求都生成完毕。最终将所有请求及其生成结果进行detokenize后返回,从而完成本次offline推理。

def generate(prompts):
  # add requests
  # ...
  outputs = {}
  # ...
  while not self.is_finished():
  	output, num_tokens = self.step()
   	# tqdm记录
   	# ...
   	for seq_id, token_ids in output:
    	outputs[seq_id] = token_ids
	outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
	outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]

	return outputs