FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

把注意力优化重点从 FLOPs 转向 IO,把“Exact Attention 也能大幅提速”变成现实,是现代训练和推理系统的关键基石之一。

年份与会议

2022 · arXiv

作者

Tri Dao、Daniel Y. Fu、Stefano Ermon、Atri Rudra、Christopher Re

主题

FlashAttention

阅读时长

约 3 分钟

收录时间

2022/05/27

标签

原文链接

https://arxiv.org/abs/2205.14135

为什么 FlashAttention 不只是“更快的 CUDA 技巧”

很多人第一次听到 FlashAttention,会把它理解成一个工程优化包,似乎只是把标准注意力 kernel 写得更快一些。实际上,这篇论文真正重要的地方在于它改变了大家看待注意力瓶颈的方式。

过去谈注意力效率时,人们更容易盯着理论复杂度:

  • 注意力是二次复杂度
  • 能不能做近似
  • 能不能改成稀疏

FlashAttention 则指出,一个被严重低估的问题是 IO。也就是说,注意力慢并不只是因为算得多,还因为数据在不同层级内存之间搬得太多。

这件事一旦被说清楚,后续大量大模型系统优化都出现了非常明确的新方向。

背景:为什么标准注意力在 GPU 上并不“天然高效”

标准 self-attention 在实现时,通常会显式构造一些中间矩阵,例如注意力分数矩阵和 softmax 后的权重矩阵。理论上公式很简单,但在 GPU 上会出现两个大问题:

  1. 中间结果很大,占显存。
  2. 数据在 HBM 和片上 SRAM 之间频繁读写,IO 成本高。

对于长序列场景,这种开销会非常可怕。也就是说,哪怕算法是“精确注意力”,真正慢你的可能不是乘法本身,而是内存读写。

FlashAttention 的核心洞察,就是把注意力看成一个 memory movement problem,而不只是一个 arithmetic problem

如果把标准注意力和 FlashAttention 放在一起比较,差别会更直观:

标准 Attention 与 FlashAttention 对比图 左侧展示标准 attention 需要显式构造 score 和概率矩阵并频繁写回高带宽显存;右侧展示 FlashAttention 通过分块和在线 softmax 在片上更紧凑地完成计算,减少 IO。 标准 Attention FlashAttention
  <g>
    <rect x="58" y="100" width="84" height="54" rx="14" fill="#e8f1ff" stroke="#98b7e1" />
    <text x="100" y="132" text-anchor="middle" font-size="18" font-weight="700">Q</text>
    <rect x="170" y="100" width="84" height="54" rx="14" fill="#eef6e8" stroke="#a8c48e" />
    <text x="212" y="132" text-anchor="middle" font-size="18" font-weight="700">K</text>
    <rect x="282" y="100" width="84" height="54" rx="14" fill="#fff4dc" stroke="#e2c36f" />
    <text x="324" y="132" text-anchor="middle" font-size="18" font-weight="700">V</text>
  </g>
  <rect x="102" y="188" width="220" height="48" rx="14" fill="#fce7ef" stroke="#e2a8bd" />
  <text x="212" y="217" text-anchor="middle" font-size="17" font-weight="700">显式构造 Score 矩阵 QK^T</text>
  <rect x="102" y="252" width="220" height="48" rx="14" fill="#ece8ff" stroke="#b5abef" />
  <text x="212" y="281" text-anchor="middle" font-size="17" font-weight="700">Softmax 概率矩阵再写回 HBM</text>
  <text x="245" y="318" text-anchor="middle" font-size="14" fill="#8b4b4b">瓶颈:中间矩阵大,HBM 读写频繁</text>

  <line x1="100" y1="154" x2="212" y2="188" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />
  <line x1="212" y1="154" x2="212" y2="188" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />
  <line x1="324" y1="154" x2="212" y2="252" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />
  <line x1="212" y1="236" x2="212" y2="252" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />

  <g>
    <rect x="548" y="100" width="84" height="54" rx="14" fill="#e8f1ff" stroke="#98b7e1" />
    <text x="590" y="132" text-anchor="middle" font-size="18" font-weight="700">Q 块</text>
    <rect x="660" y="100" width="84" height="54" rx="14" fill="#eef6e8" stroke="#a8c48e" />
    <text x="702" y="132" text-anchor="middle" font-size="18" font-weight="700">K 块</text>
    <rect x="772" y="100" width="84" height="54" rx="14" fill="#fff4dc" stroke="#e2c36f" />
    <text x="814" y="132" text-anchor="middle" font-size="18" font-weight="700">V 块</text>
  </g>
  <rect x="588" y="184" width="228" height="54" rx="14" fill="#dff4f0" stroke="#8dc7bd" />
  <text x="702" y="208" text-anchor="middle" font-size="17" font-weight="700">SRAM 中分块计算</text>
  <text x="702" y="228" text-anchor="middle" font-size="13" fill="#4b5563">tile-by-tile attention</text>
  <rect x="588" y="252" width="228" height="48" rx="14" fill="#fef3c7" stroke="#e2c36f" />
  <text x="702" y="281" text-anchor="middle" font-size="17" font-weight="700">在线 softmax + 直接聚合输出</text>
  <text x="735" y="318" text-anchor="middle" font-size="14" fill="#355c52">收益:减少中间矩阵落地,IO 更低</text>

  <line x1="590" y1="154" x2="702" y2="184" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />
  <line x1="702" y1="154" x2="702" y2="184" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />
  <line x1="814" y1="154" x2="702" y2="252" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />
  <line x1="702" y1="238" x2="702" y2="252" stroke="#5b6b7f" stroke-width="3" marker-end="url(#flash-arrow)" />
