[ํ์ดํ ์น] ํ์ดํ ์น๋ก CNN ๋ชจ๋ธ์ ๊ตฌํํด๋ณด์! (VGGNetํธ)
์๋ณธ ๊ฒ์๊ธ: https://velog.io/@euisuk-chung/ํ์ดํ ์น-ํ์ดํ ์น๋ก-CNN-๋ชจ๋ธ์-๊ตฌํํด๋ณด์-VGGNetํธ
์๋
ํ์ธ์! ์ค๋ ํฌ์คํ
๋ถํฐ ๋ค์๋ค์ ํฌ์คํ
๊น์ง๋ CNN ๋ชจ๋ธ์ ๋ผ๋๊ฐ ๋๋ ๋ชจ๋ธ๋ค์ธ VGGNet, GoogleNet, ResNet์ ์๊ฐํ๊ณ ์ด๋ฅผ ๊ตฌํํด๋ณด๋ ์๊ฐ์ ๊ฐ๋๋ก ํ๊ฒ ์ต๋๋ค! :) ์ด๋ฒ ํฌ์คํ
์ VGGNet
๊ด๋ จ ํฌ์คํธ์
๋๋ค.
๋จผ์ ILSVRC (Imagenet Large Scale Visual Recognition Challenges)
์ด๋ผ๋ ๋ํ๊ฐ ์๋๋ฐ, ๋ณธ ๋ํ๋ ๊ฑฐ๋ ์ด๋ฏธ์ง๋ฅผ 1000๊ฐ์ ์๋ธ์ด๋ฏธ์ง๋ก ๋ถ๋ฅํ๋ ๊ฒ์ ๋ชฉ์ ์ผ๋ก ํฉ๋๋ค. ์๋ ๊ทธ๋ฆผ์ CNN๊ตฌ์กฐ์ ๋์คํ๋ฅผ ์ด๋์๋ ์ด์ฐฝ๊ธฐ ๋ชจ๋ธ๋ค๋ก AlexNet (2012) - VGGNet (2014) - GoogleNet (2014) - ResNet (2015) ์์ผ๋ก ๊ณ๋ณด๋ฅผ ์ด์ด๋๊ฐ์ต๋๋ค.
Source : https://icml.cc/2016/tutorials/
์์ ๊ทธ๋ฆผ์์ layers๋ CNN layer์ ๊ฐ์(๊น์ด)๋ฅผ ์๋ฏธํ๋ฉฐ ์ง๊ด์ ์ธ ์ดํด๋ฅผ ์ํด์ ์๋์ฒ๋ผ ๊ทธ๋ฆผ์ ๊ทธ๋ ค๋ณด์์ต๋๋ค.
VGGNet ๊ฐ์
์๊ฐ
VGGNet์ด ์๊ฐ๋ ๋ ผ๋ฌธ์ ์ ๋ชฉ์ Very deep convolutional networks for large-scale image recognition๋ก, ๋ค์ ๋งํฌ์์ ํ์ธํด๋ณด์ค ์ ์์ต๋๋ค. ๋งํฌ
VGGNet์ ์ ๊ฒฝ๋ง์ ๊น์ด๊ฐ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ๋ฏธ์น๋ ์ํฅ์ ์กฐ์ฌํ๊ธฐ ์ํด ํด๋น ์ฐ๊ตฌ๋ฅผ ์์ํ์์ผ๋ฉฐ, ์ด๋ฅผ ์ฆ๋ช ํ๊ธฐ ์ํด 3x3 convolution์ ์ด์ฉํ Deep CNNs๋ฅผ ์ ์ํ์์ต๋๋ค. VGGNet์ ILSVRC-2014 ๋ํ์์ GoogLeNet์ ์ด์ด 2๋ฑ์ ์ฐจ์งํ์์ผ๋, GoogLeNet์ ๋นํด ํจ์ฌ ๊ฐ๋จํ ๊ตฌ์กฐ๋ก ์ธํด 1๋ฑ์ธ ๋ชจ๋ธ๋ณด๋ค ๋์ฑ ๋๋ฆฌ ์ฌ์ฉ๋์๋ค๋ ํน์ง์ ๊ฐ๊ณ ์์ต๋๋ค.
์คํ์ค๊ณ
๋ชจ๋ธ์ 3x3 convolution, Max-pooling, Fully Connected Network 3๊ฐ์ง ์ฐ์ฐ์ผ๋ก๋ง ๊ตฌ์ฑ์ด ๋์ด ์์ผ๋ฉฐ ์๋ ํ์ ๊ฐ์ด A, A-LRN, B, C, D, E 5๊ฐ์ง ๋ชจ๋ธ์ ๋ํด ์คํ์ ์งํํ์์ต๋๋ค.
์ด๋ ์ฌ์ฉํ ๊ฐ๊ฐ์ window_size์ activation function์ ์ค์ ์ ์๋์ ๊ฐ์ต๋๋ค.
- 3x3 convolution filters (stride: 1)
- 2x2 Max pooling (stride : 2)
- Activation function : ReLU
๐ข ์ฌ๊ธฐ์ ์ ๊น!
์ ํ์์ conv3-64๋ผ๊ณ ์จ์๋ ๊ฒ์ 3x3์ window_size๋ฅผ ๊ฐ๊ณ ์ฌ์ฉํ window์ ๊ฐ์๊ฐ 64๊ฐ์์ ์๋ฏธํฉ๋๋ค.
์ฑ๋ฅ
์๋ ์ฑ๋ฅํ๋ฅผ ํตํด ์ฐ๋ฆฌ๋ ๊น์ด๊ฐ ๊น์ด์ง์๋ก ๋ชจ๋ธ์ ์ฑ๋ฅ์ด ์ข์์ง๋ ๊ฒ๊ณผ Local Response Normalization(LRN)์ ์ฑ๋ฅ์ ํฐ ์ํฅ์ ์ฃผ์ง ์๋๋ค๋ ์ฌ์ค์ ๋ฐ๊ฒฌํ ์ ์์ต๋๋ค.
VGGNet ๊ตฌํ
๊ทธ๋ผ VGGNet์ ๊ฐ์๋ฅผ ์ดํด๋ดค์ผ๋ ์ด๋ฒ์๋ ์ด๋ฅผ ๊ตฌํํด๋ณผ๊น์? ๊ตฌํ์ ์ ์คํ ์ค๊ณ ํ์ D์ด์ ์ ํ ์ ๊ตฌํํด๋ณด์์ต๋๋ค. ๋ค์ ํ๋ฒ ์ค๊ธ๋ก ํด๋น ๊ตฌ์กฐ๋ฅผ ์ค๋ช ํ์๋ฉด ์๋์ ๊ฐ์ต๋๋ค.
- 3x3 ํฉ์ฑ๊ณฑ ์ฐ์ฐ x2 (์ฑ๋ 64)
- 3x3 ํฉ์ฑ๊ณฑ ์ฐ์ฐ x2 (์ฑ๋ 128)
- 3x3 ํฉ์ฑ๊ณฑ ์ฐ์ฐ x3 (์ฑ๋ 256)
- 3x3 ํฉ์ฑ๊ณฑ ์ฐ์ฐ x3 (์ฑ๋ 512)
- 3x3 ํฉ์ฑ๊ณฑ ์ฐ์ฐ x3 (์ฑ๋ 512)
-
FC layer x3
- FC layer 4096
- FC layer 4096
- FC layer 1000
์ฝ๋ฉ์ ํธ์๋ฅผ ์ํด ๊ฐ๊ฐ conv layer๊ฐ 2๊ฐ ์๋ block๊ณผ 3๊ฐ ์๋ block์ ๋ฐ๋ก ์ ์ธํ๋๋ก ํ๊ฒ ์ต๋๋ค.
conv_2_block
1
2
3
4
5
6
7
8
9
def conv_2_block(in_dim,out_dim):
model = nn.Sequential(
nn.Conv2d(in_dim,out_dim,kernel_size=3,padding=1),
nn.ReLU(),
nn.Conv2d(out_dim,out_dim,kernel_size=3,padding=1),
nn.ReLU(),
nn.MaxPool2d(2,2)
)
return model
conv_3_block
1
2
3
4
5
6
7
8
9
10
11
def conv_3_block(in_dim,out_dim):
model = nn.Sequential(
nn.Conv2d(in_dim,out_dim,kernel_size=3,padding=1),
nn.ReLU(),
nn.Conv2d(out_dim,out_dim,kernel_size=3,padding=1),
nn.ReLU(),
nn.Conv2d(out_dim,out_dim,kernel_size=3,padding=1),
nn.ReLU(),
nn.MaxPool2d(2,2)
)
return model
Define VGG16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class VGG(nn.Module):
def __init__(self, base_dim, num_classes=10):
super(VGG, self).__init__()
self.feature = nn.Sequential(
conv_2_block(3,base_dim), #64
conv_2_block(base_dim,2*base_dim), #128
conv_3_block(2*base_dim,4*base_dim), #256
conv_3_block(4*base_dim,8*base_dim), #512
conv_3_block(8*base_dim,8*base_dim), #512
)
self.fc_layer = nn.Sequential(
# CIFAR10์ ํฌ๊ธฐ๊ฐ 32x32์ด๋ฏ๋ก
nn.Linear(8*base_dim*1*1, 4096),
# IMAGENET์ด๋ฉด 224x224์ด๋ฏ๋ก
# nn.Linear(8*base_dim*7*7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 1000),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(1000, num_classes),
)
def forward(self, x):
x = self.feature(x)
#print(x.shape)
x = x.view(x.size(0), -1)
#print(x.shape)
x = self.fc_layer(x)
return x
model, loss, optimizer ์ ์ธ
1
2
3
4
5
6
7
8
9
# device ์ค์
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# VGG ํด๋์ค๋ฅผ ์ธ์คํด์คํ
model = VGG(base_dim=64).to(device)
# ์์คํจ์ ๋ฐ ์ต์ ํํจ์ ์ค์
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
load CIFAR10 dataset
- CIFAR10์ โ๋นํ๊ธฐ(airplane)โ, โ์๋์ฐจ(automobile)โ, โ์(bird)โ, โ๊ณ ์์ด(cat)โ, โ์ฌ์ด(deer)โ, โ๊ฐ(dog)โ, โ๊ฐ๊ตฌ๋ฆฌ(frog)โ, โ๋ง(horse)โ, โ๋ฐฐ(ship)โ, โํธ๋ญ(truck)โ๋ก 10๊ฐ์ ํด๋์ค๋ก ๊ตฌ์ฑ๋์ด ์๋ ๋ฐ์ดํฐ์ ์ ๋๋ค.
- CIFAR10์ ํฌํจ๋ ์ด๋ฏธ์ง์ ํฌ๊ธฐ๋
3x32x32
๋ก, ์ด๋32x32
ํฝ์ ํฌ๊ธฐ์ ์ด๋ฏธ์ง๊ฐ 3๊ฐ ์ฑ๋(channel)์ ์์๋ก ์ด๋ค์ ธ ์๋ค๋ ๊ฒ์ ๋ปํฉ๋๋ค.
TRAIN/TEST ๋ฐ์ดํฐ์ ์ ์
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# Transform ์ ์
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# CIFAR10 TRAIN ๋ฐ์ดํฐ ์ ์
cifar10_train = datasets.CIFAR10(root="../Data/", train=True, transform=transform, target_transform=None, download=True)
# CIFAR10 TEST ๋ฐ์ดํฐ ์ ์
cifar10_test = datasets.CIFAR10(root="../Data/", train=False, transform=transform, target_transform=None, download=True)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
TRAIN ๋ฐ์ดํฐ์ ์๊ฐํ
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import matplotlib.pyplot as plt
import numpy as np
# ์ด๋ฏธ์ง๋ฅผ ๋ณด์ฌ์ฃผ๊ธฐ ์ํ ํจ์
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# ํ์ต์ฉ ์ด๋ฏธ์ง๋ฅผ ๋ฌด์์๋ก ๊ฐ์ ธ์ค๊ธฐ
dataiter = iter(train_loader)
images, labels = dataiter.next()
# ์ด๋ฏธ์ง ๋ณด์ฌ์ฃผ๊ธฐ
imshow(torchvision.utils.make_grid(images))
# ์ ๋ต(label) ์ถ๋ ฅ
print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))
Source : https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
TRAIN & TEST
์ด์ ๋ฐ์ดํฐ์ ๋ ์ ์ํด์คฌ์ผ๋ ๋ณธ๊ฒฉ์ ์ผ๋ก ํ์ต ๋ฐ ๊ฒ์ฆ์ ์ํํด ๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค. ํ์ต ์ค์ ์ ๋ค์๊ณผ ๊ฐ์ด ์ ์ํด์ฃผ์์ต๋๋ค.
1
2
3
batch_size = 100
learning_rate = 0.0002
num_epoch = 100
TRAIN
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
loss_arr = []
for i in trange(num_epoch):
for j,[image,label] in enumerate(train_loader):
x = image.to(device)
y_= label.to(device)
optimizer.zero_grad()
output = model.forward(x)
loss = loss_func(output,y_)
loss.backward()
optimizer.step()
if i % 10 ==0:
print(loss)
loss_arr.append(loss.cpu().detach().numpy())
loss ์๊ฐํ
1
2
plt.plot(loss_arr)
plt.show()
test ๊ฒฐ๊ณผ
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# ๋ง์ ๊ฐ์, ์ ์ฒด ๊ฐ์๋ฅผ ์ ์ฅํ ๋ณ์๋ฅผ ์ง์ ํฉ๋๋ค.
correct = 0
total = 0
model.eval()
# ์ธํผ๋ฐ์ค ๋ชจ๋๋ฅผ ์ํด no_grad ํด์ค๋๋ค.
with torch.no_grad():
# ํ
์คํธ๋ก๋์์ ์ด๋ฏธ์ง์ ์ ๋ต์ ๋ถ๋ฌ์ต๋๋ค.
for image,label in test_loader:
# ๋ ๋ฐ์ดํฐ ๋ชจ๋ ์ฅ์น์ ์ฌ๋ฆฝ๋๋ค.
x = image.to(device)
y= label.to(device)
# ๋ชจ๋ธ์ ๋ฐ์ดํฐ๋ฅผ ๋ฃ๊ณ ๊ฒฐ๊ณผ๊ฐ์ ์ป์ต๋๋ค.
output = model.forward(x)
_,output_index = torch.max(output,1)
# ์ ์ฒด ๊ฐ์ += ๋ผ๋ฒจ์ ๊ฐ์
total += label.size(0)
correct += (output_index == y).sum().float()
# ์ ํ๋ ๋์ถ
print("Accuracy of Test Data: {}%".format(100*correct/total))
Accuracy of Test Data: 82.33999633789062%
๊ธด ๊ธ ์ฝ์ด์ฃผ์ ์ ๊ฐ์ฌํฉ๋๋ค ^~^