​Tiger:一个“抠”到极致的优化器
创始人
2024-06-02 02:41:56
0

bc3c32cd314cfe1b316243b9012de27a.gif

©PaperWeekly 原创 · 作者 | 苏剑林

单位 | 追一科技

研究方向 | NLP、神经网络

这段时间笔者一直在实验《Google新搜出的优化器Lion:效率与效果兼得的“训练狮”》所介绍的 Lion 优化器。之所以对 Lion 饶有兴致,是因为它跟笔者之前的关于理想优化器的一些想法不谋而合,但当时笔者没有调出好的效果,而 Lion 则做好了。

相比标准的 Lion,笔者更感兴趣的是它在 时的特殊例子,这里称之为“Tiger”。Tiger 只用到了动量来构建更新量,根据《隐藏在动量中的梯度累积:少更新几步,效果反而更好?》[1] 的结论,此时我们不新增一组参数来“无感”地实现梯度累积!这也意味着在我们有梯度累积需求时,Tiger 已经达到了显存占用的最优解,这也是“Tiger”这个名字的来源(Tight-fisted Optimizer,抠门的优化器,不舍得多花一点显存)。

此外,Tiger 还加入了我们的一些超参数调节经验,以及提出了一个防止模型出现 NaN(尤其是混合精度训练下)的简单策略。我们的初步实验显示,Tiger 的这些改动,能够更加友好地完成模型(尤其是大模型)的训练。

f75fe3e8b8a5a9359c7523f3be2bb885.png

基本形式

Tiger 的更新规则为

13ba572d8baeba56bd01318adfc3ba92.png

相比 Lion,它就是选择了参数 ;相比 SignSGD [2],它则是新增了动量和权重衰减。

参考实现:

https://github.com/bojone/tiger

下表对比了 Tiger、Lion 和 AdamW 的更新规则:

5e671f3519de4676fb39098e0f9e7685.png

可见 Tiger 是三者之中的极简者。

e90a93c54c992ff33174ec6c8f301a2e.png

超参设置

尽管 Tiger 已经相当简化,但仍有几个超参数要设置,分别是滑动平均率 、学习率 以及权重衰减率 ,下面我们分别讨论这几个参数的选择。

fa71b7f3624fe4c7e91f6c04bd52d7f4.png

滑动平均率

比较简单的是 。我们知道,在基本形式上 Tiger 相当于 Lion 取 的特殊情形,那么一个直觉是 Tiger 应当取 。在 Lion 的原论文中,对于 CV 任务有 ,所以我们建议 CV 任务取 ;而对于 NLP 任务则有 ,所以我们建议 NLP 任务取 。

833e66789a1dc8f1ee8d8029254080be.png

学习率

对于学习率,Tiger 参考了 Amos、LAMB [3] 等工作,将学习率分两种情况设置。第一种是线性层的 bias 项和 Normalization 的 beta、gamma 参数,这类参数的特点是运算是 element-wise,我们建议学习率选取为全局相对学习率 的一半;第二种主要就是线性层的 kernel 矩阵,这类参数的特点是以矩阵的身份去跟向量做矩阵乘法运算,我们建议学习率选取为全局相对学习率 乘以参数本身的 (Root Mean Square):

e4b5707fc37132ba4fad24d37a09a16c.png

其中

5aa4e7c99faf0a58093eca6c762a10aa.png

这样设置的好处是我们把参数的尺度分离了出来,使得学习率的调控可以交给一个比较通用的“全局相对学习率” ——大致可以理解为每一步的相对学习幅度,是一个对于模型尺度不是特别敏感的量。

换句话说,我们在 base 版模型上调好的 ,基本上可以不改动地用到 large 版模型。注意 带有下标 ,所以它包含了整个学习率的 schedule,包括 Wamrup 以及学习率衰减策略等,笔者的设置经验是 ,至于怎么 Warmup 和衰减,那就是大家根据自己的任务而设了,别人无法代劳。笔者给的 tiger 实现,内置了一个分段线性学习率策略,理论上可以用它模拟任意的 。

08f1e58e8c4b4778b38404e6a03cf382.png

权重衰减率

最后是权重衰减率 ,这个 Lion 论文最后一页也给出了一些参考设置,一般来说 也就设为常数,笔者常用的是 0.01。特别的是,不建议对前面说的 bias、beta、gamma 这三类参数做权重衰减,或者即便要做, 也要低一个数量级以上。因为从先验分布角度来看,权重衰减是参数的高斯先验, 跟参数方差是反比关系,而 bias、beta、gamma 的方差显然要比 kernel 矩阵的方差大,所以它们的 应该更小。

