Antkillerfarm Hacking V8.0

DL acceleration » 并行 & 框架 & 优化(七)——LLM Inference

2024-06-27 :: 5998 Words

快速Transformer(续)

FlashAttention

代码:

https://github.com/Dao-AILab/flash-attention

对于self-attention块,除了两个大矩阵乘法是计算受限的,其他都是内存受限的逐点运算。

FlashAttention主要使用分块矩阵的思想,对矩阵乘法进行优化。

分块之后,参与运算的小块可以放入速度更快的存储单元,例如原来的运算需要反复读取HBM,而现在只需要读取一次HBM,然后就可以在SRAM或者Cache中完成所有的运算。

但Self-Attention中有softmax计算,而softmax的分母包含了所有元素相关的求和项,所以对Self-Attention进行分块计算的真正难点在于对softmax的分块计算。

FlashAttention提出了一种叫做Online Softmax的增量算法:我们首先计算一个分块的局部softmax值,然后存储起来。当处理完下一个分块时,可以根据此时的新的全局最大值和全局EXP求和项来更新旧的softmax值。接着再处理下一个分块,然后再更新。当处理完所有分块后,此时的所有分块的softmax值都是“全局的”。

具体计算公式:

for i = 1 to #tiles do: $$m_i^{(local)}\leftarrow max_{j=1}^b(x_i[j])$$ $$m_i \leftarrow max(m_{i-1},m_i^{(local)})$$ $$d_i^{'} \leftarrow d_{i-1}^{'}e^{m_{i-1}-m_i}+\sum_{j=1}^b e^{x_i[j]-m_i}$$ $$o_i^{'}=o_{i-1}^{'}\frac{d_{i-1}^{'}}{d_i^{'}}e^{m_{i-1}-m_i}+\sum_{j=1}^{b}\frac{e^{x_i[j]-m_i}}{d_i^{'}}V[j+(i-1)b,:]$$ end $$O[k,:]\leftarrow o_{N/b}^{'}$$

其他优化点:

  • softmax recomputing。前向保存logsumexp(LSE),反向利用LSE,重新计算softmax。

  • Dropout,前向阶段只保存dropout seed与offset,反向宁愿再算一遍dropout,放弃保存dropout mask。


FlashAttention V1里Q在内层循环,而V2里K在内层循环。V1对于计算softmax不友好,因为每次计算的中间结果只是部分和,只有全算完才能释放相关存储。

\(B_r\)和\(B_c\)是FlashAttention分块处理时的分块size。


// vllm-project
test_flash_attn_with_paged_kv
flash_attn_with_kvcache
torch.ops._vllm_fa3_C.fwd
mha_fwd
run_mha_fwd
run_mha_fwd_
run_mha_fwd_hdim128
run_flash_fwd
flash_fwd_kernel

// Dao-AILab
test_flash_attn_output
flash_attn_func
FlashAttnFunc.apply
_wrapped_flash_attn_forward
torch.ops.flash_attn._flash_attn_forward
flash_attn_gpu.fwd
import flash_attn_2_cuda as flash_attn_gpu
mha_fwd

Composable Kernel(ck)是AMD ROCm平台的Flash-Attention实现。


varlen即变长序列,产生的背景是”数据拼接“,即LLM使用的训练数据集中,长度较短的序列占大多数,这些短序列为了能够符合Transformer固定长度的输入,就要进行padding,序列越短,padding越多,而我们不太想要padding,padding只是无奈之举。此时,我们可以使用varlen特性,简单来说就是将多个短序列拼接成一个长序列,但是还是每个短序列自己内部计算注意力,短序列之间是隔离的,这样减少了padding,节省计算量和显存。


https://www.zhihu.com/question/611236756

FlashAttention的速度优化原理是怎样的?

https://blog.csdn.net/v_JULY_v/article/details/133619540

通透理解FlashAttention与FlashAttention2:全面降低显存读写、加快计算速度

https://zhuanlan.zhihu.com/p/668888063

原理篇: 从Online-Softmax到FlashAttention V1/V2/V3

https://zhuanlan.zhihu.com/p/665170554

FlashAttention核心逻辑以及V1 V2差异总结

https://www.cnblogs.com/sasasatori/p/18474946

FlashAttention逐代解析与公式推导

https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf

From Online Softmax to FlashAttention

https://zhuanlan.zhihu.com/p/631106302

FlashAttention反向传播运算推导

FlashDecoding

FlashDecoding是FlashAttention项目的一部分,但由于优化方向有所不同,特别单列出来。

Prefill阶段在Q的seqlen维度以及batch_size维度做并行。

但是在Decoding阶段,是逐token生成,在利用KV Cache的情况下,每次推理实际的queries token数为1,已经无法通过queries进行并行了。

既然,Q和BS无法进一步并行了,那么对K,V进行并行是不是就可以了呢?

  • 首先,将K/V切分成更小的块,比如5块;
  • 然后在这些K/V块上,使用标准FlashAttention进行计算,得到所有小块的局部结果。
  • 最后,使用一个额外的kernel做全局的reduce,得到正确输出。

https://crfm.stanford.edu/2023/10/12/flashdecoding.html

Flash-Decoding for long-context inference


New hardware features on Hopper GPUs - WGMMA, TMA, FP8

