跳过正文

Paged Attention V1(vLLM)

··4705 字·10 分钟· loading · loading ·
NLP Transformer LLM VLLM Paged Attention
目录
Attention and Optimization - 这篇文章属于一个选集。
§ [4]: 本文

vLLM
#

vLLM是吞吐性能卓越的大模型推理框架,PagedAttention是vLLM最大的创新点: Efficient Memory Management for Large Language Model Serving with PagedAttention

vLLM中的attention计算,在推理的prefill阶段, 使用第三方库xformers的优化实现,decoding阶段使用 CUDA kernel 实现(csrc/attention/attention_kernels.cu,大约800多行)。

Attention计算时使用页式管理 KV Cache 来提高内存利用率,进而提高吞吐量。

Paged Attention(PA)
#

vLLM中有两个版本的 PA,其中:

  • V1 源于 FasterTransformers 的 MHA,适用于 len(seq) < 8192 或 num_seqs * num_heads > 512 的情况。
  • V2 参考了 Flash Decoding方式,对 sequence 的维度进行切分来增加并行粒度

Paged Attention V1
#

Block table in PA

Example generation process for a request with PagedAttention

一个 req 中包含多个 seq 时,可以共享blocks

Example generation process for a request that samples multiple outputs

Paged Attention V1 CUDA Kernel(vLLM)
#

csrc/attention/attention_kernels.cu

single_query attention 函数

Dispatch逻辑:

  • CALL_KERNEL_LAUNCHER_BLOCK_SIZE 根据存储的kv blocksize进行派发,分别是 8, 16, 32
  • LAUNCH_ATTENTION_KERNEL 根据注意力头大小HEADSIZE静态派发

并行任务的划分:

  • dim3 grid(num_heads, num_seqs, 1)
  • dim3 block(NUM_THREADS), 线程数是128,每个 block 负责完成 output 矩阵一行(head_size个元素)结果的 attention 计算
  • block 的线程划分为若干个 Warp, 每个 Warp 的32个线程划分为 blk_size 个 thread group

Kernel 输入参数

out[num_seqs, num_heads, head_size]
q[num_seqs, num_heads, head_size]
k_cache[num_blocks, num_kv_heads, head_size/x, block_size, x] # x表示一个向量化的大小,如float16 -> 16 / sizeof(float16) = 8
v_cache[num_blocks, num_kv_heads, head_size, block_size]
head_mapping[num_heads] # 使用MQA, GQA时的kv_head
block_tables[num_seqs, max_num_blocks_per_seq] # 维护各个Q对应KVCache的哪些block
context_lens[num_seqs] # 用于变长

num_head: Q 的 head 数 num_kv_heads:K, V 的 head 数,MHA 的 num_kv_heads = num_head,GQA、MQA 的 num_kv_heads < num_head blk_size # block_size,每个page block存储的元素数量,每个page存(blk_size, num_head,head_size)个K、V的元素

Kernel 的常量定义:

  • THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1) 通过WARPSIZE / BLOCKSIZE 得到一个thread_group大小。注意这里的BLOCKSIZE不是cuda blocksize,而是一个kv block的大小(默认值16)
  • NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / - WARP_SIZE 表示每个thread_group处理多少个token
  • NUM_WARPS 表示一个threadblock有多少个warp
  • VEC_SIZE 表示向量化大小,保证每个thread_group一次性获取16bytes,MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1)
  • NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE 表示每个thread要负责多少个数据计算
  • NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE 表示每个thread负责的数据经过向量化后,一共有多少个vec
  • V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE) 每个thread一次性读取16bytes
  • NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE。对于v_cache[head_size, block_size],表示一行需要几个V_VEC
  • NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW 表示一个warp可以处理多少行
  • NUM_ROWS_PER_THREAD 表示每个thread需要负责多少行

Kernel 代码逻辑:

(1)循环从显存读取\(Q\)到 shared memory:

迭代读取,每 CUDA block 负责读取\(Q\)的一行(head_size 个元素)存入 shared memory。其中,block 的每个 Warp 负责读取 16blk_size 字节的 Q,即每个 thread group 会读取16字节的 Q,16blk_size 字节的 Q 对应 sequence 的一个 head。

const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;

// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
    i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] =
    *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
__syncthreads();

