年份与会议
2022 · arXiv
把注意力优化重点从 FLOPs 转向 IO,把“Exact Attention 也能大幅提速”变成现实,是现代训练和推理系统的关键基石之一。
年份与会议
2022 · arXiv
作者
Tri Dao、Daniel Y. Fu、Stefano Ermon、Atri Rudra、Christopher Re
主题
FlashAttention
阅读时长
约 3 分钟
收录时间
2022/05/27
很多人第一次听到 FlashAttention,会把它理解成一个工程优化包,似乎只是把标准注意力 kernel 写得更快一些。实际上,这篇论文真正重要的地方在于它改变了大家看待注意力瓶颈的方式。
过去谈注意力效率时,人们更容易盯着理论复杂度:
FlashAttention 则指出,一个被严重低估的问题是 IO。也就是说,注意力慢并不只是因为算得多,还因为数据在不同层级内存之间搬得太多。
这件事一旦被说清楚,后续大量大模型系统优化都出现了非常明确的新方向。
标准 self-attention 在实现时,通常会显式构造一些中间矩阵,例如注意力分数矩阵和 softmax 后的权重矩阵。理论上公式很简单,但在 GPU 上会出现两个大问题:
对于长序列场景,这种开销会非常可怕。也就是说,哪怕算法是“精确注意力”,真正慢你的可能不是乘法本身,而是内存读写。
FlashAttention 的核心洞察,就是把注意力看成一个 memory movement problem,而不只是一个 arithmetic problem。
如果把标准注意力和 FlashAttention 放在一起比较,差别会更直观:
<g>
<rect x="58" y="100" width="84" height="54" rx="14" fill="#e8f1ff" stroke="#98b7e1" />
<text x="100" y="132" text-anchor="middle" font-size="18" font-weight="700">Q</text>
<rect x="170" y="100" width="84" height="54" rx="14" fill="#eef6e8" stroke="#a8c48e" />
<text x="212" y="132" text-anchor="middle" font-size="18" font-weight="700">K</text>
<rect x="282" y="100" width="84" height="54" rx="14" fill="#fff4dc" stroke="#e2c36f" />
<text x="324" y="132" text-anchor="middle" font-size="18" font-weight="700">V</text>
</g>
<rect x="102" y="188" width="220" height="48" rx="14" fill="#fce7ef" stroke="#e2a8bd" />
<text x="212" y="217" text-anchor="middle" font-size="17" font-weight="700">显式构造 Score 矩阵 QK^T</text>
<rect x="102" y="252" width="220" height="48" rx="14" fill="#ece8ff" stroke="#b5abef" />
<text x="212" y="281" text-anchor="middle" font-size="17" font-weight="700">Softmax 概率矩阵再写回 HBM</text>
<text x="245" y="318" text-anchor="middle" font-size="14" fill="#8b4b4b">瓶颈:中间矩阵大,HBM 读写频繁</text>
<line x1="100" y1="154" x2="212" y2="188" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />
<line x1="212" y1="154" x2="212" y2="188" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />
<line x1="324" y1="154" x2="212" y2="252" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />
<line x1="212" y1="236" x2="212" y2="252" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />
<g>
<rect x="548" y="100" width="84" height="54" rx="14" fill="#e8f1ff" stroke="#98b7e1" />
<text x="590" y="132" text-anchor="middle" font-size="18" font-weight="700">Q 块</text>
<rect x="660" y="100" width="84" height="54" rx="14" fill="#eef6e8" stroke="#a8c48e" />
<text x="702" y="132" text-anchor="middle" font-size="18" font-weight="700">K 块</text>
<rect x="772" y="100" width="84" height="54" rx="14" fill="#fff4dc" stroke="#e2c36f" />
<text x="814" y="132" text-anchor="middle" font-size="18" font-weight="700">V 块</text>
</g>
<rect x="588" y="184" width="228" height="54" rx="14" fill="#dff4f0" stroke="#8dc7bd" />
<text x="702" y="208" text-anchor="middle" font-size="17" font-weight="700">SRAM 中分块计算</text>
<text x="702" y="228" text-anchor="middle" font-size="13" fill="#4b5563">tile-by-tile attention</text>
<rect x="588" y="252" width="228" height="48" rx="14" fill="#fef3c7" stroke="#e2c36f" />
<text x="702" y="281" text-anchor="middle" font-size="17" font-weight="700">在线 softmax + 直接聚合输出</text>
<text x="735" y="318" text-anchor="middle" font-size="14" fill="#355c52">收益:减少中间矩阵落地,IO 更低</text>
<line x1="590" y1="154" x2="702" y2="184" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />
<line x1="702" y1="154" x2="702" y2="184" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />
<line x1="814" y1="154" x2="702" y2="252" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />
<line x1="702" y1="238" x2="702" y2="252" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />
</g>
FlashAttention 采用 tiling 的思路,把 Q、K、V 和中间计算按块处理,而不是先把完整注意力矩阵全摊开。
这样做的好处是:
从直觉上说,它不像标准实现那样“先把所有关系都写下来,再统一处理”,而是更像一边分块读数据、一边当场完成必要聚合。
如果只是分块处理,一个看似棘手的问题是:softmax 通常需要看整行分数,才能做稳定归一化。FlashAttention 的关键设计之一,就是通过在线方式维护必要统计量,让 softmax 可以在块级处理下仍然保持数值稳定。
这使得算法同时具备两个非常重要的性质:
这也是 FlashAttention 特别漂亮的地方:它并没有牺牲正确性来换速度,而是通过更聪明的执行方式把两者尽量都保住了。
IO-aware 这个词是这篇论文的灵魂。它真正想告诉大家的是:
GPU 程序性能,不只由算术复杂度决定,还强烈受限于数据搬运路径。
很多近似注意力方法从理论上减少了计算量,但在真实硬件上未必更快,因为:
FlashAttention 之所以影响力这么大,是因为它把“硬件现实”正式纳入了注意力算法设计本身。
论文最有说服力的地方,是它不仅在微观 kernel 指标上更快,还在真实训练任务上体现出端到端收益,例如:
对读者而言,最值得记住的结论不是具体倍数,而是:
这三点使 FlashAttention 成为现代大模型基础设施的里程碑工作。
FlashAttention 并没有从根本上把注意力的理论二次复杂度变成线性,但它大幅缓解了注意力在真实系统中的内存和速度压力。因此,它和长上下文的关系更像是:
这很关键,因为现实系统里很多突破,并不一定是数学上把复杂度完全改写,而是让原来不可用的方案变得可用。
FlashAttention 之后,你会明显看到一类系统论文和框架设计越来越重视:
这意味着大模型优化的叙事变了。过去更多是“新模型结构”;现在大量关键进展其实是“旧结构在硬件上怎么更聪明地跑”。
从这个意义上说,FlashAttention 对今天的系统工程影响,不亚于某些架构论文对模型设计的影响。
这几条路线解决的不是同一个问题,但彼此非常互补:
如果把大模型系统看成一条链路:
把这三层分清楚,会比单纯记术语更有帮助。
虽然 FlashAttention 非常重要,但它也有边界:
所以它是关键基建,但不是长上下文和推理成本的唯一答案。
到 2026 年再看,FlashAttention 留下的最大遗产也许是这一条方法论:
在大模型时代,算法设计必须把硬件现实一并考虑进去。
这条原则后来几乎渗透到了所有重要系统工作中。也正因如此,FlashAttention 不只是“一个 attention 优化”,而是一种新的系统研究范式代表。
如果你只抓三点,请记住:
理解这三点,再去看训练优化、长上下文和 kernel 级系统论文时会轻松很多。
沿着相近主题继续阅读,加深对方法边界与实践场景的理解。
提出 Transformer 架构,以纯注意力机制替代 RNN/CNN,重写了序列建模的工程范式与研究方向。
通过“草稿模型先猜、大模型再批量验证”的两阶段解码,把自回归推理中的串行瓶颈部分摊薄,是现代低延迟推理的重要方向之一。
理解 KV Cache 如何减少自回归解码中的重复计算,并系统掌握延迟、吞吐、显存与服务调度之间的权衡。
从本地原型到线上服务,理解主流推理框架的定位差异、部署方式、监控指标与生产化注意事项。