[Graph] 4์žฅ. Graph Neural Networks: Algorithms

Posted by Euisuk's Dev Log on July 18, 2024

[Graph] 4์žฅ. Graph Neural Networks: Algorithms

์›๋ณธ ๊ฒŒ์‹œ๊ธ€: https://velog.io/@euisuk-chung/Graph-Neural-Networks-Algorithms

  1. Introduction

๊ทธ๋ž˜ํ”„ ๊ตฌ์กฐ ๋ฐ์ดํ„ฐ๋Š” ๋ณต์žกํ•œ ๊ด€๊ณ„์™€ ์ƒํ˜ธ์ž‘์šฉ์„ ๋ชจ๋ธ๋งํ•˜๋Š” ๋ฐ ๋งค์šฐ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ํšจ๊ณผ์ ์œผ๋กœ ๋ถ„์„ํ•˜๊ณ  ํ•™์Šตํ•˜๊ธฐ ์œ„ํ•ด ๊ทธ๋ž˜ํ”„ ์‹ ๊ฒฝ๋ง(Graph Neural Networks, GNN)๊ณผ ๊ทธ๋ž˜ํ”„ ์ž„๋ฒ ๋”ฉ(Graph Embedding) ๊ธฐ๋ฒ•์ด ๊ฐœ๋ฐœ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

1.1 ๊ทธ๋ž˜ํ”„ ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ vs ๊ทธ๋ž˜ํ”„ ์ž„๋ฒ ๋”ฉ

  • ๊ทธ๋ž˜ํ”„ ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ: ๊ทธ๋ž˜ํ”„์˜ ๊ตฌ์กฐ์™€ ๋…ธ๋“œ์˜ ํŠน์„ฑ์„ ๋™์‹œ์— ๊ณ ๋ คํ•˜์—ฌ ํ•™์Šตํ•˜๋Š” ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. ๋Œ€ํ‘œ์ ์œผ๋กœ GCN, GRN, GAT๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๊ทธ๋ž˜ํ”„ ์ž„๋ฒ ๋”ฉ: ๊ทธ๋ž˜ํ”„์˜ ๊ตฌ์กฐ์  ์ •๋ณด๋ฅผ ์ €์ฐจ์› ๋ฒกํ„ฐ ๊ณต๊ฐ„์— ๋งคํ•‘ํ•˜๋Š” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค. DeepWalk, Node2Vec, GraphSAGE ๋“ฑ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

1.2 ์ฃผ์š” ์ฐจ์ด์ 

  1. ํ•™์Šต ๋ฐฉ์‹: GNN์€ ๋‹ค์–‘ํ•œ ํ•™์Šต ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜๋ฉฐ, ๊ทธ๋ž˜ํ”„ ์ž„๋ฒ ๋”ฉ์€ ์ฃผ๋กœ ๋น„์ง€๋„ ํ•™์Šต์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  2. ํŠน์„ฑ ์ •๋ณด ํ™œ์šฉ: GNN์€ ๋…ธ๋“œ์˜ ํŠน์„ฑ ์ •๋ณด๋ฅผ ์ง์ ‘ ํ™œ์šฉํ•˜์ง€๋งŒ, ๊ทธ๋ž˜ํ”„ ์ž„๋ฒ ๋”ฉ์€ ์ฃผ๋กœ ๊ตฌ์กฐ์  ์ •๋ณด๋งŒ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  3. ๋ชจ๋ธ ๋ณต์žก์„ฑ: GNN์€ ๋” ๋ณต์žกํ•œ ๊ตฌ์กฐ๋ฅผ ๊ฐ€์ง€๋ฉฐ, ๊ทธ๋ž˜ํ”„ ์ž„๋ฒ ๋”ฉ์€ ์ƒ๋Œ€์ ์œผ๋กœ ๋‹จ์ˆœํ•ฉ๋‹ˆ๋‹ค.
  4. ๊ท€๋‚ฉ์  ํ•™์Šต: GNN์€ ๊ท€๋‚ฉ์  ํ•™์Šต์ด ๊ฐ€๋Šฅํ•˜์ง€๋งŒ, ๋Œ€๋ถ€๋ถ„์˜ ๊ทธ๋ž˜ํ”„ ์ž„๋ฒ ๋”ฉ์€ ๋ณ€ํ™˜์  ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  5. ๋™์  ๊ทธ๋ž˜ํ”„ ์ฒ˜๋ฆฌ: GNN์€ ๋™์  ๊ทธ๋ž˜ํ”„ ์ฒ˜๋ฆฌ์— ๋” ์ ํ•ฉํ•œ ๋ชจ๋ธ์ด ์žˆ์Šต๋‹ˆ๋‹ค.
  6. ํ‘œํ˜„๋ ฅ: GNN์€ ์ง€์—ญ์  ๊ตฌ์กฐ์™€ ์ „์—ญ์  ๊ตฌ์กฐ๋ฅผ ๋ชจ๋‘ ํฌ์ฐฉํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, ๊ทธ๋ž˜ํ”„ ์ž„๋ฒ ๋”ฉ์€ ์ฃผ๋กœ ์ง€์—ญ์  ๊ตฌ์กฐ์— ์ดˆ์ ์„ ๋งž์ถฅ๋‹ˆ๋‹ค.

  7. Graph Convolutional Networks (GCNs)

