无损减少80%激活值内存,提升5倍训练序列长度,仅需两行代码
机器之心·2025-06-23 15:44
长序列训练内存优化技术 - 核心观点:StreamBP算法通过线性分解和分步计算链式法则,将大语言模型训练所需的激活值内存降低至梯度检查点方法的20%,同时实现序列长度提升2.8-5.5倍 [3][6] 技术原理 - 梯度检查点方法仅储存每层输入,但单层完整激活值仍占内存85%以上 [9][13] - StreamBP将单层反向传播过程分解为块计算,按输出分块累加Jacobian-vector product,仅需储存当前块输入和输出 [11][14] - 对Transformer层采用注意力掩码优化,对lmhead层根据目标函数特性分块处理(SFT/GRPO独立计算,DPO利用序列维度独立性) [16][20] 性能表现 - 峰值内存从标准BP的36.01GB降至StreamBP的11.99GB(D=20),中间内存从25.15GB降至1.13GB [14] - 单卡A800-80GB测试显示,最大序列长度达梯度检查点的2.5-5.5倍,标准BP的23-36倍 [22][25] - 14B模型SFT训练中,序列长度从梯度检查点的23提升至StreamBP的84.6,32B模型从0.4提升至16.3 [26] 应用兼容性 - 支持SFT、GRPO、PPO、DPO等LLM目标函数,可集成至现有训练框架 [6][20] - 分布式训练下序列长度提升5-5.6倍,部分长序列场景速度较梯度检查点提升10.9%-12.9% [25][28] - 开源代码适配Transformer层和lmhead层,已提供PyTorch实现 [12]