近幾年來(lái),常用深度網(wǎng)絡(luò)的實(shí)現(xiàn),如多層感知機(jī)(MLP)、卷積神經(jīng)網(wǎng)絡(luò)(CNN)、循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)等的實(shí)現(xiàn)幾乎已經(jīng)形成了規(guī)范(如數(shù)據(jù)預(yù)處理、輸入輸出數(shù)據(jù)格式、代碼的設(shè)計(jì)模式等)。然而,較晚出現(xiàn)的圖神經(jīng)網(wǎng)絡(luò)卻還沒(méi)有形成一套規(guī)范體系。例如,Github上不同的GNN實(shí)現(xiàn)有著多種不同的數(shù)據(jù)結(jié)構(gòu)來(lái)存放輸入的圖。 雖然在圖卷積網(wǎng)絡(luò)(GCN)、圖注意力網(wǎng)絡(luò)(GAT)等許多圖神經(jīng)網(wǎng)絡(luò)的理論中,每一層圖神經(jīng)網(wǎng)絡(luò)就是節(jié)點(diǎn)與鄰節(jié)點(diǎn)特征的融合。直觀上說(shuō),用循環(huán)遍歷每個(gè)節(jié)點(diǎn)的鄰節(jié)點(diǎn),按照一定的規(guī)律加權(quán)平均就可以實(shí)現(xiàn)這些網(wǎng)絡(luò)(如下圖所示)。然而實(shí)際上,這樣的實(shí)現(xiàn)方式與TensorFlow和PyTorch等深度學(xué)習(xí)框架并不兼容。由于要利用GPU的并行計(jì)算能力,這些深度學(xué)習(xí)框架需要我們將數(shù)據(jù)規(guī)整為整齊的矩陣,用矩陣運(yùn)算而不是循環(huán)來(lái)實(shí)現(xiàn)深度網(wǎng)絡(luò)。 為了將圖神經(jīng)網(wǎng)絡(luò)的實(shí)現(xiàn)用矩陣運(yùn)算形式實(shí)現(xiàn),不同的算法可能需要采用不同的設(shè)計(jì)模式。例如GCN通常使用稀疏矩陣來(lái)實(shí)現(xiàn),而GAT的一些版本由于需要使用Attention矩陣,稀疏矩陣在一些情況下就失效了。 為了解決這個(gè)問(wèn)題,pytorch_geometric(https://github.com/rusty1s/pytorch_geometric)使用了一種基于邊的實(shí)現(xiàn)方法。該方法使用scatter操作實(shí)現(xiàn)了上述的“用循環(huán)遍歷每個(gè)節(jié)點(diǎn)的鄰節(jié)點(diǎn),按照一定的規(guī)律加權(quán)平均”的操作。該實(shí)現(xiàn)依賴于pytorch_scatter(https://github.com/rusty1s/pytorch_scatter)。 用(i, j)表示一個(gè)邊,假設(shè)一個(gè)圖中有8條邊,我們用index表示i(起始點(diǎn))的集合,用to_index表示j(目標(biāo)點(diǎn))的集合,用input表示to_index特征的集合,那么,一個(gè)簡(jiǎn)化版GCN(沒(méi)有權(quán)重計(jì)算,以所有鄰節(jié)點(diǎn)的平均值為輸出;也沒(méi)有全連接層)的示意圖如下: 第一行index表示邊的起始點(diǎn),第二行是目標(biāo)點(diǎn)的特征(鄰節(jié)點(diǎn)的特征向量,這里簡(jiǎn)化為標(biāo)量)。在GCN過(guò)程中,我們其實(shí)是根據(jù)邊的起始點(diǎn)來(lái)聚合目標(biāo)點(diǎn)的特征的(以起始點(diǎn)為核心,聚合與其相鄰的鄰節(jié)點(diǎn)的特征值),因此,我們對(duì)具有相同起始點(diǎn)(index)的特征(input)進(jìn)行聚合(相加)即可完成上述操作。在pytorch_scatter中,上述操作可以用下面一行代碼實(shí)現(xiàn):
其中,src對(duì)應(yīng)input(鄰節(jié)點(diǎn)特征向量集合)。 除了加法,pytorch_scatter還集成了許多其它的聚合操作。因此,pytorch_geometric基于pytorch_scatter構(gòu)建了一個(gè)名為MessagePassing的類:
該類可以根據(jù)輸入的邊、特征和指定的聚合方式來(lái)對(duì)鄰節(jié)點(diǎn)進(jìn)行聚合。因此,在pytorch_geometric中,GCN和GAT的實(shí)現(xiàn)都是一個(gè)繼承了MessagePassing的子類,分別實(shí)現(xiàn)了GCN和GAT的權(quán)重計(jì)算。這樣的實(shí)現(xiàn)大幅度簡(jiǎn)化了GNN實(shí)現(xiàn)的門檻,使用者只要關(guān)注于權(quán)重的計(jì)算,而不需要干涉具體的與鄰節(jié)點(diǎn)融合的過(guò)程。 另外,由于框架的輸入的圖的邊,而不是鄰接矩陣,避免了大量的不存在的邊對(duì)網(wǎng)絡(luò)性能的干擾(內(nèi)存占用、計(jì)算效率)。例如經(jīng)典的GAT實(shí)現(xiàn)會(huì)讓非鄰節(jié)點(diǎn)參與計(jì)算,為其賦予一個(gè)非常小的權(quán)重來(lái)降低其對(duì)效果的干擾,這樣GAT的計(jì)算效率就會(huì)大大降低。 除了MessagePassing,pytorch_geometric還實(shí)現(xiàn)了使用其他機(jī)制的許多GNN。我們會(huì)在以后的文章中介紹。 參考鏈接:
|
|
來(lái)自: LibraryPKU > 《機(jī)器學(xué)習(xí)》