GCN์€ CNN์˜ ๊ฐœ๋…์„ ๊ทธ๋ž˜ํ”„ ๋ฐ์ดํ„ฐ์— ํ™•์žฅํ•œ ๋ชจ๋ธ๋กœ, ๋…ธ๋“œ์˜ ํŠน์„ฑ๊ณผ ๊ทธ๋ž˜ํ”„ ๊ตฌ์กฐ๋ฅผ ๋™์‹œ์— ๊ณ ๋ คํ•ฉ๋‹ˆ๋‹ค. GCN์€ ์ฃผ๋กœ ์ŠคํŽ™ํŠธ๋Ÿด ๊ธฐ๋ฐ˜๊ณผ ๊ณต๊ฐ„ ๊ธฐ๋ฐ˜ ๋ฐฉ๋ฒ•์œผ๋กœ ๋‚˜๋‰ฉ๋‹ˆ๋‹ค.

2.1 ์ŠคํŽ™ํŠธ๋Ÿด ๊ธฐ๋ฐ˜ ๋ฐฉ๋ฒ•

์ŠคํŽ™ํŠธ๋Ÿด ๊ธฐ๋ฐ˜ GCN์€ ๊ทธ๋ž˜ํ”„ ๋ผํ”Œ๋ผ์‹œ์•ˆ(Laplacian) ํ–‰๋ ฌ์˜ ๊ณ ์œ ๋ฒกํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ทธ๋ž˜ํ”„์˜ ์ฃผํŒŒ์ˆ˜ ๋„๋ฉ”์ธ์—์„œ ํ•ฉ์„ฑ๊ณฑ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. ๊ธฐ๋ณธ ์•„์ด๋””์–ด๋Š” ๊ทธ๋ž˜ํ”„ ์‹ ํ˜ธ๋ฅผ ์ฃผํŒŒ์ˆ˜ ๋„๋ฉ”์ธ์—์„œ ๋ณ€ํ™˜ํ•˜๊ณ , ํ•„ํ„ฐ๋งํ•œ ํ›„ ๋‹ค์‹œ ๊ณต๊ฐ„ ๋„๋ฉ”์ธ์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

  • ๊ทธ๋ž˜ํ”„ ๋ผํ”Œ๋ผ์‹œ์•ˆ ํ–‰๋ ฌ: L=Dโˆ’AL = D - AL=Dโˆ’A (์—ฌ๊ธฐ์„œ DDD๋Š” ์ฐจ์ˆ˜ ํ–‰๋ ฌ, AAA๋Š” ์ธ์ ‘ ํ–‰๋ ฌ)
  • ์ŠคํŽ™ํŠธ๋Ÿด ํ•ฉ์„ฑ๊ณฑ: gฮธโ‹†x=Ugฮธ(ฮ›)UTxg_{\theta} \star x = U g_{\theta}(\Lambda) U^T xgฮธโ€‹โ‹†x=Ugฮธโ€‹(ฮ›)UTx
    • UUU: ๋ผํ”Œ๋ผ์‹œ์•ˆ์˜ ๊ณ ์œ ๋ฒกํ„ฐ
    • ฮ›\Lambdaฮ›: ๋ผํ”Œ๋ผ์‹œ์•ˆ์˜ ๊ณ ์œ ๊ฐ’

