近幾年來,常用深度網(wǎng)絡(luò)的實現(xiàn),如多層感知機(MLP)、卷積神經(jīng)網(wǎng)絡(luò)(CNN)、循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)等的實現(xiàn)幾乎已經(jīng)形成了規(guī)范(如數(shù)據(jù)預(yù)處理、輸入輸出數(shù)據(jù)格式、代碼的設(shè)計模式等)。然而,較晚出現(xiàn)的圖神經(jīng)網(wǎng)絡(luò)卻還沒有形成一套規(guī)范體系。例如,Github上不同的GNN實現(xiàn)有著多種不同的數(shù)據(jù)結(jié)構(gòu)來存放輸入的圖。 雖然在圖卷積網(wǎng)絡(luò)(GCN)、圖注意力網(wǎng)絡(luò)(GAT)等許多圖神經(jīng)網(wǎng)絡(luò)的理論中,每一層圖神經(jīng)網(wǎng)絡(luò)就是節(jié)點與鄰節(jié)點特征的融合。直觀上說,用循環(huán)遍歷每個節(jié)點的鄰節(jié)點,按照一定的規(guī)律加權(quán)平均就可以實現(xiàn)這些網(wǎng)絡(luò)(如下圖所示)。然而實際上,這樣的實現(xiàn)方式與TensorFlow和PyTorch等深度學(xué)習(xí)框架并不兼容。由于要利用GPU的并行計算能力,這些深度學(xué)習(xí)框架需要我們將數(shù)據(jù)規(guī)整為整齊的矩陣,用矩陣運算而不是循環(huán)來實現(xiàn)深度網(wǎng)絡(luò)。 為了將圖神經(jīng)網(wǎng)絡(luò)的實現(xiàn)用矩陣運算形式實現(xiàn),不同的算法可能需要采用不同的設(shè)計模式。例如GCN通常使用稀疏矩陣來實現(xiàn),而GAT的一些版本由于需要使用Attention矩陣,稀疏矩陣在一些情況下就失效了。 為了解決這個問題,pytorch_geometric(https://github.com/rusty1s/pytorch_geometric)使用了一種基于邊的實現(xiàn)方法。該方法使用scatter操作實現(xiàn)了上述的“用循環(huán)遍歷每個節(jié)點的鄰節(jié)點,按照一定的規(guī)律加權(quán)平均”的操作。該實現(xiàn)依賴于pytorch_scatter(https://github.com/rusty1s/pytorch_scatter)。 用(i, j)表示一個邊,假設(shè)一個圖中有8條邊,我們用index表示i(起始點)的集合,用to_index表示j(目標點)的集合,用input表示to_index特征的集合,那么,一個簡化版GCN(沒有權(quán)重計算,以所有鄰節(jié)點的平均值為輸出;也沒有全連接層)的示意圖如下: 第一行index表示邊的起始點,第二行是目標點的特征(鄰節(jié)點的特征向量,這里簡化為標量)。在GCN過程中,我們其實是根據(jù)邊的起始點來聚合目標點的特征的(以起始點為核心,聚合與其相鄰的鄰節(jié)點的特征值),因此,我們對具有相同起始點(index)的特征(input)進行聚合(相加)即可完成上述操作。在pytorch_scatter中,上述操作可以用下面一行代碼實現(xiàn):
其中,src對應(yīng)input(鄰節(jié)點特征向量集合)。 除了加法,pytorch_scatter還集成了許多其它的聚合操作。因此,pytorch_geometric基于pytorch_scatter構(gòu)建了一個名為MessagePassing的類:
該類可以根據(jù)輸入的邊、特征和指定的聚合方式來對鄰節(jié)點進行聚合。因此,在pytorch_geometric中,GCN和GAT的實現(xiàn)都是一個繼承了MessagePassing的子類,分別實現(xiàn)了GCN和GAT的權(quán)重計算。這樣的實現(xiàn)大幅度簡化了GNN實現(xiàn)的門檻,使用者只要關(guān)注于權(quán)重的計算,而不需要干涉具體的與鄰節(jié)點融合的過程。 另外,由于框架的輸入的圖的邊,而不是鄰接矩陣,避免了大量的不存在的邊對網(wǎng)絡(luò)性能的干擾(內(nèi)存占用、計算效率)。例如經(jīng)典的GAT實現(xiàn)會讓非鄰節(jié)點參與計算,為其賦予一個非常小的權(quán)重來降低其對效果的干擾,這樣GAT的計算效率就會大大降低。 除了MessagePassing,pytorch_geometric還實現(xiàn)了使用其他機制的許多GNN。我們會在以后的文章中介紹。 參考鏈接:
|
|