(2)循环从显存读取\(K\)到 register,并计算QK:

  • 每个 seq 包含 cxt_length * num_kv_heads * head_size 个元素
  • 每个 CUDA block 负责计算一个 seq 的一个 head 的 \(QK^T\), 只需要读取 ctx_length * head_size 个 K 的元素
  • 因为页式内存管理,K 在 ctx_length 维度的存储不连续,以 blk_size 个 token 为粒度分布在不同的内存地址,所以需要根据 Q 的 head_idx 和 seq_idx 访问 block_table 找到 K 的 physical_block_num
  • K Cache的布局为 [num_blocks, num_kv_heads, head_size/x, block_size, x], 目的是优化写入 shared memory。Q和K的同一行元素被读入寄存器并进行点乘运算后,结果要写入shared memory。如果一个 Warp 中所有线程都计算 Q、K 同一行数据,会导致写入 shared memory 的同一个位置,这将造成 warp 内不同线程顺序地写入。所以 warp 的线程最好计算 Q和K 的不同行数据。在设计 K 布局时,将 block_size 放在比 head_size 更低的维度。由于warp size大于block_size,我们需要将head_size拆分为head_size/x和x两个维度,借x到最低维度,以确保每个线程读入的数据量和计算量都足够大。最后,每个线程组派一个线程去写入shared memory,这样一个warp有blk_size个线程并行写入shared memory,从而增加了shared memory的访问带宽。这种设计策略是为了实现高效的并行计算和内存访问,以提高整体的计算性能。
  • 读取 K 需要一个循环,循环中每个CUDA block中的所有 warp 依次访问num_blocks 个 page block。每次迭代:
    • 每个 warp 负责访问连续的 blk_size 个 KCache 的行数据(blk_size * head_size个元素)。每个 thread group 负责访问 KCache 的一行,将head_size 个元素读入寄存器
    • 寄存器中的Q和K元素进行点乘,结果写入shared memory。一个 CUDA block 的 shared memory 存储了一行 QK^T 的结果,共 ctx_length 个元素
    • CUDA block 对 shared memory 中元素进行 max,sum 方式 reduction,然后计算得到 softmax 的结果

代码步骤:

  • group是由block大小决定的,当block>32时,每个warp实现了一个group,否则在一个warp中实现多个group

  • 每个warp负责计算一个block KCache,而每个block key shape为 [block_size, num_head, head_size]

  • 每个thread_group取一个key,即num_head个元素,计算QK dot

  • 只有thread_group的第一个thread负责将QK结果写入shared memory

  • head_idx标记GPU BLOCKs,也即每个GPU Blocks计算一个head

  • num_heads标记使用的GPU BLOCKs总数,也即head num

  • seq_idx标记的是第二维GPU BLOCKs, 也即seq的位置

分配red_smem[2*NUM_WARPS]为reduce所用,保留的是warp内的局部最大值。后面计算了qvec的dot结果保存为qk,先在group内reduce计算得到局部最大值,然后在每个warp内reduce计算得到全局最大值为qk_max。

// 每个warp负责 blocksize * headsize个元素
// block_idx是block cache中的序号(逻辑序号)
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
    // TODO(Zhengzekang)
    // 定位物理块
    const int physical_block_number = block_table[block_idx];
    // ...
    K_vec k_vecs[NUM_VECS_PER_THREAD];
    
    // 遍历每个thread_group处理多少个token
    for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
        const int physical_block_offset =
            (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
        const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
        K_vec k_vecs[NUM_VECS_PER_THREAD];
        // 遍历每个thread需要处理多少个VEC
        for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
            //  vectorized取到key
            k_vecs[j] = xxxx;
        }
        // 计算QKdot,里面包含了一个thread_groupsize的WarpReduceSum,
        float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);

        // 只有thread_group的第一个thread负责将QK结果写入shared memory
        // 并且维护一个qk_max,用于后续softmax
        if (thread_group_offset == 0) {
            // Store the partial reductions to shared memory.
            // NOTE(woosuk): It is required to zero out the masked logits.
            const bool mask = token_idx >= context_len;
            logits[token_idx] = mask ? 0.f : qk;
            // Update the max value.
            qk_max = mask ? qk_max : fmaxf(qk_max, qk);
        }
    }
}

此时各个thread_group已经完成了自己的qk_dot操作,并且都维护了qk_max。下面就需要和其他thread_group做warp shuffle操作,得到一个warp内的qk max值。

由于每个thread_group里的thread内维护的qk_max是一样的,所以warp shuffle只需到 thread_group_size即可停止。并由lane_id = 0的线程将warp里的qk_max存储到smem,最后再做一次warpreduce,得到一个block里的qkmax值,通过shfl_sync广播操作,让每个线程都拿到max

