大模型基础知识复习

大模型-Norm

LLM常用的Norm方法有LayerNorm

LayerNorm 与 BatchNorm的区别

为什么Pre-Norm不如Post-Norm

Pre Norm与Post Norm之间的对比是一个“老生常谈”的话题了,本博客就多次讨论过这个问题,比如文章《浅谈Transformer的初始化、参数化与标准化》《模型优化漫谈:BERT的初始标准差为什么是0.02?》等。目前比较明确的结论是:同一设置之下,Pre Norm结构往往更容易训练,但最终效果通常不如Post Norm。Pre Norm更容易训练好理解,因为它的恒等路径更突出,但为什么它效果反而没那么好呢?

笔者之前也一直没有好的答案,直到前些时间在知乎上看到 @唐翔昊 的一个回复后才“恍然大悟”,原来这个问题竟然有一个非常直观的理解!本文让我们一起来学习一下。

基本结论

Pre Norm和Post Norm的式子分别如下:

PreNorm:xt+1=xt+Ft(Norm(xt))(1)Pre Norm: x_{t+1}=x_t+F_t(Norm(x_t)) \tag{1}

PostNorm:xt+1=Norm(xt+Ft(xt))(2)Post Norm: x_{t+1}=Norm(x_t+F_t(x_t)) \tag{2}

在Transformer中,这里的NormNorm主要指Layer Normalization,但在一般的模型中,它也可以是Batch Normalization、Instance Normalization等,相关结论本质上是通用的。

在笔者找到的资料中,显示Post Norm优于Pre Norm的工作有两篇,一篇是《Understanding the Difficulty of Training Transformers》,一篇是《RealFormer: Transformer Likes Residual Attention》。另外,笔者自己也做过对比实验,显示Post Norm的结构迁移性能更加好,也就是说在Pretraining中,Pre Norm和Post Norm都能做到大致相同的结果,但是Post Norm的Finetune效果明显更好。

可能读者会反问《On Layer Normalization in the Transformer Architecture》不是显示Pre Norm要好于Post Norm吗?这是不是矛盾了?其实这篇文章比较的是在完全相同的训练设置下Pre Norm的效果要优于Post Norm,这只能显示出Pre Norm更容易训练,因为Post Norm要达到自己的最优效果,不能用跟Pre Norm一样的训练配置(比如Pre Norm可以不加Warmup但Post Norm通常要加),所以结论并不矛盾。

直观理解

为什么Pre Norm的效果不如Post Norm?知乎上 @唐翔昊 给出的答案是:Pre Norm的深度有“水分”!也就是说,一个LL层的Pre Norm模型,其实际等效层数不如LL层的Post Norm模型,而层数少了导致效果变差了。

具体怎么理解呢?很简单,对于Pre Norm模型我们迭代得到:

xt+1=xt+Ft(Norm(xt))=xt1+Ft1(Norm(xt1))+Ft(Norm(xt))=...=x0+F0(Norm(x0))++Ft1(Norm(xt1))+Ft(Norm(xt))(3)\begin{aligned} x_{t+1}&=x_t+F_t(Norm(x_t)) \\ &=x_{t-1}+F_{t-1}(Norm(x_{t-1}))+F_t(Norm(x_t)) \\ &=... \\ &=x0+F0(Norm(x0))+⋯+F_{t-1}(Norm(x_{t-1}))+F_t(Norm(x_t)) \tag{3} \end{aligned}

其中每一项都是同一量级的,那么有xt+1=O(t+1)x_{t+1}=O(t+1),也就是说第t+1t+1层跟第tt层的差别就相当于t+1t+1tt的差别,当tt较大时,两者的相对差别是很小的,因此

Ft(Norm(xt))+Ft+1(Norm(xt+1))Ft(Norm(xt))+Ft+1(Norm(xt))=(11)(FtFt+1)(Norm(xt))\begin{aligned} &F_t(Norm(x_t))+F_{t+1}(Norm(x_{t+1})) \\ &≈F_t(Norm(x_t))+F_{t+1}(Norm(x_t)) \\ &=\begin{pmatrix} 1 & 1 \end{pmatrix} \begin{pmatrix} F_t \\ F_{t+1} \end{pmatrix}(Norm(x_t)) \end{aligned}

这个意思是说,当tt比较大时,xt,xt+1x_t,x_{t+1}相差较小,所以Ft+1(Norm(xt+1))F_{t+1}(Norm(x_{t+1}))Ft+1(Norm(xt))F_{t+1}(Norm(x_t))很接近,因此原本一个tt层的模型与t+1t+1层和,近似等效于一个更宽的tt层模型,所以在Pre Norm中多层叠加的结果更多是增加宽度而不是深度,层数越多,这个层就越“虚”。

说白了,Pre Norm结构无形地增加了模型的宽度而降低了模型的深度,而我们知道深度通常比宽度更重要,所以是无形之中的降低深度导致最终效果变差了。而Post Norm刚刚相反,在《浅谈Transformer的初始化、参数化与标准化》中我们就分析过,它每Norm一次就削弱一次恒等分支的权重,所以Post Norm反而是更突出残差分支的,因此Post Norm中的层数更加“足秤”,一旦训练好之后效果更优。

相关工作

前段时间号称能训练1000层Transformer的DeepNet想必不少读者都听说过,在其论文《DeepNet: Scaling Transformers to 1,000 Layers》中对Pre Norm的描述是:

However, the gradients of Pre-LN at bottom layers tend to be larger than at top layers, leading to a degradation in performance compared with Post-LN.

不少读者当时可能并不理解这段话的逻辑关系,但看了前一节内容的解释后,想必会有新的理解。

简单来说,所谓“the gradients of Pre-LN at bottom layers tend to be larger than at top layers”,就是指Pre Norm结构会过度倾向于恒等分支(bottom layers),从而使得Pre Norm倾向于退化(degradation)为一个“浅而宽”的模型,最终不如同一深度的Post Norm。这跟前面的直观理解本质上是一致的。

文章小结

本文主要分享了“为什么Pre Norm的效果不如Post Norm”的一个直观理解。


© Chasing Powered by Hexo Theme Stellar 1.29.1
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。

本站由丘盾系统丘盾系统提供CDN加速

| 总访问量12204次 | 总访客数人次

发表了 8 篇文章 · 总计 5k 字