</g>
FlashAttention 的关键并不是“换了一个近似算法”,而是把 exact attention 的执行路径重新设计成更符合 GPU 内存层级的形态。

核心方法一:分块计算,而不是一次性展开完整矩阵

FlashAttention 采用 tiling 的思路,把 Q、K、V 和中间计算按块处理,而不是先把完整注意力矩阵全摊开。

这样做的好处是:

  • 中间结果不必完整落到高带宽显存里。
  • 更多计算可以在更快的片上 SRAM 中完成。
  • 显著减少不必要的读写次数。

从直觉上说,它不像标准实现那样“先把所有关系都写下来,再统一处理”,而是更像一边分块读数据、一边当场完成必要聚合。

核心方法二:在线 softmax

如果只是分块处理,一个看似棘手的问题是:softmax 通常需要看整行分数,才能做稳定归一化。FlashAttention 的关键设计之一,就是通过在线方式维护必要统计量,让 softmax 可以在块级处理下仍然保持数值稳定。

这使得算法同时具备两个非常重要的性质:

  • 仍然是 exact attention,而不是近似值。
  • 不必为 exactness 付出传统实现那样巨大的中间存储代价。

这也是 FlashAttention 特别漂亮的地方:它并没有牺牲正确性来换速度,而是通过更聪明的执行方式把两者尽量都保住了。

为什么论文强调“IO-aware”

IO-aware 这个词是这篇论文的灵魂。它真正想告诉大家的是:

GPU 程序性能,不只由算术复杂度决定,还强烈受限于数据搬运路径。

很多近似注意力方法从理论上减少了计算量,但在真实硬件上未必更快,因为:

  • 数据仍然要频繁读写
  • kernel 调度和访存模式不理想
  • 近似算法本身又引入额外开销

FlashAttention 之所以影响力这么大,是因为它把“硬件现实”正式纳入了注意力算法设计本身。

实验结果说明了什么

论文最有说服力的地方,是它不仅在微观 kernel 指标上更快,还在真实训练任务上体现出端到端收益,例如:

  • 更快的 BERT 训练
  • 更快的 GPT 风格模型训练
  • 更长上下文下更好的可扩展性

对读者而言,最值得记住的结论不是具体倍数,而是:

  1. Exact attention 依然有很大优化空间。
  2. 不是所有效率提升都必须靠近似。
  3. 真正的 wall-clock speedup 往往来自算法与硬件共同设计。

这三点使 FlashAttention 成为现代大模型基础设施的里程碑工作。

它和长上下文问题是什么关系

FlashAttention 并没有从根本上把注意力的理论二次复杂度变成线性,但它大幅缓解了注意力在真实系统中的内存和速度压力。因此,它和长上下文的关系更像是:

  • 不是彻底消灭问题
  • 而是把一个原本难以承受的问题推迟、压缩、优化到更可用区间

这很关键,因为现实系统里很多突破,并不一定是数学上把复杂度完全改写,而是让原来不可用的方案变得可用。

为什么说它重塑了后续系统论文的出发点

FlashAttention 之后,你会明显看到一类系统论文和框架设计越来越重视:

  • kernel fusion
  • memory hierarchy
  • IO complexity
  • 数据布局与调度

这意味着大模型优化的叙事变了。过去更多是“新模型结构”;现在大量关键进展其实是“旧结构在硬件上怎么更聪明地跑”。

从这个意义上说,FlashAttention 对今天的系统工程影响,不亚于某些架构论文对模型设计的影响。

它和 KV Cache / PagedAttention 的关系

这几条路线解决的不是同一个问题,但彼此非常互补:

  • FlashAttention:主要优化注意力计算本身,尤其是训练和长上下文下的 kernel/IO 问题。
  • KV Cache:减少自回归生成中的重复计算。
  • PagedAttention / vLLM:优化服务场景中 KV Cache 的内存管理和调度。

如果把大模型系统看成一条链路:

  • FlashAttention 更接近算子级效率
  • KV Cache 更接近解码级重复利用
  • PagedAttention 更接近服务级资源管理

把这三层分清楚,会比单纯记术语更有帮助。

局限:FlashAttention 不是万能钥匙

虽然 FlashAttention 非常重要,但它也有边界:

  1. 它主要优化的是执行效率,不直接解决模型对长上下文的“利用质量”问题。
  2. 即便算得更快,超长上下文仍然会带来整体资源压力。
  3. 真正收益依赖硬件、实现版本、框架集成情况,不是任何环境都能无脑同倍数提升。
  4. 它让 exact attention 更实用,但不意味着稀疏、检索、压缩等其他长序列方法不再重要。

所以它是关键基建,但不是长上下文和推理成本的唯一答案。

从今天看,FlashAttention 最重要的遗产是什么

到 2026 年再看,FlashAttention 留下的最大遗产也许是这一条方法论:

在大模型时代,算法设计必须把硬件现实一并考虑进去。

这条原则后来几乎渗透到了所有重要系统工作中。也正因如此,FlashAttention 不只是“一个 attention 优化”,而是一种新的系统研究范式代表。

读这篇论文时最该抓住什么

如果你只抓三点,请记住:

  1. 注意力慢,很多时候慢在 IO,不只是慢在 FLOPs。
  2. FlashAttention 通过分块和在线 softmax 保持了 exactness。
  3. 它重新定义了大模型系统优化的出发点。

理解这三点,再去看训练优化、长上下文和 kernel 级系统论文时会轻松很多。

延伸阅读

相关内容

沿着相近主题继续阅读,加深对方法边界与实践场景的理解。