2.2 ๊ณต๊ฐ„ ๊ธฐ๋ฐ˜ ๋ฐฉ๋ฒ•

๊ณต๊ฐ„ ๊ธฐ๋ฐ˜ GCN์€ ์ง์ ‘ ์ด์›ƒ ๋…ธ๋“œ๋“ค์˜ ํŠน์„ฑ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ•ฉ์„ฑ๊ณฑ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. Kipf์™€ Welling์˜ GCN์ด ๋Œ€ํ‘œ์ ์ž…๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ ์ธ์ ‘ ํ–‰๋ ฌ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ด์›ƒ ๋…ธ๋“œ์˜ ์ •๋ณด๋ฅผ ์ง‘๊ณ„ํ•˜๊ณ , ์ด๋ฅผ ํ†ตํ•ด ๊ฐ ๋…ธ๋“œ์˜ ํŠน์„ฑ์„ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค.

  • Kipf์™€ Welling์˜ GCN:
    • ๋…ธ๋“œ ํŠน์„ฑ ํ–‰๋ ฌ XXX์™€ ์ธ์ ‘ ํ–‰๋ ฌ AAA ์‚ฌ์šฉ
    • ์ •๊ทœํ™” ์ธ์ ‘ ํ–‰๋ ฌ A^=Dโˆ’12ADโˆ’12\hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}A^=Dโˆ’21โ€‹ADโˆ’21โ€‹
    • ์—…๋ฐ์ดํŠธ ์‹: H(l+1)=ฯƒ(A^H(l)W(l))H^{(l+1)} = \sigma(\hat{A} H^{(l)} W^{(l)})H(l+1)=ฯƒ(A^H(l)W(l))

2.3 GCN ๊ตฌํ˜„ ์˜ˆ์‹œ (PyTorch)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
    
    def forward(self, X, A):
        X = torch.mm(A, X)
        X = self.linear(X)
        return F.relu(X)

class GCN(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super(GCN, self).__init__()
        self.layer1 = GCNLayer(in_features, hidden_features)
        self.layer2 = GCNLayer(hidden_features, out_features)
    
    def forward(self, X, A):
        X = self.layer1(X, A)
        X = self.layer2(X, A)
        return X

# ๋ฐ์ดํ„ฐ ์ค€๋น„
X = torch.randn(5, 10)  # 5๊ฐœ์˜ ๋…ธ๋“œ, 10์ฐจ์› ํŠน์„ฑ
A = torch.eye(5) + torch.rand(5, 5)  # ๊ฐ„๋‹จํ•œ ์ธ์ ‘ ํ–‰๋ ฌ ์˜ˆ์‹œ

# ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฐ ํ•™์Šต
model = GCN(10, 16, 2)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# Forward pass
output = model(X, A)
  1. Graph Recurrent Networks (GRNs)

Graph Recurrent Networks (GRNs)์€ RNN์˜ ๊ฐœ๋…์„ ๊ทธ๋ž˜ํ”„์— ์ ์šฉํ•œ ๋ชจ๋ธ๋กœ, ์ฃผ๋กœ ์‹œ๊ฐ„์ ์ธ ์ˆœ์„œ๋‚˜ ๊ณ„์ธต์  ๊ตฌ์กฐ๋ฅผ ๊ฐ€์ง€๋Š” ๊ทธ๋ž˜ํ”„ ๋ฐ์ดํ„ฐ์— ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค.

3.1 ๊ฒŒ์ดํŠธ๋œ ๊ทธ๋ž˜ํ”„ ์‹ ๊ฒฝ๋ง (Gated Graph Neural Networks, GGNN)

GGNN์€ GRU(Gated Recurrent Unit)๋ฅผ ๊ทธ๋ž˜ํ”„์— ์ ์šฉํ•˜์—ฌ ์‹œํ€€์Šค ์ถœ๋ ฅ์ด ํ•„์š”ํ•œ ๋ฌธ์ œ์— ์ ํ•ฉํ•ฉ๋‹ˆ๋‹ค.

  • GGNN ๊ตฌ์กฐ: ๊ฐ ๋…ธ๋“œ๋Š” GRU๋ฅผ ํ†ตํ•ด ์ž์‹ ์˜ ์ƒํƒœ๋ฅผ ์—…๋ฐ์ดํŠธํ•˜๋ฉฐ, ์ธ์ ‘ ๋…ธ๋“œ๋กœ๋ถ€ํ„ฐ์˜ ๋ฉ”์‹œ์ง€๋ฅผ ๋ฐ›์•„ ์ด๋ฅผ ํ†ตํ•ฉํ•ฉ๋‹ˆ๋‹ค.
  • ์—…๋ฐ์ดํŠธ ์‹:
    • ๋ฉ”์‹œ์ง€ ์ „๋‹ฌ: Mt=AHtM_t = A H_tMtโ€‹=AHtโ€‹
    • GRU ์—…๋ฐ์ดํŠธ: Ht+1=GRU(Ht,Mt)H_{t+1} = \text{GRU}(H_t, M_t)Ht+1โ€‹=GRU(Htโ€‹,Mtโ€‹)

3.2 GGNN ๊ตฌํ˜„ ์˜ˆ์‹œ (PyTorch)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class GGNNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GGNNLayer, self).__init__()
        self.gru = nn.GRU(in_features, out_features)
    
    def forward(self, X, A):
        M = torch.mm(A, X)
        X, _ = self.gru(M.unsqueeze(0))
        return X.squeeze(0)