54f6b92172239e81d8da7bfbeb390c6d.png

d7a8597edc8d9a554ea48d36eb319aeb.png

梯度累积

对于很多算力有限的读者来说,通过梯度累积来增大 batch_size 是训练大模型时不可避免的一步。标准的梯度累积需要新增一组参数,用来缓存历史梯度,这意味着在梯度累积的需求下,Adam 新增的参数是 3 组,Lion 是 2 组,而即便是不加动量的 AdaFactor [4] 也有 1.x 组(但说实话 AdaFactor 不加动量,收敛会慢很多,所以考虑速度的话,加一组动量就变为 2.x 组)。

而对于 Tiger 来说,它的更新量只用到了动量和原参数,根据《隐藏在动量中的梯度累积:少更新几步,效果反而更好?》[1],我们可以通过如下改动,将梯度累积内置在 Tiger 中:

a9d045dc154a3c5516d0301dde59c814.png

这里的 是判断 能否被 整除的示性函数

d720de4ca1b412b4204377362ef40e2a.png

可以看到,这仅仅相当于修改了滑动平均率 和学习率 ,几乎不增加显存成本,整个过程是完全“无感”的,这是笔者认为的 Tiger 的最大的魅力。

需要指出的是,尽管 Lion 跟 Tiger 很相似,但是 Lion 并不能做到这一点,因为 时,Lion 的更新需要用到动量以及当前批的梯度,这两个量需要用不同的参数缓存,而 Tiger 的更新只用到了动量,因此满足这一点。类似滴,SGDM 优化器也能做到一点,但是它没有 操作,这意味着学习率的自适应能力不够好,在 Transformer 等模型上的效果通常不如意(参考《Why are Adaptive Methods Good for Attention Models?》[5])。

43cb05386fa9f970129c38377e2d53a4.png

全半精度

对于大模型来说,混合精度训练是另一个常用的“利器”(参考《在 bert4keras 中使用混合精度和 XLA 加速训练》[6])。混合精度,简单来说就是模型计算部分用半精度的 FP16,模型参数的储存和更新部分用单精度的 FP32。之所以模型参数要用 FP32,是因为担心更新过程中参数的更新量过小,下溢出了 FP16 的表示范围(大致是 ),导致某些参数长期不更新,模型训练进度慢甚至无法正常训练。

然而,Tiger(Lion 也一样)对更新量做了 运算,这使得理论上我们可以全用半精度训练!分析过程并不难。

首先,只要对 Loss 做适当的缩放,那么可以做到梯度 不会溢出 FP16 的表示范围;而动量 只是梯度的滑动平均,梯度不溢出,它也不会溢出, 只能是 ,更加不会溢出了;之后,我们只需要保证学习率不小于 ,那么更新量就不会下溢了,事实上我们也不会将学习率调得这么小。因此,Tiger 的整个更新过程都是在 FP16 表示范围内的,因此理论上我们可以直接用全 FP16 精度训练而不用担心溢出问题。

1d37ad1bb6edb4bfb156a56eecdd5447.png

防止NaN

然而,笔者发现对于同样的配置,在 FP32 下训练正常,但切换到混合精度或者半精度后有时会训练失败,具体表现后 Loss 先降后升然后 NaN,这我们之前在《在 bert4keras 中使用混合精度和 XLA 加速训练》[6] 也讨论过。虽然有一些排查改进的方向(比如调节 epsilon 和无穷大的值、缩放 loss 等),但有时候把该排查的都排查了,还会出现这样的情况。

经过调试,笔者发现出现这种情况时,主要是对于某些 batch 梯度变为 NaN,但此时模型的参数和前向计算还是正常的。于是笔者就想了个简单的应对策略:对梯度出现 NaN 时,跳过这一步更新,并且对参数进行轻微收缩,如下

5f7152fa7a1f96617ac527413773bf65.png

其中 代表收缩率,笔者取 , 则是参数的初始化中心,一般就是 gamma 取 1,其他参数都是 0。经过这样处理后,模型的 loss 会有轻微上升,但一般能够恢复正常训练,不至于从头再来。个人的实验结果显示,这样处理能够缓解一部分 NaN 的问题。

当然,该技巧一般的使用场景是同样配置下 FP32 能够正常训练,并且已经做好了 epsilon、无穷大等混合精度调节,万般无奈之下才不得已使用的。如果模型本身超参数设置有问题(比如学习率过大),连 FP32 都会训练到 NaN,那么就不要指望这个技巧能够解决问题了。此外,有兴趣的读者,还可以尝试改进这个技巧,比如收缩之后可以再加上一点噪声来增加参数的多样性,等等。