https://www.zhihu.com/question/661395457

FlashAttention-3发布!有什么新优化点?

https://zhuanlan.zhihu.com/p/661478232

FlashAttenion-V3: Flash Decoding详解

PagedAttention

PagedAttention是UC Berkeley的作品。

PagedAttention使用分页管理的方式管理KV Cache,将每个序列的KV Cache划分为块,每个块包含固定数量token的K/V。

Parallel Sampling:我给模型发送一个请求,希望它对prompt做续写,并给出三种不同的回答。我们管这个场景叫parallel sampling。

显然这里的prompt部分的KV cache是完全重复。这时可以使用如上图所示的进阶版本,通过类似MMU的逻辑地址和物理地址的映射,来解决存储问题。

这个方法也可以推广到Beam Search、Shared prefix等场景。

https://blog.vllm.ai/2023/06/20/vllm.html

vLLM: Easy, Fast, and Cheap LLM Serving with PagedAttention

https://zhuanlan.zhihu.com/p/691038809

vLLM核心技术PagedAttention原理

参考

https://mp.weixin.qq.com/s/1R_plHqxTLE-Fw3TjYnlJQ

GPU BERT上线性能不合格,看看微信AI的PPoPP论文

https://mp.weixin.qq.com/s/OgTQ3O_6lvOG07U-tjpTDA

如何让Transformer在GPU上跑得更快?快手:需要GPU底层优化

https://zhuanlan.zhihu.com/p/638468472

从FlashAttention到PagedAttention, 如何进一步优化Attention性能

https://blog.csdn.net/v_JULY_v/article/details/144218958

一文通透vLLM与其核心技术PagedAttention:减少KV Cache碎片、提高GPU显存利用率(推理加速利器)

LLM Inference

A Survey on Efficient Inference for Large Language Models


https://zhuanlan.zhihu.com/p/653352979

LLM七种推理服务框架总结

https://zhuanlan.zhihu.com/p/671347964

大模型(LLM)推理框架汇总

https://zhuanlan.zhihu.com/p/642412124

LLM的推理优化技术纵览

https://github.com/DefTruth/Awesome-LLM-Inference

Awesome LLM Inference


一次用户请求,实际上既包含prefill,也包含decode。一个是计算密集型,一个是访存密集型。

prefill(用户输入)和decode(模型输出)的token量在不同场景下也是不一样的。如果是简单对话场景,通常模型的decode输出会更多一些,而如果是超长上下文场景,用户先上传一本几十万字的书再进行问答,这本书的prefill会直接起飞。在Agent场景下,大量预设的prompt也会占据非常多的prefill,不过prompt的prefill有不少机会可以提前算好KV而无需每个用户请求单独重复计算。

当整个推理系统服务几千万用户时,一个batch的几十个用户请求只是开胃菜。每个用户会不间断地和大模型进行交互,发出大量请求,但这些请求的间隔时间短则几秒,长则几分钟几小时。考虑人机交互的频率,一个用户请求结束后,对应的KV-cache继续常驻在高速内存中实际意义不大。


这个从今年年中开始,各家都陆续上了PD分离方案(如MoonCake)。

Prefilling阶段是计算密集型,少量Query就可以打满GPU,大量Query反而会增加首token延迟;Decoding阶段是访存密集型,必须依赖大量Query提高GPU计算利用率。因此可以通过多台机器处理Prefilling、单台机器处理Decoding的PD分离方案,实现综合效率最大化、首token延迟(TTFT)最低、DecodeSpeed(TPS)最高。


上图是一个通常的多batch的LLM的Inference过程。其中黄色表示用户的prompt,蓝色表示LLM生成的文本,红色表示结束符号。

这是改进后的continuous batching的示意图。

https://www.anyscale.com/blog/continuous-batching-llm-inference

How continuous batching enables 23x throughput in LLM inference while reducing p50 latency


https://www.zhihu.com/tardis/zm/art/647813179

大模型文本生成——解码策略(Top-k & Top-p & Temperature)

https://b23.tv/OfdfBnz

如何设置大模型推理参数,top_k,top_p, temperature, num_beams

https://blog.csdn.net/HUSTHY/article/details/125990877

Contrastive Search Decoding——一种对比搜索解码文本生成算法

https://zhuanlan.zhihu.com/p/656707466

DoLa:层对比解码改善LLM的真实性


投机采样使用两个模型:一个是原始目标模型,另一个是比原始模型小得多的近似模型。近似模型用于进行自回归串行采样,而大型模型则用于评估采样结果。

https://zhuanlan.zhihu.com/p/651359908

大模型推理妙招—投机采样(Speculative Decoding)


LLM Inference的性能评估主要有以下几个方面:

  • Time To First Token (TTFT)
  • Time Per Output Token (TPOT)
  • Latency:模型为用户生成完整响应所需的总时间。latency = (TTFT) + (TPOT) * (the number of tokens to be generated)
  • Throughput:一个推理服务器每秒可以为所有用户和请求生成的输出token数量。

https://www.databricks.com/blog/llm-inference-performance-engineering-best-practices

LLM Inference Performance Engineering: Best Practices

vLLM

https://docs.vllm.ai/en/latest/

Easy, fast, and cheap LLM serving for everyone

Fork me on GitHub