Workflow
硬件的非对称扩展
icon
搜索文档
FlashAttention-4正式发布:算法流水线大改,矩阵乘法级速度
机器之心· 2026-03-06 12:31
文章核心观点 - FlashAttention-4 作为深度学习底层优化技术的重要更新,通过算法与内核的协同设计,针对新一代 Blackwell GPU 架构进行了优化,显著提升了注意力机制的计算效率 [1] - 在 Blackwell B200 GPU 上,FlashAttention-4 使注意力机制的执行速度几乎与矩阵乘法一样快,前向传播最高可达 1605 TFLOPs/s,利用率为 71% [1][10] - 该技术解决了由硬件非对称扩展带来的新瓶颈,并通过利用 Blackwell 的新硬件特性、新型流水线设计和调度优化实现了性能突破 [5][11] - FlashAttention-4 的发布被视为一个里程碑,其性能提升将直接惠及所有前沿大模型,带来更长的有效上下文窗口、更低的推理成本和更强的规模化推理能力 [28] FlashAttention-4 的技术背景与挑战 - **硬件趋势与瓶颈转移**: AI 行业正迅速转向部署 Blackwell 架构系统,现代加速器延续了“硬件非对称扩展”趋势,即张量核心吞吐量增长远快于共享内存带宽、特殊函数单元等其他资源 [5][6] - 从 Hopper H100 到 Blackwell B200,BF16 张量核心吞吐量增加了 2.25倍 (从 1 到 2.25 PFLOPs),但 SFU 数量和共享内存带宽基本保持不变 [6] - 这种扩展不对称性对像注意力这样的复杂内核优化产生了深远影响,性能瓶颈已从张量核心转移至其他部分 [7][10] - **注意力机制的复杂性**: 注意力机制的核心包含两个通用矩阵乘法,中间夹着 softmax,但在实践中还涉及大量辅助工作,如数据搬运、同步、布局转换等 [8][9] - 传统观点认为注意力性能由 GEMM 速度决定,但在 B200 上分析显示,主要瓶颈在于前向传播中的 SFU 单元和反向传播中的共享内存流量 [10][14] FlashAttention-4 的核心设计与优化 - **协同设计思路**: 通过最大化矩阵乘法与其他瓶颈资源之间的重叠来提升性能 [10] - **利用 Blackwell 新硬件特性**: - **张量内存**: 每个 SM 配备 256 KB 的 TMEM,与张量核心直接连接,用于存储中间结果 [12] - **完全异步的第五代张量核心**: 支持异步执行并将结果存储在 TMEM 中,单个 CTA 可使用的最大 UMMA tile 约为 Hopper 架构的 2 倍,减轻了寄存器压力并支持更深流水线 [12] - **2-CTA MMA**: 支持一对 CTA 共同执行一个 UMMA 运算,可将 MMA 的 tile 尺寸扩展到 256×256×16,减少冗余数据传输并降低每个 CTA 的资源占用 [13] - **新型流水线设计**: - **前向传播**: 在 FMA 单元上通过多项式近似实现指数函数的软件仿真以提升吞吐量;引入条件式 softmax 重缩放,跳过 90% 不必要的重缩放操作,缓解 SFU 瓶颈 [1][14] - **反向传播**: 利用 TMEM 存储中间结果以缓解共享内存流量压力;结合 2-CTA MMA 模式进一步降低共享内存访问,并将 atomic reduction 次数减少一半;支持确定性执行模式 [14] - **调度优化**: 引入新的 tile 调度器,解决因果掩码和变长序列导致的负载不均衡问题 [14] 性能表现与行业影响 - **性能基准测试**: 在 B200 上的测试显示,FlashAttention-4 性能显著优于其他实现 [19] - **前向传播**: 比 cuDNN 9.13 快 1.1–1.3 倍,比 Triton 实现快 2.1–2.7 倍 [19] - **反向传播**: 在长序列长度场景下,表现始终优于其他基准模型 [19] - 相比 FlashAttention-3,性能提升了 2–3 倍 [28] - **框架集成与行业反响**: - PyTorch 官方宣布其 FlexAttention 现已支持 FlashAttention-4 后端,使研究人员无需在“灵活性”和“高性能”之间做选择 [24][27] - 在算力受限的工作负载下,相比 Triton,FlexAttention 使用 FlashAttention-4 后端仍可实现 1.2 倍到 3.2 倍的性能提升 [27] - 该技术被认为将直接惠及所有前沿大模型,因为更快的注意力意味着更长的有效上下文窗口、更低的推理成本和更强的规模化推理能力 [28] 实现与工具 - **编程语言与框架**: FlashAttention-4 完全使用 CuTe-DSL 实现,这是 CUTLASS 提供的 Python 内核 DSL,可将编译时间缩短约 20–30 倍,使安装/编译只需几秒钟而非几分钟/几小时 [17]