1c852cf5a7b35e0804581f197245b76f.png

实验结果

不考虑梯度累积带来的显存优化,Tiger 就是 Lion 的一个特例,可以预估 Tiger 的最佳效果肯定是不如 Lion 的最佳效果的,那么效果下降的幅度是否在可接受范围内呢?综合到目前为止多方的实验结果,笔者暂时得出的结论是:

713ee97333d4ed506ef55e5655ce1b08.png

也就是说,考虑效果 Lion 最优,考虑显存占用 Tiger 最优(启用梯度累积时),效果上 Tiger 不逊色于 AdamW,所以 Tiger 替代 AdamW 时没有太大问题的。

具体实验结果包括几部分。第一部分实验来自 Lion 的论文《Symbolic Discovery of Optimization Algorithms》[7],论文中的 Figure 12 对比了 Lion、Tiger、AdamW 在不同尺寸的语言模型上的效果:

dc991254d0bf61356e63770e33416d9d.png

▲ Lion、Tiger(Ablation)、AdamW 在语言模型任务上的对比

这里的 Ablation0.95、Ablation0.98,就是 Tiger的 分别取 0.95、0.98。可以看到,对于 small级别模型,两个 Tiger 持平 AdamW,而在 middle 和 large 级别上,两个 Tiger 都超过了 AdamW。但正如前面所说, 取两者的均值 0.965,有可能还会有进一步的提升。

至于在 CV 任务上,原论文给出了 Table 7:

603c109cc787dcf1dea09321ff997764.png

▲ Lion、Tiger(Ablation)、AdamW在图像分类任务上的对比

同样地,这里的 Ablation0.9、Ablation0.99,就是 Tiger 的 分别取 0.9、0.99。在这个表中,Tiger 跟 AdamW 有明显差距。但是考虑到作者只实验了 0.9、0.99 两个 ,而笔者推荐的是 ,所以笔者跟原作者取得了联系,请他们做了补充实验,他们回复的结果是“ 分别取 0.92、0.95、0.98 时,在 ViT-B/16 上 ImageNet 的结果都是 80.0% 左右”,那么对比上图,就可以确定在精调 时,在 CV 任务上 Tiger 应该也可以追平 AdamW 的。

最后是笔者自己的实验。笔者常用的是 LAMB 优化器,它的效果基本跟 AdamW 持平,但相对更稳定,而且对不同的初始化适应性更好,因此笔者更乐意使用 LAMB。特别地,LAMB 的学习率设置可以完全不改动地搬到 Tiger 中。笔者用 Tiger 重新训练了之前的 base 版 GAU-α 模型,训练曲线跟之前的对比如下:

bdef824a0f755ceaea435fa69d487388.png

▲ 笔者在GAU-α上的对比实验(loss曲线)

327bd611c08337da17be12486b9b7799.png

▲ 笔者在GAU-α上的对比实验(accuracy曲线)

可以看到,Tiger 确实可以取得比 LAMB 更优异的表现。

43ec46f4acaa3adf1424e12e5499931d.png

未来工作

Tiger 还有改进空间吗?肯定有,想法其实有很多,但都没来得及一一验证,大家有兴趣的可以帮忙继续做下去。

在《Google新搜出的优化器Lion:效率与效果兼得的“训练狮”》中,笔者对 运算的评价是:

Lion 通过 操作平等地对待了每一个分量,使得模型充分地发挥了每一个分量的作用,从而有更好的泛化性能。如果是 SGD,那么更新的大小正比于它的梯度,然而有些分量梯度小,可能仅仅是因为它没初始化好,而并非它不重要,所以 Lion 的 操作算是为每个参数都提供了“恢复活力”甚至“再创辉煌”的机会。

然而,细思之下就会发现,这里其实有一个改进空间。“平等地对待了每一个分量”在训练的开始阶段是很合理的,它保留了模型尽可能多的可能。然而,如果一个参数长时间的梯度都很小,那么很有可能这个参数真的是“烂泥扶不上墙”,即已经优化到尽头了,这时候如果还是“平等地对待了每一个分量”,那么就对那些梯度依然较大的“上进生”分量不公平了,而且很可能导致模型震荡。

一个符合直觉的想法是,优化器应该随着训练的推进,慢慢从 Tiger 退化为 SGD。为此,我们可以考虑将更新量设置为

583d6ceae6686a8ef5b9b4e6bda55bfc.png

这里的绝对值和幂运算都是 element-wise 的, 是从 1 到 0 的单调递减函数,当 时对应 Tiger,当 时对应 SGDM。

