2021年,研究人員在訓(xùn)練一系列微型模型時(shí)取得了一個(gè)驚人的發(fā)現(xiàn),即模型經(jīng)過(guò)長(zhǎng)時(shí)間的訓(xùn)練后,會(huì)有一個(gè)變化,從開(kāi)始只會(huì)「記憶訓(xùn)練數(shù)據(jù)」,轉(zhuǎn)變?yōu)閷?duì)沒(méi)見(jiàn)過(guò)的數(shù)據(jù)也表現(xiàn)出很強(qiáng)的泛化能力。
這種現(xiàn)象被稱為「領(lǐng)悟(grokking)」,如下圖所示,模型在長(zhǎng)時(shí)間擬合訓(xùn)練數(shù)據(jù)后,「領(lǐng)悟」現(xiàn)象會(huì)突然出現(xiàn)。
既然微型模型有這種特性,那么更復(fù)雜一點(diǎn)的模型在經(jīng)過(guò)更長(zhǎng)時(shí)間的訓(xùn)練后,是否也會(huì)突然出現(xiàn)「領(lǐng)悟」現(xiàn)象?最近大型語(yǔ)言模型(LLM)發(fā)展迅猛,它們看起來(lái)對(duì)世界有著豐富的理解力,很多人認(rèn)為L(zhǎng)LM只是在重復(fù)所記憶的訓(xùn)練內(nèi)容,這一說(shuō)法正確性如何,我們?cè)撊绾闻袛郘LM是輸出記憶內(nèi)容,還是對(duì)輸入數(shù)據(jù)進(jìn)行了很好的泛化?
為了更好的了解這一問(wèn)題,本文來(lái)自谷歌的研究者撰寫(xiě)了一篇博客,試圖弄清楚大模型突然出現(xiàn)「領(lǐng)悟」現(xiàn)象的真正原因。
本文先從微型模型的訓(xùn)練動(dòng)態(tài)開(kāi)始,他們?cè)O(shè)計(jì)了一個(gè)具有24個(gè)神經(jīng)元的單層MLP,訓(xùn)練它們學(xué)會(huì)做模加法(modular addition)任務(wù),我們只需知道這個(gè)任務(wù)的輸出是周期性的,其形式為(a+b)mod n。
MLP模型權(quán)重如下圖所示,研究發(fā)現(xiàn)模型的權(quán)重最初非常嘈雜,但隨著時(shí)間的增加,開(kāi)始表現(xiàn)出周期性。
如果將單個(gè)神經(jīng)元的權(quán)重可視化,這種周期性變化更加明顯:
別小看周期性,權(quán)重的周期性表明該模型正在學(xué)習(xí)某種數(shù)學(xué)結(jié)構(gòu),這也是模型從記憶數(shù)據(jù)轉(zhuǎn)變?yōu)榫哂蟹夯芰Φ年P(guān)鍵。很多人對(duì)這一轉(zhuǎn)變感到迷惑,為什么模型會(huì)從記憶數(shù)據(jù)模式轉(zhuǎn)變?yōu)榉夯瘮?shù)據(jù)模式。
用01序列進(jìn)行實(shí)驗(yàn)
為了判斷模型是在泛化還是記憶,該研究訓(xùn)練模型預(yù)測(cè)30個(gè)1和0隨機(jī)序列的前三位數(shù)字中是否有奇數(shù)個(gè)1。例如000110010110001010111001001011為0,而010110010110001010111001001011為1。這基本就是一個(gè)稍微棘手的XOR運(yùn)算問(wèn)題,帶有一些干擾噪聲。如果模型在泛化,那么應(yīng)該只使用序列的前三位數(shù)字;而如果模型正在記憶訓(xùn)練數(shù)據(jù),那么它還會(huì)使用后續(xù)數(shù)字。
該研究使用的模型是一個(gè)單層MLP,在1200個(gè)序列的固定批上進(jìn)行訓(xùn)練。起初,只有訓(xùn)練準(zhǔn)確率有所提高,即模型會(huì)記住訓(xùn)練數(shù)據(jù)。與模運(yùn)算一樣,測(cè)試準(zhǔn)確率本質(zhì)上是隨機(jī)的,隨著模型學(xué)會(huì)通用解決方案而急劇上升。
通過(guò)01序列問(wèn)題這個(gè)簡(jiǎn)單的示例,我們可以更容易地理解為什么會(huì)發(fā)生這種情況。原因就是模型在訓(xùn)練期間會(huì)做兩件事:最小化損失和權(quán)重衰減。在模型泛化之前,訓(xùn)練損失實(shí)際上會(huì)略有增加,因?yàn)樗粨Q了與輸出正確標(biāo)簽相關(guān)的損失,以獲得較低的權(quán)重。
測(cè)試損失的急劇下降使得模型看起來(lái)像是突然泛化,但如果查看模型在訓(xùn)練過(guò)程中的權(quán)重,大多數(shù)模型都會(huì)在兩個(gè)解之間平滑地插值。當(dāng)與后續(xù)分散注意力的數(shù)字相連的最后一個(gè)權(quán)重通過(guò)權(quán)重衰減被修剪時(shí),快速泛化就會(huì)發(fā)生。
「領(lǐng)悟」現(xiàn)象是什么時(shí)候發(fā)生的?
值得注意的是,「領(lǐng)悟(grokking)」是一種偶然現(xiàn)象——如果模型大小、權(quán)重衰減、數(shù)據(jù)大小和其他超參數(shù)不合適,「領(lǐng)悟」現(xiàn)象就會(huì)消失。如果權(quán)重衰減太少,模型就會(huì)對(duì)訓(xùn)練數(shù)據(jù)過(guò)渡擬合。如果權(quán)重衰減過(guò)多,模型將無(wú)法學(xué)到任何東西。
下面,該研究使用不同的超參數(shù)針對(duì)1和0任務(wù)訓(xùn)練了1000多個(gè)模型。訓(xùn)練過(guò)程充滿噪音,因此針對(duì)每組超參數(shù)訓(xùn)練了九個(gè)模型。表明只有兩類(lèi)模型出現(xiàn)「領(lǐng)悟」現(xiàn)象,藍(lán)色和黃色。
具有五個(gè)神經(jīng)元的模塊化加法
模加法a+b mod 67是周期性的,如果總和超過(guò)67,則答案會(huì)產(chǎn)生環(huán)繞現(xiàn)象,可以用一個(gè)圓來(lái)表示。為了簡(jiǎn)化問(wèn)題,該研究構(gòu)建了一個(gè)嵌入矩陣,使用cos?和sin?將a和b放置在圓上,表示為如下形式。
結(jié)果表明,模型僅用5個(gè)神經(jīng)元就可以完美準(zhǔn)確地找到解決方案:
觀察經(jīng)過(guò)訓(xùn)練的參數(shù),研究團(tuán)隊(duì)發(fā)現(xiàn)所有神經(jīng)元都收斂到大致相等的范數(shù)。如果直接繪制它們的cos?和sin?分量,它們基本上均勻分布在一個(gè)圓上。
接下來(lái)是,它是從頭開(kāi)始訓(xùn)練的,沒(méi)有內(nèi)置周期性,這個(gè)模型有很多不同的頻率。
該研究使用離散傅立葉變換(DFT)分離出頻率。就像在1和0任務(wù)中一樣,只有幾個(gè)權(quán)重起到關(guān)鍵作用:
下圖表明,在不同的頻率,模型也能實(shí)現(xiàn)「領(lǐng)悟」:
開(kāi)放問(wèn)題
現(xiàn)在,雖然我們對(duì)單層MLP解決模加法的機(jī)制及其在訓(xùn)練過(guò)程中出現(xiàn)的原因有了扎實(shí)的了解,但在記憶和泛化方面仍有許多有趣的開(kāi)放性問(wèn)題。
哪種模型的約束效果更好呢?
從廣義上講,權(quán)重衰減的確可以引導(dǎo)各種模型避免記憶訓(xùn)練數(shù)據(jù)。其他有助于避免過(guò)擬合的技術(shù)包括dropout、縮小模型,甚至數(shù)值不穩(wěn)定的優(yōu)化算法。這些方法以復(fù)雜的非線性方式相互作用,因此很難先驗(yàn)地預(yù)測(cè)哪種方法最終會(huì)誘導(dǎo)泛化。
此外,不同的超參數(shù)也會(huì)使改進(jìn)不那么突然。
為什么記憶比泛化更容易?
有一種理論認(rèn)為:記憶訓(xùn)練集的方法可能比泛化解法多得多。因此,從統(tǒng)計(jì)學(xué)上講,記憶應(yīng)該更有可能首先發(fā)生,尤其是在沒(méi)有正則化或正則化很少的情況中。正則化技術(shù)(如權(quán)重衰減)會(huì)優(yōu)先考慮某些解決方案,例如,優(yōu)先考慮「稀疏」解決方案,而不是「密集」解決方案。
研究表明,泛化與結(jié)構(gòu)良好的表征有關(guān)。然而,這不是必要條件;在求解模加法時(shí),一些沒(méi)有對(duì)稱輸入的MLP變體學(xué)習(xí)到的「循環(huán)」表征較少。研究團(tuán)隊(duì)還發(fā)現(xiàn),結(jié)構(gòu)良好的表征并不是泛化的充分條件。這個(gè)小模型(訓(xùn)練時(shí)沒(méi)有權(quán)重衰減)開(kāi)始泛化,然后轉(zhuǎn)為使用周期性嵌入的記憶。
在下圖中可以看到,如果沒(méi)有權(quán)重衰減,記憶模型可以學(xué)習(xí)更大的權(quán)重來(lái)減少損失。
甚至可以找到模型開(kāi)始泛化的超參數(shù),然后切換到記憶,然后切換回泛化。
較大的模型呢?
理解模加法的解決方案并非易事。我們有希望理解更大的模型嗎?在這條路上可能需要:
訓(xùn)練更簡(jiǎn)單的模型,具有更多的歸納偏差和更少的運(yùn)動(dòng)部件。
使用它們來(lái)解釋更大模型如何工作的費(fèi)解部分。
按需重復(fù)。
研究團(tuán)隊(duì)相信,這可能是一種更好地有效理解大型模型的的方法,此外,隨著時(shí)間的推移,這種機(jī)制化的可解釋性方法可能有助于識(shí)別模式,從而使神經(jīng)網(wǎng)絡(luò)所學(xué)算法的揭示變得容易甚至自動(dòng)化。