class GGNN(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super(GGNN, self).__init__()
        self.layer1 = GGNNLayer(in_features, hidden_features)
        self.layer2 = GGNNLayer(hidden_features, out_features)
    
    def forward(self, X, A):
        X = self.layer1(X, A)
        X = self.layer2(X, A)
        return X

3.3 ํŠธ๋ฆฌ LSTM (Tree LSTM)

Tree LSTM์€ LSTM์„ ํŠธ๋ฆฌ ๊ตฌ์กฐ์— ์ ์šฉํ•œ ๋ชจ๋ธ๋กœ, ๊ณ„์ธต์  ๊ตฌ์กฐ๋ฅผ ๊ฐ€์ง„ ๋ฐ์ดํ„ฐ์— ์ ํ•ฉํ•ฉ๋‹ˆ๋‹ค. ํŠธ๋ฆฌ ๊ตฌ์กฐ์—์„œ๋Š” ๋ถ€๋ชจ ๋…ธ๋“œ์™€ ์ž์‹ ๋…ธ๋“œ ๊ฐ„์˜ ๊ด€๊ณ„๋ฅผ ๋ชจ๋ธ๋งํ•ฉ๋‹ˆ๋‹ค.

Tree LSTM ๊ตฌ์กฐ

๊ฐ ๋…ธ๋“œ๋Š” ๋ถ€๋ชจ ๋…ธ๋“œ์™€ ์ž์‹ ๋…ธ๋“œ์˜ ์ƒํƒœ๋ฅผ ํ†ตํ•ฉํ•˜์—ฌ ์ž์‹ ์˜ ์ƒํƒœ๋ฅผ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ํŠธ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ๋”ฐ๋ผ ์ •๋ณด๋ฅผ ์ „ํŒŒํ•˜๊ณ , ํŠธ๋ฆฌ์˜ ๋ฃจํŠธ ๋…ธ๋“œ์—์„œ ์ „์ฒด ํŠธ๋ฆฌ์˜ ์ •๋ณด๋ฅผ ์ง‘์•ฝํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์—…๋ฐ์ดํŠธ ์‹

  • ์ž์‹ ๋…ธ๋“œ ์ƒํƒœ ์ง‘ํ•ฉ: C={hjโˆฃjโˆˆchildren(i)}C = {h_j j \in \text{children}(i)}C={hjโ€‹โˆฃjโˆˆchildren(i)}
  • ๋ถ€๋ชจ ๋…ธ๋“œ ์ƒํƒœ ์ง‘ํ•ฉ: P={hpโˆฃp=parent(i)}P = {h_p p = \text{parent}(i)}P={hpโ€‹โˆฃp=parent(i)}
  • ์ƒํƒœ ์—…๋ฐ์ดํŠธ: hi=LSTM(C,P,hi)h_i = \text{LSTM}(C, P, h_i)hiโ€‹=LSTM(C,P,hiโ€‹)

๊ฐ ๋…ธ๋“œ iii์˜ ์ƒํƒœ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์—…๋ฐ์ดํŠธ๋ฉ๋‹ˆ๋‹ค:

  • iii ๋…ธ๋“œ์˜ ์ž์‹ ๋…ธ๋“œ ์ƒํƒœ ํ•ฉ: hjh_jhjโ€‹๋Š” iii์˜ ์ž์‹ ๋…ธ๋“œ ์ƒํƒœ์ž…๋‹ˆ๋‹ค.
  • iii ๋…ธ๋“œ์˜ ๋ถ€๋ชจ ๋…ธ๋“œ ์ƒํƒœ ํ•ฉ: hph_phpโ€‹๋Š” iii์˜ ๋ถ€๋ชจ ๋…ธ๋“œ ์ƒํƒœ์ž…๋‹ˆ๋‹ค.
  • iii ๋…ธ๋“œ์˜ ์ƒˆ๋กœ์šด ์ƒํƒœ: hih_ihiโ€‹๋Š” LSTM์„ ํ†ตํ•ด ์—…๋ฐ์ดํŠธ๋ฉ๋‹ˆ๋‹ค.

Tree LSTM์˜ ํ•ต์‹ฌ ์ˆ˜์‹

\[i*t = \sigma(W\_i x\_t + \sum*{j \in \text{children}(i)} U*i h\_j + b\_i) f*{tj} = \sigma(W*f x\_t + \sum*{j \in \text{children}(i)} U*f h\_j + b\_f) o\_t = \sigma(W\_o x\_t + \sum*{j \in \text{children}(i)} U*o h\_j + b\_o) u\_t = \tanh(W\_u x\_t + \sum*{j \in \text{children}(i)} U*u h\_j + b\_u) c\_t = i\_t \odot u\_t + \sum*{j \in \text{children}(i)} f\_{tj} \odot c\_j h\_t = o\_t \odot \tanh(c\_t)\]

์—ฌ๊ธฐ์„œ it,ftj,ot,uti_t, f_{tj}, o_t, u_titโ€‹,ftjโ€‹,otโ€‹,utโ€‹๋Š” ๊ฐ๊ฐ ์ž…๋ ฅ ๊ฒŒ์ดํŠธ, ์žŠ์Œ ๊ฒŒ์ดํŠธ, ์ถœ๋ ฅ ๊ฒŒ์ดํŠธ, ์—…๋ฐ์ดํŠธ ๊ฒŒ์ดํŠธ๋ฅผ ๋‚˜ํƒ€๋‚ด๋ฉฐ, ctc_tctโ€‹์™€ hth_thtโ€‹๋Š” ๊ฐ๊ฐ ์…€ ์ƒํƒœ์™€ ํžˆ๋“  ์ƒํƒœ๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.

์ฝ”๋“œ ์˜ˆ์‹œ (PyTorch)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn
import torch.nn.functional as F

class TreeLSTMCell(nn.Module):
    def __init__(self, in_features, out_features):
        super(TreeLSTMCell, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.W_i = nn.Linear(in_features, out_features)
        self.U_i = nn.Linear(out_features, out_features)
        self.W_f = nn.Linear(in_features, out_features)
        self.U_f = nn.Linear(out_features, out_features)
        self.W_o = nn.Linear(in_features, out_features)
        self.U_o = nn.Linear(out_features, out_features)
        self.W_u = nn.Linear(in_features, out_features)
        self.U_u = nn.Linear(out_features, out_features)
    
    def forward(self, x, child_h, child_c):
        child_h_sum = torch.sum(child_h, dim=0)
        child_c_sum = torch.sum(child_c, dim=0)
        
        i = torch.sigmoid(self.W_i(x) + self.U_i(child_h_sum))
        f = torch.sigmoid(self.W_f(x) + self.U_f(child_h_sum))
        o = torch.sigmoid(self.W_o(x) + self.U_o(child_h_sum))
        u = torch.tanh(self.W_u(x) + self.U_u(child_h_sum))
        
        c = i * u + f * child_c_sum
        h = o * torch.tanh(c)
        
        return h, c

class TreeLSTM(nn.Module):
    def __init__(self, in_features, out_features):
        super(TreeLSTM, self).__init__()
        self.cell = TreeLSTMCell(in_features, out_features)
    
    def forward(self, x, children_h, children_c):
        h, c = self.cell(x, children_h, children_c)
        return h, c

3.4 ๊ทธ๋ž˜ํ”„ LSTM (Graph LSTM)

๊ทธ๋ž˜ํ”„ LSTM์€ LSTM์„ ์ผ๋ฐ˜ ๊ทธ๋ž˜ํ”„์— ํ™•์žฅํ•œ ๋ชจ๋ธ๋กœ, ๋…ธ๋“œ ์ˆœ์„œ๋ฅผ ๊ฒฐ์ •ํ•˜๋Š” ๋ฐฉ๋ฒ•์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ž˜ํ”„ ๊ตฌ์กฐ์—์„œ๋Š” ์ˆœํ™˜ ๊ฒฝ๋กœ๊ฐ€ ์žˆ์„ ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ, ์ˆœ์„œ๋ฅผ ์ •์˜ํ•˜๋Š” ๊ฒƒ์ด ํ•ต์‹ฌ์ž…๋‹ˆ๋‹ค.

๊ทธ๋ž˜ํ”„ LSTM ๊ตฌ์กฐ

๊ทธ๋ž˜ํ”„ LSTM์—์„œ๋Š” ๊ฐ ๋…ธ๋“œ๊ฐ€ ์ž์‹ ์˜ ์ด์›ƒ ๋…ธ๋“œ๋กœ๋ถ€ํ„ฐ ์ •๋ณด๋ฅผ ๋ฐ›์•„ ์ƒํƒœ๋ฅผ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค. ๋…ธ๋“œ ์ˆœ์„œ๋ฅผ ๊ฒฐ์ •ํ•˜๋Š” ๋‹ค์–‘ํ•œ ๋ฐฉ๋ฒ•์ด ์žˆ์œผ๋ฉฐ, ์ด๋Š” ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์— ์ค‘์š”ํ•œ ์˜ํ–ฅ์„ ๋ฏธ์นฉ๋‹ˆ๋‹ค.

์—…๋ฐ์ดํŠธ ์‹

๊ฐ ๋…ธ๋“œ iii์˜ ์ƒํƒœ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์—…๋ฐ์ดํŠธ๋ฉ๋‹ˆ๋‹ค:

\[h*i = \text{LSTM}(h*{\text{neighbors}(i)}, x\_i)\]

์—ฌ๊ธฐ์„œ hneighbors(i)h_{\text{neighbors}(i)}hneighbors(i)โ€‹๋Š” ๋…ธ๋“œ iii์˜ ์ด์›ƒ ๋…ธ๋“œ๋“ค์˜ ์ƒํƒœ๋ฅผ ๋‚˜ํƒ€๋‚ด๋ฉฐ, xix_ixiโ€‹๋Š” ๋…ธ๋“œ iii์˜ ์ž…๋ ฅ ํŠน์„ฑ์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.

๊ทธ๋ž˜ํ”„ LSTM์˜ ํ•ต์‹ฌ ์ˆ˜์‹

\[i*t = \sigma(W\_i x\_t + U\_i h*{t-1} + b*i) f\_t = \sigma(W\_f x\_t + U\_f h*{t-1} + b*f) o\_t = \sigma(W\_o x\_t + U\_o h*{t-1} + b*o) u\_t = \tanh(W\_u x\_t + U\_u h*{t-1} + b*u) c\_t = f\_t \odot c*{t-1} + i\_t \odot u\_t\] \[h\_t = o\_t \odot \tanh(c\_t)\]

์ฝ”๋“œ ์˜ˆ์‹œ (PyTorch)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class GraphLSTMCell(nn.Module):
    def __init__(self, in_features, out_features):
        super(GraphLSTMCell, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.W_i = nn.Linear(in_features, out_features)
        self.U_i = nn.Linear(out_features, out_features)
        self.W_f = nn.Linear(in_features, out_features)
        self.U_f = nn.Linear(out_features, out_features)
        self.W_o = nn.Linear(in_features, out_features)
        self.U_o = nn.Linear(out_features, out_features)
        self.W_u = nn.Linear(in_features, out_features)
        self.U_u = nn.Linear(out_features, out_features)
    
    def forward(self, x, neighbor_h, neighbor_c):
        neighbor_h_sum = torch.sum(neighbor_h, dim=0)
        neighbor_c_sum = torch.sum(neighbor_c, dim=0)
        
        i = torch.sigmoid(self.W_i(x) + self.U_i(neighbor_h_sum))
        f = torch.sigmoid(self.W_f(x) + self.U_f(neighbor_h_sum))
        o = torch.sigmoid(self.W_o(x) + self.U_o(neighbor_h_sum))
        u = torch.tanh(self.W_u(x) + self.U_u(neighbor_h_sum))
        
        c = i * u + f * neighbor_c_sum
        h = o * torch.tanh(c)
        
        return h, c

class GraphLSTM(nn.Module):
    def __init__(self, in_features, out_features):
        super(GraphLSTM, self).__init__()
        self.cell = GraphLSTMCell(in_features, out_features)
    
    def forward(self, x, neighbors_h, neighbors_c):
        h, c = self.cell(x, neighbors_h, neighbors_c)
        return h, c
  1. Graph Attention Networks (GAT)

Graph Attention Networks (GAT)๋Š” ์ฃผ์˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ๊ทธ๋ž˜ํ”„์— ์ ์šฉํ•œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. GAT๋Š” ์ž๊ธฐ ์ฃผ์˜(self-attention) ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ ๋…ธ๋“œ์˜ ์ด์›ƒ ๋…ธ๋“œ๋“ค์— ์„œ๋กœ ๋‹ค๋ฅธ ๊ฐ€์ค‘์น˜๋ฅผ ๋ถ€์—ฌํ•ฉ๋‹ˆ๋‹ค.

4.1 ์ฃผ์š” ํŠน์ง•

  • self-attention ๋ฉ”์ปค๋‹ˆ์ฆ˜: ๊ฐ ๋…ธ๋“œ์˜ ์ด์›ƒ ๋…ธ๋“œ๋“ค์— ๋Œ€ํ•ด ๊ฐ€์ค‘์น˜๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ, ์ค‘์š”ํ•œ ์ด์›ƒ ๋…ธ๋“œ์˜ ์ •๋ณด๋ฅผ ๋” ๋งŽ์ด ๋ฐ˜์˜ํ•ฉ๋‹ˆ๋‹ค.
  • multi-head-attention ๋ฉ”์ปค๋‹ˆ์ฆ˜: ์—ฌ๋Ÿฌ ์ฃผ์˜ ํ—ค๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์•ˆ์ •์„ฑ์„ ํ–ฅ์ƒ์‹œํ‚ค๊ณ , ๊ฐ ํ—ค๋“œ์˜ ์ถœ๋ ฅ์„ ๊ฒฐํ•ฉํ•˜์—ฌ ์ตœ์ข… ์ถœ๋ ฅ์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

4.2 GAT ๊ตฌํ˜„ ์˜ˆ์‹œ (PyTorch)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class GATLayer(nn.Module):
    def __init__(self, in_features, out_features, num_heads=1):
        super(GATLayer, self).__init__()
        self.num_heads = num_heads
        self.attention_heads = nn.ModuleList([nn.Linear(in_features, out_features) for _ in range(num_heads)])
        self.attention_coeffs = nn.ModuleList([nn.Linear(2 * out_features, 1) for _ in range(num_heads)])
    
    def forward(self, X, A):
        outputs = []
        for head, coeff in zip(self.attention_heads, self.attention_coeffs):
            H = head(X)
            N = X.size(0)
            H_repeated_in_chunks = H.repeat_interleave(N, dim=0)
            H_repeated_alternating = H.repeat(N, 1)
            all_combinations_matrix = torch.cat([H_repeated_in_chunks, H_repeated_alternating], dim=1)
            e = F.leaky_relu(coeff(all_combinations_matrix).view(N, N))
            attention = F.softmax(e.masked_fill(A == 0, -1e9), dim=1)
            outputs.append(torch.matmul(attention, H))
        return torch.cat(outputs, dim=1)

class GAT(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, num_heads=1):
        super(GAT, self).__init__()
        self.layer1 = GATLayer(in_features, hidden_features, num_heads)
        self.layer2 = GATLayer(hidden_features * num_heads, out_features, num_heads)
    
    def forward(self, X, A):
        X = self.layer1(X, A)
        X = self.layer2(X, A)
        return X
  1. ๊ฒฐ๋ก 

์ด ์ฑ•ํ„ฐ์—์„œ๋Š” Graph Neural Networks์˜ ์ฃผ์š” ์•Œ๊ณ ๋ฆฌ์ฆ˜์ธ GCN, GRN, GAT์— ๋Œ€ํ•ด ์‚ดํŽด๋ณด์•˜์Šต๋‹ˆ๋‹ค. ๊ฐ ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ๊ทธ๋ž˜ํ”„ ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ๊ณ ์œ ํ•œ ๋ฐฉ์‹์„ ๊ฐ€์ง€๊ณ  ์žˆ์œผ๋ฉฐ, ๋‹ค์–‘ํ•œ ๊ทธ๋ž˜ํ”„ ๊ด€๋ จ ์ž‘์—…์— ์ ์šฉ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  • GCN์€ ๋…ธ๋“œ์˜ ํŠน์„ฑ๊ณผ ๊ทธ๋ž˜ํ”„ ๊ตฌ์กฐ๋ฅผ ๋™์‹œ์— ๊ณ ๋ คํ•˜์—ฌ ํšจ๊ณผ์ ์ธ ๋…ธ๋“œ ํ‘œํ˜„์„ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.
  • GRN์€ ์‹œ๊ฐ„์  ๋˜๋Š” ๊ณ„์ธต์  ๊ตฌ์กฐ๋ฅผ ๊ฐ€์ง„ ๊ทธ๋ž˜ํ”„ ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ๋ฐ ์ ํ•ฉํ•ฉ๋‹ˆ๋‹ค.
  • GAT๋Š” ์ฃผ์˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ํ†ตํ•ด ์ด์›ƒ ๋…ธ๋“œ๋“ค์˜ ์ค‘์š”๋„๋ฅผ ํ•™์Šตํ•˜์—ฌ ๋” ์œ ์—ฐํ•œ ์ •๋ณด aggregation์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค.

์ด๋Ÿฌํ•œ Graph Neural Networks ์•Œ๊ณ ๋ฆฌ์ฆ˜๋“ค์€ ๊ฐ๊ฐ์˜ ์žฅ๋‹จ์ ์„ ๊ฐ€์ง€๊ณ  ์žˆ์œผ๋ฉฐ, ๋ฌธ์ œ์˜ ํŠน์„ฑ์— ๋”ฐ๋ผ ์ ์ ˆํ•œ ๋ชจ๋ธ์„ ์„ ํƒํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.

  • GCN์€ ๊ตฌํ˜„์ด ๊ฐ„๋‹จํ•˜๊ณ  ํšจ๊ณผ์ ์ด์ง€๋งŒ, ๊นŠ์€ ๋ ˆ์ด์–ด๋ฅผ ์Œ“๊ธฐ ์–ด๋ ต๋‹ค๋Š” ๋‹จ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค.
  • GRN์€ ์‹œํ€€์Šค ๋ฐ์ดํ„ฐ๋‚˜ ํŠธ๋ฆฌ ๊ตฌ์กฐ ๋ฐ์ดํ„ฐ์— ๊ฐ•์ ์„ ๋ณด์ด์ง€๋งŒ, ์ผ๋ฐ˜์ ์ธ ๊ทธ๋ž˜ํ”„์— ์ ์šฉํ•  ๋•Œ๋Š” ๋…ธ๋“œ ์ˆœ์„œ ๊ฒฐ์ •์ด ์–ด๋ ค์šธ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • GAT๋Š” ๋…ธ๋“œ ๊ฐ„์˜ ์ค‘์š”๋„๋ฅผ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์–ด ์œ ์—ฐ์„ฑ์ด ๋†’์ง€๋งŒ, ๊ณ„์‚ฐ ๋ณต์žก๋„๊ฐ€ ์ƒ๋Œ€์ ์œผ๋กœ ๋†’์Šต๋‹ˆ๋‹ค.


-->