#pragma unroll
  for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
    qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
  }
  if (lane == 0) {
    red_smem[warp_idx] = qk_max;
  }
  __syncthreads();

  // TODO(woosuk): Refactor this part.
  // Get the max qk value for the sequence.
  qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
    qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
  }
  // Broadcast the max qk value to all threads.
  qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);

接下来就是常规的softmax

执行exp(x-qk_max)并得到每个warp上的exp_sum,规约得全局(所有warp)的exp_sum,计算每个节点上的softmax

// Get the sum of the exp values.
float exp_sum = 0.f;
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
    float val = __expf(logits[i] - qk_max);
    logits[i] = val;
    exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);

// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
    logits[i] *= inv_sum;
}
__syncthreads();

(3)从显存读取\(V\)到 register, 计算 softmax(QK^T)V

和KCache一样,CUDA block 依次访问 num_blk 个 VCahce block 到寄存器,每个 warp 负责 1 个 VCache block,。不过这里不需要以 thread group 为单位访问16字节,而是每个 thread 读取16字节的元素到寄存器,然后与shared memory的 softmax(QK^T)中间结果 对应位置16字节的数据进行点乘,得到一个 float 结果,写到 output 的对应位置中。

为了读写连续,将V_cache转置,shape为:[num_blocks, num_kv_heads, head_size, block_size]

注意这里使用了fp32模式以防止累加过程中的精度损失

// 每个线程一次性读16bytes数据
  constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
  using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  using Float_L_vec = typename FloatVec<L_vec>::Type;
  
  // 每一行有多少个V_VEC,假设BLOCK_SIZE=8,那么NUM_V_VECS_PER_ROW=1
  constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
  // 一个WARP一次处理多少行,按照上面假设,这里是32
  constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
  // 每个thread需要负责多少行,假设headsize=128,那么每个thread要处理4行
  constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;

  // 提前分配accumulate buffer,用float累加
  float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
    accs[i] = 0.f;
  }

for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
    // ...
#pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
      if (row_idx < HEAD_SIZE) {
        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
        V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
        accs[i] += dot(logits_vec, v_vec);
      }
    }
  }

(4)更新最终的结果

将一个block分成上半部分warp和下半部分warp。上半部分warp(warp_id > mid)将自己累加的结果写到shared memory。下半部分warp将之前上半部分warp存到shared_memory 的结果取出,进行累加。这样重复,当warp_idx==0时,将所有结果写回到每一行中。

  // Perform reduction across warps.
  float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
  for (int i = NUM_WARPS; i > 1; i /= 2) {
    int mid = i / 2;
    // Upper warps write to shared memory.
    if (warp_idx >= mid && warp_idx < i) {
      float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
          dst[row_idx] = accs[i];
        }
      }
    }
  }
    __syncthreads();
    // Lower warps update the output.
    if (warp_idx < mid) {
      const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
          accs[i] += src[row_idx];
        }
      }
    }
    __syncthreads();
    // Write the final output.
    if (warp_idx == 0) {
        scalar_t* out_ptr = out + seq_idx * num_heads * 
        max_num_partitions * HEAD_SIZE + head_idx * 
        max_num_partitions * HEAD_SIZE + partition_idx * 
        HEAD_SIZE;
#pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
      if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
        from_float(*(out_ptr + row_idx), accs[i]);
      }
    }
  }

为什么 VCache 的 layout 是 [num_blocks, num_kv_heads, head_size, block_size],和 KCache layout 不一样? 因为 V 要去做点乘的对象在shared memory,只需要读,不涉及并行写。

PA V1 和 Flash Attention 的区别
#

并行任务的划分方式不同

  • FlashAttention 用了两层循环,每次写一个 Tile 的 output tensor,而 PA 只有一层循环,每次写一行 output tensor。因为每次迭代都有整行的 QK^T 中间结果,不需要online softmax
  • PA V1 设计的 KCache layout 充分利用了 shared memory 写带宽

PA V1 的缺陷
#

不足:

  • 不适合 seq 很长的情况,因为没有沿着 ctx_length 或者 batch 维度做切分
  • 和MHA相比,MQA和GAQ没有减少对KV Cache的读写次数。读K、V Cache时候只是做了一个head_idx的转换,会重复从显存读相同的head

未完待续…

Reference:

Attention and Optimization - 这篇文章属于一个选集。
§ [4]: 本文