可能读者会吐槽这里多了 这个 schedule 要调整,问题变得复杂很多。确实如此,如果将它独立地进行调参,那么确实会引入过多的复杂度了。但我们不妨再仔细回忆一下,抛开 Warmup 阶段不算,一般情况下相对学习率 不正是一个单调递减至零的函数?

我们是否可以借助 来设计 呢?比如 不正好是一个从 1 到 0 的单调递减函数?能否用它来作为 ?当然也有可能是 、 更好,调参空间还是有的,但至少我们不用重新设计横跨整个训练进程的 schedule 了。

更发散一些,既然有时候学习率我们也可以用非单调的 schedule(比如带 restart的 cosine annealing),那么 我们是否也可以用非单调的(相当于 Tiger、SGDM 反复切换)?这些想法都有待验证。

253b116ce298308c810ad85a8b8740dd.png

文章小结

在这篇文章中,我们提出了一个新的优化器,名为 Tiger(Tight-fisted Optimizer,抠门的优化器),它在 Lion 的基础上做了一些简化,并加入了我们的一些超参数经验。特别地,在需要梯度累积的场景下,Tiger 可以达到显存占用的理论最优(抠)解!

outside_default.png

参考文献

outside_default.png

[1] https://kexue.fm/archives/8634

[2] https://arxiv.org/abs/1802.04434

[3] https://kexue.fm/archives/7094#层自适应

[4] https://kexue.fm/archives/7302

[5] https://arxiv.org/abs/1912.03194

[6] https://kexue.fm/archives/9059

[7] https://arxiv.org/abs/2302.06675

f629a19c4c4f1750363781f5e00cc525.jpeg

🔍

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

·

·

33a0e36ca6aae66b394f539bbdfa05d2.jpeg

相关内容

热门资讯

122.(leaflet篇)l... 听老人家说:多看美女会长寿 地图之家总目录(订阅之前建议先查看该博客) 文章末尾处提供保证可运行...
育碧GDC2018程序化大世界... 1.传统手动绘制森林的问题 采用手动绘制的方法的话,每次迭代地形都要手动再绘制森林。这...
育碧GDC2018程序化大世界... 1.传统手动绘制森林的问题 采用手动绘制的方法的话,每次迭代地形都要手动再绘制森林。这...
Vue使用pdf-lib为文件... 之前也写过两篇预览pdf的,但是没有加水印,这是链接:Vu...
PyQt5数据库开发1 4.1... 文章目录 前言 步骤/方法 1 使用windows身份登录 2 启用混合登录模式 3 允许远程连接服...
Android studio ... 解决 Android studio 出现“The emulator process for AVD ...
Linux基础命令大全(上) ♥️作者:小刘在C站 ♥️个人主页:小刘主页 ♥️每天分享云计算网络运维...
再谈解决“因为文件包含病毒或潜... 前面出了一篇博文专门来解决“因为文件包含病毒或潜在的垃圾软件”的问题,其中第二种方法有...
南京邮电大学通达学院2023c... 题目展示 一.问题描述 实验题目1 定义一个学生类,其中包括如下内容: (1)私有数据成员 ①年龄 ...
PageObject 六大原则 PageObject六大原则: 1.封装服务的方法 2.不要暴露页面的细节 3.通过r...
【Linux网络编程】01:S... Socket多进程 OVERVIEWSocket多进程1.Server2.Client3.bug&...
数据结构刷题(二十五):122... 1.122. 买卖股票的最佳时机 II思路:贪心。把利润分解为每天为单位的维度,然后收...
浏览器事件循环 事件循环 浏览器的进程模型 何为进程? 程序运行需要有它自己专属的内存空间࿰...
8个免费图片/照片压缩工具帮您... 继续查看一些最好的图像压缩工具,以提升用户体验和存储空间以及网站使用支持。 无数图像压...
计算机二级Python备考(2... 目录  一、选择题 1.在Python语言中: 2.知识点 二、基本操作题 1. j...
端电压 相电压 线电压 记得刚接触矢量控制的时候,拿到板子,就赶紧去测各种波形,结...
如何使用Python检测和识别... 车牌检测与识别技术用途广泛,可以用于道路系统、无票停车场、车辆门禁等。这项技术结合了计...
带环链表详解 目录 一、什么是环形链表 二、判断是否为环形链表 2.1 具体题目 2.2 具体思路 2.3 思路的...
【C语言进阶:刨根究底字符串函... 本节重点内容: 深入理解strcpy函数的使用学会strcpy函数的模拟实现⚡strc...
Django web开发(一)... 文章目录前端开发1.快速开发网站2.标签2.1 编码2.2 title2.3 标题2.4 div和s...