擴散模型以其令人印象深刻的生成高質(zhì)量圖像的能力而聞名,它們是流行的文本到圖像模型(例如DALL-E、Stable Diffusion和Midjourney)中使用的主要架構(gòu)。
然而,擴散模型不只是用于生成圖像。Meta公司、普林斯頓大學和德克薩斯大學奧斯汀分校的研究人員最近聯(lián)合發(fā)表的一篇研究報告表明,擴散模型可以幫助創(chuàng)建更好的強化學習系統(tǒng)。
該報告引入了一種使用基于擴散的世界模型來訓練強化學習代理的技術(shù)。擴散世界模型(DWM)通過預測未來多個步驟的環(huán)境,增強了當前基于模型的強化學習系統(tǒng)。
無模型的強化學習vs基于模型的強化學習
無模型的強化學習算法直接從與環(huán)境的交互中學習策略或價值函數(shù),而無需預測未來環(huán)境。與其相反,基于模型的強化學習算法通過世界模型來模擬它們的環(huán)境。這些模型使他們能夠預測他們的行為將如何影響他們的環(huán)境,并相應地調(diào)整政策。
基于模型的強化學習的一個關(guān)鍵優(yōu)勢是它需要更少的來自真實環(huán)境的數(shù)據(jù)樣本。這對于自動駕駛汽車和機器人等應用尤其有用。在這些應用中,從現(xiàn)實世界收集數(shù)據(jù)可能成本高昂或者存在風險。
然而,基于模型的強化學習高度依賴于世界模型的準確性。在實踐中,世界模型中的不準確性導致基于模型的強化學習系統(tǒng)比無模型的強化學習表現(xiàn)得更差。
傳統(tǒng)的世界模型使用單步動態(tài)(one-step dynamics)模式,這意味著它們只能根據(jù)當前狀態(tài)和動作預測獎勵和下一個狀態(tài)。當規(guī)劃未來的多個步驟時,強化學習系統(tǒng)使用自己的輸出遞歸地調(diào)用模型。這種方法帶來的問題是,小誤差可能在多個步驟中疊加,使長期預測變得不可靠和不準確。
擴散世界模型(DWM)的前提是學會一次預測未來的多個步驟。如果做得正確,這種方法可以減少長期預測中的錯誤,并提高基于模型的強化學習算法的性能。
擴散世界模型的工作原理
擴散世界模型的工作原理很簡單:它們通過反轉(zhuǎn)一個逐漸向數(shù)據(jù)添加噪聲的過程來學習生成數(shù)據(jù)。例如,當訓練生成圖像時,擴散世界模型會逐漸向圖像添加噪聲層,然后嘗試反轉(zhuǎn)過程并預測原始圖像。通過重復這個過程并添加更多的噪聲層,它學會了從純噪聲中生成高質(zhì)量的圖像。條件擴散模型通過將模型的輸出條件轉(zhuǎn)化為特定輸入(例如圖像附帶的字幕)來添加一層控制。這使開發(fā)人員能夠為這些模型提供文本描述并接收相應的圖像。
但是,雖然擴散模型以其生成高質(zhì)量圖像的能力而聞名,但它們也可以應用于其他數(shù)據(jù)類型。
擴散世界模型(DWM)使用相同的原理來預測強化學習系統(tǒng)的長期結(jié)果。擴散世界模型(DWM)以當前狀態(tài)、操作和預期回報為條件,而不是文本描述。它的輸出是多個步驟的狀態(tài)和對未來的獎勵。
擴散世界模型(DWM)框架有兩個訓練階段。在第一階段,擴散模型在從環(huán)境中收集的一系列軌跡上進行訓練。它從一個強大的世界模型中學習,可以一次預測多個步驟,使其在長期模擬中比其他基于模型的方法更穩(wěn)定。
在第二階段,使用Actor-Critic 算法和擴散世界模型訓練離線強化學習策略。使用離線強化學習消除了訓練過程中在線交互的需求,從而提高了速度,降低了成本和風險。
對于每個步驟,代理使用擴散世界模型(DWM)來生成未來的軌跡,并模擬其動作的回報。研究人員稱之為“擴散模型價值擴展”(Diffusion MVE)。雖然強化學習系統(tǒng)在訓練期間使用擴散世界模型(DWM),但生成的策略是無模型的,這具有更快推理的好處。
研究人員寫道:“擴散模型價值擴展(Diffusion MVE)可以解釋為通過生成建模對離線強化學習進行的值正則化,或者可以解釋為使用合成數(shù)據(jù)進行離線Q學習的一種方法。”
在更高的層面,擴散世界模型(DWM)背后的主要思想是預測未來世界的多個狀態(tài)。因此,可以用另一個序列模型替換擴散模型。研究人員也對Transformer模型進行了實驗,但發(fā)現(xiàn)擴散世界模型(DWM)更有效。
運行擴散世界模型(DWM)
為了測試擴散世界模型(DWM)的有效性,研究人員將其與基于模型的強化學習系統(tǒng)和無模型的強化學習系統(tǒng)進行了比較。他們從D4RL數(shù)據(jù)集中試驗了三種不同的算法和九種運動任務。
結(jié)果表明,擴散世界模型(DWM)比單步世界模型顯著提高了44%的性能。當單步世界模型應用于無模型強化學習算法時,它通常會降低性能。然而,研究人員發(fā)現(xiàn),當與擴散世界模型(DWM)結(jié)合使用時,無模型強化系統(tǒng)的表現(xiàn)優(yōu)于原始版本。
研究人員寫道:“這要歸功于擴散模型的強大表現(xiàn)力和對整個序列的一次性預測,這規(guī)避了傳統(tǒng)的單步動態(tài)模型在多個步驟推出時的復合誤差問題。我們的方法實現(xiàn)了最先進的(SOTA)性能,消除了基于模型算法和無模型算法之間的差距。”
擴散世界模型(DWM)是在非生成任務中使用生成模型的更廣泛趨勢的一部分。在過去的一年,由于生成式人工智能模型的進步,機器人研究取得了飛躍式的進展。大型語言模型正在幫助彌合自然語言命令和機器人運動命令之間的差距。Transformers還幫助研究人員將從不同形態(tài)和設(shè)置中收集的數(shù)據(jù)整合在一起,并訓練可以推廣到不同機器人和任務的模型。