四月,arXiv上出現(xiàn)了一篇題為《KAN: Kolmogorov-Arnold Networks》的論文。該論文獲得約5000個贊,對于一篇學術(shù)論文來說,可謂是相當火爆。隨附的GitHub庫已有7600多個星標,且數(shù)字還在持續(xù)增長。 Kolmogorov-Arnold 網(wǎng)絡(luò)(KAN)是一種全新的神經(jīng)網(wǎng)絡(luò)構(gòu)建塊。它比多層感知器(MLP)更具表達力、更不易過擬合且更易于解釋。多層感知器在深度學習模型中無處不在。例如,我們知道它們被用于GPT-2、3以及(可能的)4等模型的Transformer模塊之間。對MLP的改進將對機器學習世界產(chǎn)生廣泛的影響。 MLPMLP實際上是一種非常古老的架構(gòu),可以追溯到50年代。其設(shè)計初衷是模仿大腦結(jié)構(gòu);由許多互聯(lián)的神經(jīng)元組成,這些神經(jīng)元將信息向前傳遞,因此得名前饋網(wǎng)絡(luò)(feed-forward network)。 MLP通常通過類似上圖的示意圖來展示。對于外行來說,這很有用,但在我看來,它并沒有傳達出真正正在發(fā)生的事情的深刻理解。用數(shù)學來表示它要容易得多。 假設(shè)有一些輸入x和一些輸出y。一個兩層的MLP將如下所示: 其中W是可學習權(quán)重的矩陣,b是偏差向量。函數(shù)f是一個非線性函數(shù)。看到這些方程,很明顯,一個MLP是一系列帶有非線性間隔的線性回歸模型。這是一個非?;镜脑O(shè)置。 盡管基本,但它表達力極強。有數(shù)學保證,MLP是通用逼近器,即:它們可以逼近任何函數(shù),類似于所有函數(shù)都可以用泰勒級數(shù)來表示。 為了訓練模型的權(quán)重,我們使用了反向傳播(backpropagation),這要歸功于自動微分(autodiff)。我不會在這里深入討論,但重要的是要注意自動微分可以對任何可微函數(shù)起作用,這在后面會很重要。 MLP的問題 MLP在廣泛的用例中被使用,但存在一些嚴重的缺點。
Kolmogorov-Arnold 網(wǎng)絡(luò)Kolmogorov-Arnold 表示定理 Kolmogorov-Arnold 表示定理的目標類似于支撐MLP的通用逼近定理,但前提不同。它本質(zhì)上說,任何多變量函數(shù)都可以用1維非線性函數(shù)的加法來表示。例如:向量v=(x1, x2)的除法運算可以用對數(shù)和指數(shù)代替: 為什么這會有用呢?這究竟實現(xiàn)了什么? 這為我們提供了一種不同但簡單的范式來開始構(gòu)建神經(jīng)網(wǎng)絡(luò)架構(gòu)。作者聲稱,這種架構(gòu)比使用多層感知器(MLP)更易于解釋、更高效地使用參數(shù),并且具有更好的泛化能力。在MLP中,非線性函數(shù)是固定的,在訓練過程中從未改變。而在KAN中,不再有權(quán)重矩陣或偏差,只有適應(yīng)數(shù)據(jù)的一維非線性函數(shù)。然后將這些非線性函數(shù)相加。我們可以堆疊越來越多的層來創(chuàng)建更復雜的函數(shù)。 B樣條(B-splines) 在KAN中表示非線性的方式中有一點重要的是需要注意的。與MLP中明確定義的非線性函數(shù)(如ReLU()、Tanh()、silu()等)不同,KAN的作者使用樣條。這些基本上是分段多項式。它們源自計算機圖形領(lǐng)域,在該領(lǐng)域中,過度參數(shù)化并不是一個問題。 樣條解決了在多個點之間平滑插值的問題。如果你熟悉機器學習理論,你會知道要在n個數(shù)據(jù)點之間完美插值,需要一個n-1階的多項式。問題是高階多項式可能變得非常曲折,看起來不平滑。
通過將分段多項式函數(shù)適應(yīng)于數(shù)據(jù)點之間的部分,樣條解決了這個問題。這里我們使用三次樣條。
對于三次樣條(樣條的一種類型),為了確保平滑,需要在數(shù)據(jù)點(或結(jié)點)的位置對一階和二階導數(shù)設(shè)置約束。數(shù)據(jù)點兩側(cè)的曲線必須在數(shù)據(jù)點處具有匹配的一階導數(shù)和二階導數(shù)。 KAN使用的是B樣條,另一種類型的樣條,具有局部性(移動一個點不會影響曲線的整體形狀)和匹配的二階導數(shù)(也稱為C2連續(xù)性)的特性。這樣做的代價是實際上不會通過這些點(除了在極端情況下)。
在機器學習中,特別是在應(yīng)用于物理學時,不經(jīng)過每一個數(shù)據(jù)點是可以接受的,因為我們預(yù)計測量會有噪聲。 這就是在KAN的計算圖的每一個邊緣發(fā)生的事情。一維數(shù)據(jù)用一組B樣條進行擬合。 進入KAN因此,現(xiàn)在我們在計算圖的每個邊緣都有一個分段的參數(shù)曲線。在每個節(jié)點,這些曲線被求和:我們之前看到,可以通過這種方式逼近任何函數(shù)。 為了訓練這樣的模型,我們可以使用標準的反向傳播。在這種情況下,作者使用的是LBFGS(Limited-memory Broyden-Fletcher-Goldfarb-Shanno),這是一種二階優(yōu)化方法(與Adam這種一階方法相比)。另一個需要注意的細節(jié)是:在每個代表一維函數(shù)的邊上,有一個B樣條,但作者還增加了一個非線性函數(shù):silu函數(shù)。 對此的解釋不是很清楚,但很可能是由于梯度消失(這是我的猜測)。 我們來試用一下我打算使用作者提供的代碼,它運行得非常出色,有許多示例可以幫助我們更好地理解它。 他們使用由以下函數(shù)生成的合成數(shù)據(jù): 定義模型
這里定義了三個參數(shù):
訓練
該庫的API非常直觀,我們可以看到我們正在使用LBFGS優(yōu)化器,訓練20步。接下來的兩個參數(shù)與網(wǎng)絡(luò)的正則化相關(guān)。 訓練后的下一步是修剪模型,這會移除低于相關(guān)性閾值的邊和節(jié)點,完成后建議重新訓練一下。然后將每個樣條邊轉(zhuǎn)換為符號函數(shù)(log、exp、sin等)。這可以手動或自動完成。庫提供了一個極好的工具,借助model.plot()方法可以看到模型內(nèi)部的情況。
一旦在每個邊上設(shè)置了符號函數(shù),就會進行最終的再訓練,以確保每個邊的仿射參數(shù)是合理的。 整個訓練過程在下面的圖表中總結(jié)。
完整的訓練代碼如下所示:
一些思考模型中有相當多的超參數(shù)可以調(diào)整。這些可以產(chǎn)生非常不同的結(jié)果。例如,在上面的示例中:將隱藏神經(jīng)元的數(shù)量從5改為6意味著KAN找不到正確的函數(shù)。
這種變化性是預(yù)期的,因為這種架構(gòu)是全新的?;藥资陼r間,人們才找到了調(diào)整MLP超參數(shù)(如學習率、批大小、初始化等)的最佳方式。 結(jié)論MLP已經(jīng)存在很長時間了,早該升級了。我們知道這種改變是可能的,大約6年前,LSTMs在序列建模中無處不在,后來被transformers作為標準的語言模型架構(gòu)構(gòu)建塊所取代。如果MLP也能發(fā)生這種變化,那將是令人興奮的。另一方面,這種架構(gòu)仍然不穩(wěn)定,而且運行效果并不是非常出色。時間將告訴我們,否能找到一種方法來繞過這種不穩(wěn)定性并釋放KAN的真正潛力,或者KAN是否會被遺忘,成為機器學習的一個小知識點。 我對這種新架構(gòu)感到非常興奮,但我也持懷疑態(tài)度。 |
|