[Graph] 4์ฅ. Graph Neural Networks: Algorithms
์๋ณธ ๊ฒ์๊ธ: https://velog.io/@euisuk-chung/Graph-Neural-Networks-Algorithms
-
Introduction
๊ทธ๋ํ ๊ตฌ์กฐ ๋ฐ์ดํฐ๋ ๋ณต์กํ ๊ด๊ณ์ ์ํธ์์ฉ์ ๋ชจ๋ธ๋งํ๋ ๋ฐ ๋งค์ฐ ์ ์ฉํฉ๋๋ค. ์ด๋ฌํ ๋ฐ์ดํฐ๋ฅผ ํจ๊ณผ์ ์ผ๋ก ๋ถ์ํ๊ณ ํ์ตํ๊ธฐ ์ํด ๊ทธ๋ํ ์ ๊ฒฝ๋ง(Graph Neural Networks, GNN)๊ณผ ๊ทธ๋ํ ์๋ฒ ๋ฉ(Graph Embedding) ๊ธฐ๋ฒ์ด ๊ฐ๋ฐ๋์์ต๋๋ค.
1.1 ๊ทธ๋ํ ์ ๊ฒฝ๋ง ๋ชจ๋ธ vs ๊ทธ๋ํ ์๋ฒ ๋ฉ
- ๊ทธ๋ํ ์ ๊ฒฝ๋ง ๋ชจ๋ธ: ๊ทธ๋ํ์ ๊ตฌ์กฐ์ ๋ ธ๋์ ํน์ฑ์ ๋์์ ๊ณ ๋ คํ์ฌ ํ์ตํ๋ ์ ๊ฒฝ๋ง ๋ชจ๋ธ์ ๋๋ค. ๋ํ์ ์ผ๋ก GCN, GRN, GAT๊ฐ ์์ต๋๋ค.
- ๊ทธ๋ํ ์๋ฒ ๋ฉ: ๊ทธ๋ํ์ ๊ตฌ์กฐ์ ์ ๋ณด๋ฅผ ์ ์ฐจ์ ๋ฒกํฐ ๊ณต๊ฐ์ ๋งคํํ๋ ๊ธฐ๋ฒ์ ๋๋ค. DeepWalk, Node2Vec, GraphSAGE ๋ฑ์ด ์์ต๋๋ค.
1.2 ์ฃผ์ ์ฐจ์ด์
- ํ์ต ๋ฐฉ์: GNN์ ๋ค์ํ ํ์ต ๋ฐฉ์์ ์ฌ์ฉํ๋ฉฐ, ๊ทธ๋ํ ์๋ฒ ๋ฉ์ ์ฃผ๋ก ๋น์ง๋ ํ์ต์ ์ฌ์ฉํฉ๋๋ค.
- ํน์ฑ ์ ๋ณด ํ์ฉ: GNN์ ๋ ธ๋์ ํน์ฑ ์ ๋ณด๋ฅผ ์ง์ ํ์ฉํ์ง๋ง, ๊ทธ๋ํ ์๋ฒ ๋ฉ์ ์ฃผ๋ก ๊ตฌ์กฐ์ ์ ๋ณด๋ง ์ฌ์ฉํฉ๋๋ค.
- ๋ชจ๋ธ ๋ณต์ก์ฑ: GNN์ ๋ ๋ณต์กํ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๋ฉฐ, ๊ทธ๋ํ ์๋ฒ ๋ฉ์ ์๋์ ์ผ๋ก ๋จ์ํฉ๋๋ค.
- ๊ท๋ฉ์ ํ์ต: GNN์ ๊ท๋ฉ์ ํ์ต์ด ๊ฐ๋ฅํ์ง๋ง, ๋๋ถ๋ถ์ ๊ทธ๋ํ ์๋ฒ ๋ฉ์ ๋ณํ์ ๋ฐฉ์์ ์ฌ์ฉํฉ๋๋ค.
- ๋์ ๊ทธ๋ํ ์ฒ๋ฆฌ: GNN์ ๋์ ๊ทธ๋ํ ์ฒ๋ฆฌ์ ๋ ์ ํฉํ ๋ชจ๋ธ์ด ์์ต๋๋ค.
-
ํํ๋ ฅ: GNN์ ์ง์ญ์ ๊ตฌ์กฐ์ ์ ์ญ์ ๊ตฌ์กฐ๋ฅผ ๋ชจ๋ ํฌ์ฐฉํ ์ ์์ง๋ง, ๊ทธ๋ํ ์๋ฒ ๋ฉ์ ์ฃผ๋ก ์ง์ญ์ ๊ตฌ์กฐ์ ์ด์ ์ ๋ง์ถฅ๋๋ค.
-
Graph Convolutional Networks (GCNs)
GCN์ CNN์ ๊ฐ๋ ์ ๊ทธ๋ํ ๋ฐ์ดํฐ์ ํ์ฅํ ๋ชจ๋ธ๋ก, ๋ ธ๋์ ํน์ฑ๊ณผ ๊ทธ๋ํ ๊ตฌ์กฐ๋ฅผ ๋์์ ๊ณ ๋ คํฉ๋๋ค. GCN์ ์ฃผ๋ก ์คํํธ๋ด ๊ธฐ๋ฐ๊ณผ ๊ณต๊ฐ ๊ธฐ๋ฐ ๋ฐฉ๋ฒ์ผ๋ก ๋๋ฉ๋๋ค.
2.1 ์คํํธ๋ด ๊ธฐ๋ฐ ๋ฐฉ๋ฒ
์คํํธ๋ด ๊ธฐ๋ฐ GCN์ ๊ทธ๋ํ ๋ผํ๋ผ์์(Laplacian) ํ๋ ฌ์ ๊ณ ์ ๋ฒกํฐ๋ฅผ ์ฌ์ฉํ์ฌ ๊ทธ๋ํ์ ์ฃผํ์ ๋๋ฉ์ธ์์ ํฉ์ฑ๊ณฑ ์ฐ์ฐ์ ์ํํฉ๋๋ค. ๊ธฐ๋ณธ ์์ด๋์ด๋ ๊ทธ๋ํ ์ ํธ๋ฅผ ์ฃผํ์ ๋๋ฉ์ธ์์ ๋ณํํ๊ณ , ํํฐ๋งํ ํ ๋ค์ ๊ณต๊ฐ ๋๋ฉ์ธ์ผ๋ก ๋ณํํ๋ ๊ฒ์ ๋๋ค.
- ๊ทธ๋ํ ๋ผํ๋ผ์์ ํ๋ ฌ: L=DโAL = D - AL=DโA (์ฌ๊ธฐ์ DDD๋ ์ฐจ์ ํ๋ ฌ, AAA๋ ์ธ์ ํ๋ ฌ)
- ์คํํธ๋ด ํฉ์ฑ๊ณฑ: gฮธโx=Ugฮธ(ฮ)UTxg_{\theta} \star x = U g_{\theta}(\Lambda) U^T xgฮธโโx=Ugฮธโ(ฮ)UTx
- UUU: ๋ผํ๋ผ์์์ ๊ณ ์ ๋ฒกํฐ
- ฮ\Lambdaฮ: ๋ผํ๋ผ์์์ ๊ณ ์ ๊ฐ
2.2 ๊ณต๊ฐ ๊ธฐ๋ฐ ๋ฐฉ๋ฒ
๊ณต๊ฐ ๊ธฐ๋ฐ GCN์ ์ง์ ์ด์ ๋ ธ๋๋ค์ ํน์ฑ์ ์ฌ์ฉํ์ฌ ํฉ์ฑ๊ณฑ ์ฐ์ฐ์ ์ํํฉ๋๋ค. Kipf์ Welling์ GCN์ด ๋ํ์ ์ ๋๋ค. ์ด ๋ชจ๋ธ์ ์ธ์ ํ๋ ฌ์ ์ฌ์ฉํ์ฌ ์ด์ ๋ ธ๋์ ์ ๋ณด๋ฅผ ์ง๊ณํ๊ณ , ์ด๋ฅผ ํตํด ๊ฐ ๋ ธ๋์ ํน์ฑ์ ์ ๋ฐ์ดํธํฉ๋๋ค.
- Kipf์ Welling์ GCN:
- ๋ ธ๋ ํน์ฑ ํ๋ ฌ XXX์ ์ธ์ ํ๋ ฌ AAA ์ฌ์ฉ
- ์ ๊ทํ ์ธ์ ํ๋ ฌ A^=Dโ12ADโ12\hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}A^=Dโ21โADโ21โ
- ์ ๋ฐ์ดํธ ์: H(l+1)=ฯ(A^H(l)W(l))H^{(l+1)} = \sigma(\hat{A} H^{(l)} W^{(l)})H(l+1)=ฯ(A^H(l)W(l))
2.3 GCN ๊ตฌํ ์์ (PyTorch)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class GCNLayer(nn.Module):
def __init__(self, in_features, out_features):
super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, X, A):
X = torch.mm(A, X)
X = self.linear(X)
return F.relu(X)
class GCN(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super(GCN, self).__init__()
self.layer1 = GCNLayer(in_features, hidden_features)
self.layer2 = GCNLayer(hidden_features, out_features)
def forward(self, X, A):
X = self.layer1(X, A)
X = self.layer2(X, A)
return X
# ๋ฐ์ดํฐ ์ค๋น
X = torch.randn(5, 10) # 5๊ฐ์ ๋
ธ๋, 10์ฐจ์ ํน์ฑ
A = torch.eye(5) + torch.rand(5, 5) # ๊ฐ๋จํ ์ธ์ ํ๋ ฌ ์์
# ๋ชจ๋ธ ์ด๊ธฐํ ๋ฐ ํ์ต
model = GCN(10, 16, 2)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
# Forward pass
output = model(X, A)
-
Graph Recurrent Networks (GRNs)
Graph Recurrent Networks (GRNs)์ RNN์ ๊ฐ๋ ์ ๊ทธ๋ํ์ ์ ์ฉํ ๋ชจ๋ธ๋ก, ์ฃผ๋ก ์๊ฐ์ ์ธ ์์๋ ๊ณ์ธต์ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๋ ๊ทธ๋ํ ๋ฐ์ดํฐ์ ์ ์ฉํฉ๋๋ค.
3.1 ๊ฒ์ดํธ๋ ๊ทธ๋ํ ์ ๊ฒฝ๋ง (Gated Graph Neural Networks, GGNN)
GGNN์ GRU(Gated Recurrent Unit)๋ฅผ ๊ทธ๋ํ์ ์ ์ฉํ์ฌ ์ํ์ค ์ถ๋ ฅ์ด ํ์ํ ๋ฌธ์ ์ ์ ํฉํฉ๋๋ค.
- GGNN ๊ตฌ์กฐ: ๊ฐ ๋ ธ๋๋ GRU๋ฅผ ํตํด ์์ ์ ์ํ๋ฅผ ์ ๋ฐ์ดํธํ๋ฉฐ, ์ธ์ ๋ ธ๋๋ก๋ถํฐ์ ๋ฉ์์ง๋ฅผ ๋ฐ์ ์ด๋ฅผ ํตํฉํฉ๋๋ค.
- ์
๋ฐ์ดํธ ์:
- ๋ฉ์์ง ์ ๋ฌ: Mt=AHtM_t = A H_tMtโ=AHtโ
- GRU ์ ๋ฐ์ดํธ: Ht+1=GRU(Ht,Mt)H_{t+1} = \text{GRU}(H_t, M_t)Ht+1โ=GRU(Htโ,Mtโ)
3.2 GGNN ๊ตฌํ ์์ (PyTorch)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class GGNNLayer(nn.Module):
def __init__(self, in_features, out_features):
super(GGNNLayer, self).__init__()
self.gru = nn.GRU(in_features, out_features)
def forward(self, X, A):
M = torch.mm(A, X)
X, _ = self.gru(M.unsqueeze(0))
return X.squeeze(0)
class GGNN(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super(GGNN, self).__init__()
self.layer1 = GGNNLayer(in_features, hidden_features)
self.layer2 = GGNNLayer(hidden_features, out_features)
def forward(self, X, A):
X = self.layer1(X, A)
X = self.layer2(X, A)
return X
3.3 ํธ๋ฆฌ LSTM (Tree LSTM)
Tree LSTM์ LSTM์ ํธ๋ฆฌ ๊ตฌ์กฐ์ ์ ์ฉํ ๋ชจ๋ธ๋ก, ๊ณ์ธต์ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง ๋ฐ์ดํฐ์ ์ ํฉํฉ๋๋ค. ํธ๋ฆฌ ๊ตฌ์กฐ์์๋ ๋ถ๋ชจ ๋ ธ๋์ ์์ ๋ ธ๋ ๊ฐ์ ๊ด๊ณ๋ฅผ ๋ชจ๋ธ๋งํฉ๋๋ค.
Tree LSTM ๊ตฌ์กฐ
๊ฐ ๋ ธ๋๋ ๋ถ๋ชจ ๋ ธ๋์ ์์ ๋ ธ๋์ ์ํ๋ฅผ ํตํฉํ์ฌ ์์ ์ ์ํ๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค. ์ด๋ฅผ ํตํด ํธ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ๋ฐ๋ผ ์ ๋ณด๋ฅผ ์ ํํ๊ณ , ํธ๋ฆฌ์ ๋ฃจํธ ๋ ธ๋์์ ์ ์ฒด ํธ๋ฆฌ์ ์ ๋ณด๋ฅผ ์ง์ฝํ ์ ์์ต๋๋ค.
์ ๋ฐ์ดํธ ์
-
์์ ๋ ธ๋ ์ํ ์งํฉ: C={hjโฃjโchildren(i)}C = {h_j j \in \text{children}(i)}C={hjโโฃjโchildren(i)} -
๋ถ๋ชจ ๋ ธ๋ ์ํ ์งํฉ: P={hpโฃp=parent(i)}P = {h_p p = \text{parent}(i)}P={hpโโฃp=parent(i)} - ์ํ ์ ๋ฐ์ดํธ: hi=LSTM(C,P,hi)h_i = \text{LSTM}(C, P, h_i)hiโ=LSTM(C,P,hiโ)
๊ฐ ๋ ธ๋ iii์ ์ํ๋ ๋ค์๊ณผ ๊ฐ์ด ์ ๋ฐ์ดํธ๋ฉ๋๋ค:
- iii ๋ ธ๋์ ์์ ๋ ธ๋ ์ํ ํฉ: hjh_jhjโ๋ iii์ ์์ ๋ ธ๋ ์ํ์ ๋๋ค.
- iii ๋ ธ๋์ ๋ถ๋ชจ ๋ ธ๋ ์ํ ํฉ: hph_phpโ๋ iii์ ๋ถ๋ชจ ๋ ธ๋ ์ํ์ ๋๋ค.
- iii ๋ ธ๋์ ์๋ก์ด ์ํ: hih_ihiโ๋ LSTM์ ํตํด ์ ๋ฐ์ดํธ๋ฉ๋๋ค.
Tree LSTM์ ํต์ฌ ์์
\[i*t = \sigma(W\_i x\_t + \sum*{j \in \text{children}(i)} U*i h\_j + b\_i) f*{tj} = \sigma(W*f x\_t + \sum*{j \in \text{children}(i)} U*f h\_j + b\_f) o\_t = \sigma(W\_o x\_t + \sum*{j \in \text{children}(i)} U*o h\_j + b\_o) u\_t = \tanh(W\_u x\_t + \sum*{j \in \text{children}(i)} U*u h\_j + b\_u) c\_t = i\_t \odot u\_t + \sum*{j \in \text{children}(i)} f\_{tj} \odot c\_j h\_t = o\_t \odot \tanh(c\_t)\]์ฌ๊ธฐ์ it,ftj,ot,uti_t, f_{tj}, o_t, u_titโ,ftjโ,otโ,utโ๋ ๊ฐ๊ฐ ์ ๋ ฅ ๊ฒ์ดํธ, ์์ ๊ฒ์ดํธ, ์ถ๋ ฅ ๊ฒ์ดํธ, ์ ๋ฐ์ดํธ ๊ฒ์ดํธ๋ฅผ ๋ํ๋ด๋ฉฐ, ctc_tctโ์ hth_thtโ๋ ๊ฐ๊ฐ ์ ์ํ์ ํ๋ ์ํ๋ฅผ ๋ํ๋ ๋๋ค.
์ฝ๋ ์์ (PyTorch)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn
import torch.nn.functional as F
class TreeLSTMCell(nn.Module):
def __init__(self, in_features, out_features):
super(TreeLSTMCell, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.W_i = nn.Linear(in_features, out_features)
self.U_i = nn.Linear(out_features, out_features)
self.W_f = nn.Linear(in_features, out_features)
self.U_f = nn.Linear(out_features, out_features)
self.W_o = nn.Linear(in_features, out_features)
self.U_o = nn.Linear(out_features, out_features)
self.W_u = nn.Linear(in_features, out_features)
self.U_u = nn.Linear(out_features, out_features)
def forward(self, x, child_h, child_c):
child_h_sum = torch.sum(child_h, dim=0)
child_c_sum = torch.sum(child_c, dim=0)
i = torch.sigmoid(self.W_i(x) + self.U_i(child_h_sum))
f = torch.sigmoid(self.W_f(x) + self.U_f(child_h_sum))
o = torch.sigmoid(self.W_o(x) + self.U_o(child_h_sum))
u = torch.tanh(self.W_u(x) + self.U_u(child_h_sum))
c = i * u + f * child_c_sum
h = o * torch.tanh(c)
return h, c
class TreeLSTM(nn.Module):
def __init__(self, in_features, out_features):
super(TreeLSTM, self).__init__()
self.cell = TreeLSTMCell(in_features, out_features)
def forward(self, x, children_h, children_c):
h, c = self.cell(x, children_h, children_c)
return h, c
3.4 ๊ทธ๋ํ LSTM (Graph LSTM)
๊ทธ๋ํ LSTM์ LSTM์ ์ผ๋ฐ ๊ทธ๋ํ์ ํ์ฅํ ๋ชจ๋ธ๋ก, ๋ ธ๋ ์์๋ฅผ ๊ฒฐ์ ํ๋ ๋ฐฉ๋ฒ์ด ์ค์ํฉ๋๋ค. ๊ทธ๋ํ ๊ตฌ์กฐ์์๋ ์ํ ๊ฒฝ๋ก๊ฐ ์์ ์ ์์ผ๋ฏ๋ก, ์์๋ฅผ ์ ์ํ๋ ๊ฒ์ด ํต์ฌ์ ๋๋ค.
๊ทธ๋ํ LSTM ๊ตฌ์กฐ
๊ทธ๋ํ LSTM์์๋ ๊ฐ ๋ ธ๋๊ฐ ์์ ์ ์ด์ ๋ ธ๋๋ก๋ถํฐ ์ ๋ณด๋ฅผ ๋ฐ์ ์ํ๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค. ๋ ธ๋ ์์๋ฅผ ๊ฒฐ์ ํ๋ ๋ค์ํ ๋ฐฉ๋ฒ์ด ์์ผ๋ฉฐ, ์ด๋ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ์ค์ํ ์ํฅ์ ๋ฏธ์นฉ๋๋ค.
์ ๋ฐ์ดํธ ์
๊ฐ ๋ ธ๋ iii์ ์ํ๋ ๋ค์๊ณผ ๊ฐ์ด ์ ๋ฐ์ดํธ๋ฉ๋๋ค:
\[h*i = \text{LSTM}(h*{\text{neighbors}(i)}, x\_i)\]์ฌ๊ธฐ์ hneighbors(i)h_{\text{neighbors}(i)}hneighbors(i)โ๋ ๋ ธ๋ iii์ ์ด์ ๋ ธ๋๋ค์ ์ํ๋ฅผ ๋ํ๋ด๋ฉฐ, xix_ixiโ๋ ๋ ธ๋ iii์ ์ ๋ ฅ ํน์ฑ์ ๋ํ๋ ๋๋ค.
๊ทธ๋ํ LSTM์ ํต์ฌ ์์
\[i*t = \sigma(W\_i x\_t + U\_i h*{t-1} + b*i) f\_t = \sigma(W\_f x\_t + U\_f h*{t-1} + b*f) o\_t = \sigma(W\_o x\_t + U\_o h*{t-1} + b*o) u\_t = \tanh(W\_u x\_t + U\_u h*{t-1} + b*u) c\_t = f\_t \odot c*{t-1} + i\_t \odot u\_t\] \[h\_t = o\_t \odot \tanh(c\_t)\]์ฝ๋ ์์ (PyTorch)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class GraphLSTMCell(nn.Module):
def __init__(self, in_features, out_features):
super(GraphLSTMCell, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.W_i = nn.Linear(in_features, out_features)
self.U_i = nn.Linear(out_features, out_features)
self.W_f = nn.Linear(in_features, out_features)
self.U_f = nn.Linear(out_features, out_features)
self.W_o = nn.Linear(in_features, out_features)
self.U_o = nn.Linear(out_features, out_features)
self.W_u = nn.Linear(in_features, out_features)
self.U_u = nn.Linear(out_features, out_features)
def forward(self, x, neighbor_h, neighbor_c):
neighbor_h_sum = torch.sum(neighbor_h, dim=0)
neighbor_c_sum = torch.sum(neighbor_c, dim=0)
i = torch.sigmoid(self.W_i(x) + self.U_i(neighbor_h_sum))
f = torch.sigmoid(self.W_f(x) + self.U_f(neighbor_h_sum))
o = torch.sigmoid(self.W_o(x) + self.U_o(neighbor_h_sum))
u = torch.tanh(self.W_u(x) + self.U_u(neighbor_h_sum))
c = i * u + f * neighbor_c_sum
h = o * torch.tanh(c)
return h, c
class GraphLSTM(nn.Module):
def __init__(self, in_features, out_features):
super(GraphLSTM, self).__init__()
self.cell = GraphLSTMCell(in_features, out_features)
def forward(self, x, neighbors_h, neighbors_c):
h, c = self.cell(x, neighbors_h, neighbors_c)
return h, c
-
Graph Attention Networks (GAT)
Graph Attention Networks (GAT)๋ ์ฃผ์ ๋ฉ์ปค๋์ฆ์ ๊ทธ๋ํ์ ์ ์ฉํ ๋ชจ๋ธ์ ๋๋ค. GAT๋ ์๊ธฐ ์ฃผ์(self-attention) ๋ฉ์ปค๋์ฆ์ ์ฌ์ฉํ์ฌ ๊ฐ ๋ ธ๋์ ์ด์ ๋ ธ๋๋ค์ ์๋ก ๋ค๋ฅธ ๊ฐ์ค์น๋ฅผ ๋ถ์ฌํฉ๋๋ค.
4.1 ์ฃผ์ ํน์ง
- self-attention ๋ฉ์ปค๋์ฆ: ๊ฐ ๋ ธ๋์ ์ด์ ๋ ธ๋๋ค์ ๋ํด ๊ฐ์ค์น๋ฅผ ๊ณ์ฐํ์ฌ, ์ค์ํ ์ด์ ๋ ธ๋์ ์ ๋ณด๋ฅผ ๋ ๋ง์ด ๋ฐ์ํฉ๋๋ค.
- multi-head-attention ๋ฉ์ปค๋์ฆ: ์ฌ๋ฌ ์ฃผ์ ํค๋๋ฅผ ์ฌ์ฉํ์ฌ ์์ ์ฑ์ ํฅ์์ํค๊ณ , ๊ฐ ํค๋์ ์ถ๋ ฅ์ ๊ฒฐํฉํ์ฌ ์ต์ข ์ถ๋ ฅ์ ๋ง๋ญ๋๋ค.
4.2 GAT ๊ตฌํ ์์ (PyTorch)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class GATLayer(nn.Module):
def __init__(self, in_features, out_features, num_heads=1):
super(GATLayer, self).__init__()
self.num_heads = num_heads
self.attention_heads = nn.ModuleList([nn.Linear(in_features, out_features) for _ in range(num_heads)])
self.attention_coeffs = nn.ModuleList([nn.Linear(2 * out_features, 1) for _ in range(num_heads)])
def forward(self, X, A):
outputs = []
for head, coeff in zip(self.attention_heads, self.attention_coeffs):
H = head(X)
N = X.size(0)
H_repeated_in_chunks = H.repeat_interleave(N, dim=0)
H_repeated_alternating = H.repeat(N, 1)
all_combinations_matrix = torch.cat([H_repeated_in_chunks, H_repeated_alternating], dim=1)
e = F.leaky_relu(coeff(all_combinations_matrix).view(N, N))
attention = F.softmax(e.masked_fill(A == 0, -1e9), dim=1)
outputs.append(torch.matmul(attention, H))
return torch.cat(outputs, dim=1)
class GAT(nn.Module):
def __init__(self, in_features, hidden_features, out_features, num_heads=1):
super(GAT, self).__init__()
self.layer1 = GATLayer(in_features, hidden_features, num_heads)
self.layer2 = GATLayer(hidden_features * num_heads, out_features, num_heads)
def forward(self, X, A):
X = self.layer1(X, A)
X = self.layer2(X, A)
return X
-
๊ฒฐ๋ก
์ด ์ฑํฐ์์๋ Graph Neural Networks์ ์ฃผ์ ์๊ณ ๋ฆฌ์ฆ์ธ GCN, GRN, GAT์ ๋ํด ์ดํด๋ณด์์ต๋๋ค. ๊ฐ ์๊ณ ๋ฆฌ์ฆ์ ๊ทธ๋ํ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ๊ณ ์ ํ ๋ฐฉ์์ ๊ฐ์ง๊ณ ์์ผ๋ฉฐ, ๋ค์ํ ๊ทธ๋ํ ๊ด๋ จ ์์ ์ ์ ์ฉ๋ ์ ์์ต๋๋ค.
- GCN์ ๋ ธ๋์ ํน์ฑ๊ณผ ๊ทธ๋ํ ๊ตฌ์กฐ๋ฅผ ๋์์ ๊ณ ๋ คํ์ฌ ํจ๊ณผ์ ์ธ ๋ ธ๋ ํํ์ ํ์ตํฉ๋๋ค.
- GRN์ ์๊ฐ์ ๋๋ ๊ณ์ธต์ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง ๊ทธ๋ํ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ์ ํฉํฉ๋๋ค.
- GAT๋ ์ฃผ์ ๋ฉ์ปค๋์ฆ์ ํตํด ์ด์ ๋ ธ๋๋ค์ ์ค์๋๋ฅผ ํ์ตํ์ฌ ๋ ์ ์ฐํ ์ ๋ณด aggregation์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค.
์ด๋ฌํ Graph Neural Networks ์๊ณ ๋ฆฌ์ฆ๋ค์ ๊ฐ๊ฐ์ ์ฅ๋จ์ ์ ๊ฐ์ง๊ณ ์์ผ๋ฉฐ, ๋ฌธ์ ์ ํน์ฑ์ ๋ฐ๋ผ ์ ์ ํ ๋ชจ๋ธ์ ์ ํํ๋ ๊ฒ์ด ์ค์ํฉ๋๋ค.
- GCN์ ๊ตฌํ์ด ๊ฐ๋จํ๊ณ ํจ๊ณผ์ ์ด์ง๋ง, ๊น์ ๋ ์ด์ด๋ฅผ ์๊ธฐ ์ด๋ ต๋ค๋ ๋จ์ ์ด ์์ต๋๋ค.
- GRN์ ์ํ์ค ๋ฐ์ดํฐ๋ ํธ๋ฆฌ ๊ตฌ์กฐ ๋ฐ์ดํฐ์ ๊ฐ์ ์ ๋ณด์ด์ง๋ง, ์ผ๋ฐ์ ์ธ ๊ทธ๋ํ์ ์ ์ฉํ ๋๋ ๋ ธ๋ ์์ ๊ฒฐ์ ์ด ์ด๋ ค์ธ ์ ์์ต๋๋ค.
- GAT๋ ๋ ธ๋ ๊ฐ์ ์ค์๋๋ฅผ ํ์ตํ ์ ์์ด ์ ์ฐ์ฑ์ด ๋์ง๋ง, ๊ณ์ฐ ๋ณต์ก๋๊ฐ ์๋์ ์ผ๋ก ๋์ต๋๋ค.