[Paper Review] An Image Is Worth 16x16 Words : Transformers for Image Recognition at Scale (Vision Transformer)
์๋ณธ ๊ฒ์๊ธ: https://velog.io/@euisuk-chung/Paper-Review-An-Image-Is-Worth-16x16-Words-Transformers-for-Image-Recognition-at-Scale-Vision-Transformer
์ ์ ์ด์
์๋ ํ์ธ์! ์ค๋ ๋ ผ๋ฌธ๋ฆฌ๋ทฐ, ์ฝ๋๋ฆฌ๋ทฐํด๋ณผ ๋ ผ๋ฌธ์ โAn Image Is Worth 16x16 Words: Transformers for Image Recognition at Scaleโ ๋ก, ์ปดํจํฐ ๋น์ ์์ Transformer์ Attention์ด ์ฐ์ด๊ฒ ๋ ๊ฒฐ์ ์ ๊ณ๊ธฐ(?)๊ฐ ๋ ๋ ผ๋ฌธ์ ๋๋ค. ์ต๊ทผ ์ด์ชฝ ๋ถ์ผ์ ๊ด์ฌ์ด ๋ง๋ค ๋ณด๋ ์ค๋์ ์ด ๋ ผ๋ฌธ์ ๋ฆฌ๋ทฐํ๊ฒ ๋์์ต๋๋ค.
๋ ผ๋ฌธ๋ฆฌ๋ทฐ
Background
(Self) Attention
- Attention์ ๊ธฐ๋ณธ ์์ด๋์ด๋ ๋์ฝ๋์์ ์ถ๋ ฅ ๋จ์ด๋ฅผ ์์ธกํ๋ ๋งค ์์ (time step)๋ง๋ค, ์ธ์ฝ๋์์์ ์ ์ฒด ์ ๋ ฅ ๋ฌธ์ฅ์ ๋ค์ ํ ๋ฒ ์ฐธ๊ณ ํฉ๋๋ค.
- ๋จ, ์ ์ฒด ์ ๋ ฅ ๋ฌธ์ฅ์ ์ ๋ถ ๋ค ๋์ผํ ๋น์จ๋ก ์ฐธ๊ณ ํ๋ ๊ฒ์ด ์๋๋ผ, ํด๋น ์์ ์์ ์์ธกํด์ผ ํ ๋จ์ด์ ์ฐ๊ด์ด ์๋ ์ ๋ ฅ ๋จ์ด ๋ถ๋ถ์ ์ข ๋ ์ง์ค(attention)ํด์ ๋ณด๊ฒ ๋ฉ๋๋ค.
- ์ฃผ์ด์ง โ์ฟผ๋ฆฌ(Query)โ์ ๋ํด์ ๋ชจ๋ โํค(Key)โ์์ ์ ์ฌ๋๋ฅผ ๊ฐ๊ฐ ๊ตฌํฉ๋๋ค. - ๊ทธ๋ฆฌ๊ณ ๊ตฌํด๋ธ ์ ์ฌ๋๋ฅผ ๊ฐ์ค์น๋ก ํ์ฌ ํค์ ๋งตํ๋์ด ์๋ ๊ฐ๊ฐ์ โ๊ฐ(Value)โ์ ๋ฐ์ํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ์ ์ฌ๋๊ฐ ๋ฐ์๋ โ๊ฐ(Value)โ์ ๋ชจ๋ ๊ฐ์คํฉํ์ฌ ๋ฆฌํดํ๊ฒ ๋ฉ๋๋ค.
- โ์ฟผ๋ฆฌ(Query)โ, โํค(Key)โ, โ๊ฐ(Value)โ์ ์ ์๋ ์์ด๋ก ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
Transformer
- ๊ธฐ์กด RNN ๊ธฐ๋ฐ seq2seq ๋ชจ๋ธ์์๋ ์ด์ ์์ ์ ์ฐ์ฐ์ด ๋๋๊ธฐ ์ ์๋ ๋ค์ ์์ ์ ์ฐ์ฐ์ด ๋ถ๊ฐ๋ฅํ์ฌ ๋ณ๋ ฌํ(parallelize) ๋ ์ฐ์ฐ์ฒ๋ฆฌ๊ฐ ๋ถ๊ฐ๋ฅํ์ต๋๋ค.
- RNN๊ตฌ์กฐ๋ ๊ณ ์ง์ ๋ฌธ์ ์ธ Long-term dependency ๋ฌธ์ ๊ฐ ๋ฐ์ํ์๊ณ , ์ด๋ ๊ณง ํ์ ์คํ (time step)์ด ๊ธธ์ด์ง ์๋ก ์ํ์ค ์ฒ๋ฆฌ์ ์ฑ๋ฅ์ด ๋จ์ด์ง์ ์๋ฏธํฉ๋๋ค.
- ์ด๋ฌํ ๋ฌธ์ ์ ๋ค์ ๋ณด์ํ๊ธฐ ์ํด Attention๋ง์ผ๋ก Encoder, Decoder ๊ตฌ์กฐ๋ฅผ ๋ง๋ค์ด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ๋ ๋ชจ๋ธ์ ์ ์ํจ์ผ๋ก์จ ํ์ต ์๋๊ฐ ๋งค์ฐ ๋น ๋ฅด๋ฉฐ ์ฑ๋ฅ๋ ์ฐ์ํ Transformer๊ฐ ์ ์๋์์ต๋๋ค.
- ๋ฟ๋ง ์๋๋ผ ์ฌ๋ฌ๊ฐ์ Head๋ฅผ ์ฌ์ฉํ๋ Multi-head Attention์ ํตํด ๋ค์ํ aspect์ ๋ํด์ ๋ชจ๋ธ์ด ํ์ตํ ์ ์๋๋ก ํ์์ต๋๋ค.
Image Recognition (Classification)
- Image Recognition (Classification)์ ์ด๋ฏธ์ง๋ฅผ ์๊ณ ๋ฆฌ์ฆ์ ์ ๋ ฅ(input)ํด์ฃผ๋ฉด, ๊ทธ ์ด๋ฏธ์ง๊ฐ ์ํ๋ class lable์ ์ถ๋ ฅ(output)ํด์ฃผ๋ task๋ฅผ ์๋ฏธํฉ๋๋ค.
-
์๋ ๊ทธ๋ฆผ ์ฒ๋ผ ๊ณ ์์ด ์ฌ์ง์ ๋ฃ์ด์ฃผ๋ฉด ๊ณ ์์ด ๋ผ๊ณ ์ธ์(๋ถ๋ฅ)ํด๋ ๋๋ค.
- ์์ ๊ทธ๋ฆผ์ฒ๋ผ ์ฌ์ง๊ป CV(Computer Vision) ๋๋ฉ์ธ์์๋ CNN(Convolutional Neural Network)๋ฅผ ์ฌ์ฉํ ๋ชจ๋ธ๋ค์ด ๋ง์ด ์ฌ์ฉ๋์ด ์ค๊ณ ์์์ต๋๋ค. (Ex. ResNet, UNet, EfficientNet ๋ฑ)
- ํ์ง๋ง, NLP(Natural Language Processing) ๋๋ฉ์ธ์์์ Self-Attention๊ณผ Transformer์ ์ฑ์ฅ์ผ๋ก ์ธํด CNN๊ณผ Attention์ ํจ๊ป ์ด์ฉํ๋ ค๋ ์ถ์ธ๊ฐ ์ฆ๊ฐํ๊ณ ์์ต๋๋ค. ๋ณธ ๋ ผ๋ฌธ(์ฐ๊ตฌ) ์ญ์ ๊ทธ๋ฌํ ์๋ ์ค ํ๋์ ๋๋ค.
Vision Transformer
- Vision Transformer์ ๊ฐ๋ ์ Transformer๊ฐ ์ด๋ป๊ฒ ์๋ํ๋์ง ์๋ ์ฌ๋๋ค์ด๋ผ๋ฉด ์ฝ๊ฒ ์ ๊ทผํ ์ ์์ต๋๋ค. ์๋ ๊ทธ๋ฆผ์ด ์์งํ ๋ณธ ๋ ผ๋ฌธ์ ์ ๋ถ์ด๊ธฐ ๋๋ฌธ์ด์ฃ .
- Vision Transformer๋ image recognition task์ ์์ด์ Convolution์ ์์ ์์ ๊ณ , Transformer Encoder๋ง์ ์ฌ์ฉํ์์ต๋๋ค. ๊ฐ๊ฐ์ ์์๋ ์๋์ ๊ฐ์ต๋๋ค.
0. Prerequisites
- ํ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ IMPORT
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torch.nn as nn
from torch import Tensor
from torchsummary import summary
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.datasets as dset
import torchvision.transforms as T
from torchvision.transforms import Compose, Resize, ToTensor
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import numpy as np
import os
import copy
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
%pip install einops
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange, Reduce
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.long
- Define Helper Function
1
2
def pair(t):
return t if isinstance(t, tuple) else (t, t)
- Define
PreNorm
Class
1
2
3
4
5
6
7
8
9
# Define PreNorm
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
- Define
FeedForward
Class
1
2
3
4
5
6
7
8
9
10
11
12
13
# Define FeedForward
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
- Define
Attention
Class
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Define Attention
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
- Define
Transformer
Class
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Define Transformer
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
step 1. Splitting Image into fixed-size patches
- ๊ฐ์ฅ ๋จผ์ ์ด๋ฏธ์ง๋ฅผ ๊ณ ์ ๋ ์ฌ์ด์ฆ์ ํจ์น๋ค๋ก ๋ถํ ํ์ฌ ๋ชจ๋ธ์ ๋ฃ์ด์ค๋๋ค.
step 2. Linearly embed each patches
-
๊ฐ๊ฐ์ ์ด๋ฏธ์ง ํจ์น๋ค์ ๋ํด Linear Embedding์ ์ํํด์ค๋๋ค. (D์ฐจ์์ผ๋ก)
step 3. Add positional embedding
- ์ด์ ๊ฐ๊ฐ์ ์ด๋ฏธ์ง ํจ์น๋ค์ด ์ด๋ค ์์น์ ์๋๊ฐ์ ๋ํ ์ ๋ณด๋ ๋ชจ๋ธ์ ๋ฃ์ด์ฃผ์ด์ผ๊ฒ ์ฃ ? ์ด๋ฐ ์์น์ ๋ํ ์ ๋ณด๋ฅผ ์ฐ๋ฆฌ๋ position embedding์ด๋ผ๊ณ ํ๋ฉฐ, ์์์ ๊ตฌํ Embedding์ ๋ถ์ฌ์ฃผ๊ฒ ๋ฉ๋๋ค.
step 4. Feed embedding vector into Transformer Encoder
- ๊ฐ๊ฐ์ ์ด๋ฏธ์ง ํจ์น๋ค์ ๋ํ ์์น ์ ๋ณด์ ์๋ฐฐ๋ฉ ๊ฐ์ Transformer Encoder๋ก ๋ฃ์ด์ค๋๋ค. Transformer Encoder๋ ์๋ ๊ทธ๋ฆผ(์ฐ์ธก)๊ณผ ๊ฐ์ด ๊ตฌ์ฑ๋์ด ์์ต๋๋ค.
step 5. Use [CLS] token for Classification
- ์ด๋ฏธ ๋์น ์ฑ์ ๋ถ๋ค๋ ์๊ฒ ์ง๋ง, Transformer Encoder์ ๋ค์ด๊ฐ ๊ฐ๊ฐ์ ์ด๋ฏธ์ง ํจ์น๋ค์ ๋ํ ์์น ์ ๋ณด์ ์๋ฐฐ๋ฉ ๊ฐ ์ธ์๋ ์์
[0,*]
์ด ์๋ ๊ฒ์ ํ์ธํ์ค ์ ์๋ ๋ฐ ์ด๋ ์ ์ฒด ์ด๋ฏธ์ง์ ๋ชจ๋ ์ ๋ณด๋ฅผ ๋ด๊ณ ์๋ ํ ํฐ์ด๋ผ๊ณ ๋ณผ ์ ์์ต๋๋ค. (A.K.A.[CLS]
ํ ํฐ) ์ด๋ฒ ๋จ๊ณ์์๋ ์ด๋ฌํ[CLS]
ํ ํฐ์ ์ฌ์ฉํ์ฌ MLP(Multi-layer Perceptron)์ ํ์ Classificatin์ ์ํํ๊ฒ ๋ฉ๋๋ค.
๐ ์ฌ๊ธฐ์ ์ ๊น!
einops
๋ผ์ด๋ธ๋ฌ๋ฆฌ?
- Einstein notation ์ ๋ณต์กํ ํ ์ ์ฐ์ฐ์ ํ๊ธฐํ๋ ๋ฐฉ๋ฒ์ ๋๋ค. ๋ฅ๋ฌ๋์์ ์ฐ์ด๋ ๋ง์ ์ฐ์ฐ์ Einstein notation ์ผ๋ก ์ฝ๊ฒ ํ๊ธฐํ ์ ์์ต๋๋ค.
- einops (https://github.com/arogozhnikov/einops)๋ pytorch, tensorflow ๋ฑ ์ฌ๋ฌ ํ๋ ์์ํฌ๋ฅผ ๋์์ ์ง์ํ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก ์ด๋ฌํ Einstein notation์ ์ฌ์ฉํ ์ ์๊ฒํฉ๋๋ค.
๐ ์ฌ๊ธฐ์ ์ ๊น!
Rearrange
ํจ์?
- Rearrange ํจ์๋ shape๋ฅผ ์ฝ๊ฒ ๋ณํํด์ฃผ๋ ํจ์๋ผ๊ณ ์๊ฐํ๋ฉด ๋ฉ๋๋ค.
๋ฐ์ ๊ทธ๋ฆผ์ผ๋ก ์ด๋ป๊ฒ ์๋ํ๋์ง ์ง๊ด์ ์ผ๋ก ํ์ธํด๋ณด์์ฃ !
๐ ์ฌ๊ธฐ์ ์ ๊น!
einsum
ํจ์?
- Einsum ํ๊ธฐ๋ฒ์ ํน์ํ Domain Specific Language๋ฅผ ์ด์ฉํด ์ด ๋ชจ๋ ํ๋ ฌ, ์ฐ์ฐ์ ํ๊ธฐํ๋ ๋ฐฉ๋ฒ์ ๋๋ค.
- ์ฝ๊ฒ ๋งํด ์ฐ๋ฆฌ๊ฐ ๊ตฌํ๊ณ ์ถ์ ํ๋ ฌ ์ฐ์ฐ์ ์ง๊ด์ ์ผ๋ก ์ ์ํด์ ๊ตฌํ๊ฒ ํด์ฃผ๋ ํจ์์ ๋๋ค.
๋ช ๊ฐ์ง ์์๋ก ์ดํด๋ณด์์ฃ (given X(matrix), Y(matrix))
Transpose :
np.einsum("ij->ji", X)
Matrix sum :
np.einsum("ij->", X)
Matrix row sum :
np.einsum("ij->i", X)
Matrix col sum :
np.einsum("ij->j", X)
Matrix Multiplication :
np.einsum('ij,j->i', X, Y)
Batched Matrix Multiplication :
np.einsum('bik,bkj->bij', X, Y)
Vision Transformer ์ฝ๋
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# ViT Class
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__() # super()๋ก ๊ธฐ๋ฐ ํด๋์ค์ __init__ ๋ฉ์๋ ํธ์ถ
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
# assert ๋ฌธ : ๋ค์ ์กฐ๊ฑด์ด True๊ฐ ์๋๋ฉด AssertError๋ฅผ ๋ฐ์
# patch size ์กฐ๊ฑด
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
# pooling ์กฐ๊ฑด
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), # 3 -> 2
nn.Linear(patch_dim, dim), # Linear Projection
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) # position embedding ์ด๊ธฐํ
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) # [CLS] ํ ํฐ ์ด๊ธฐํ
self.dropout = nn.Dropout(emb_dropout) # Dropout ์ ์
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) # Transformer ์ ์ธ
self.pool = pool
self.to_latent = nn.Identity() # ๋๋ค์ฑ ์ ๊ฑฐ
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
Reference
- An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, ICLR 2019 - Alexey Dosovitskiy et. al.
- ๋ฅ ๋ฌ๋์ ์ด์ฉํ ์์ฐ์ด ์ฒ๋ฆฌ ์ ๋ฌธ - ์ ์์ค ์ธ 1์ธ (https://wikidocs.net/book/2155)
-
The Illustrated Transformer -
Jay Alammar (https://jalammar.github.io/illustrated-transformer)
- ViT Source Code - lucidrains (https://github.com/lucidrains/vit-pytorch/blob/64a2ef6462bde61db4dd8f0887ee71192b273692/vit_pytorch/vit.py)