[ํ์ดํ ์น] ํ์ดํ ์น๋ก CNN ๋ชจ๋ธ์ ๊ตฌํํด๋ณด์! (ResNetํธ)
์๋ณธ ๊ฒ์๊ธ: https://velog.io/@euisuk-chung/ํ์ดํ ์น-ํ์ดํ ์น๋ก-CNN-๋ชจ๋ธ์-๊ตฌํํด๋ณด์-ResNetํธ
์๋
ํ์ธ์! ์ง๋๋ฒ ํฌ์คํธ์ธ VGGNet๊ณผ GoogleNet ์ดํ๋ก ์ค๋์ ResNet
๊ด๋ จ ํฌ์คํธ์
๋๋ค.
2๋ฒ์ ๊ฑธ์น ํฌ์คํ
์์ ์๊ฐ๋๋ ธ๋ค์ํผ ์ปดํจํฐ ๋น์ ๋ํ ์ค์ ILSVRC (Imagenet Large Scale Visual Recognition Challenges)
์ด๋ผ๋ ๋ํ๊ฐ ์๋๋ฐ, ๋ณธ ๋ํ๋ ๊ฑฐ๋ ์ด๋ฏธ์ง๋ฅผ 1000๊ฐ์ ์๋ธ์ด๋ฏธ์ง๋ก ๋ถ๋ฅํ๋ ๊ฒ์ ๋ชฉ์ ์ผ๋ก ํฉ๋๋ค. ์๋ ๊ทธ๋ฆผ์ CNN๊ตฌ์กฐ์ ๋์คํ๋ฅผ ์ด๋์๋ ์ด์ฐฝ๊ธฐ ๋ชจ๋ธ๋ค๋ก AlexNet (2012) - VGGNet (2014) - GoogleNet (2014) - ResNet (2015) ์์ผ๋ก ๊ณ๋ณด๋ฅผ ์ด์ด๋๊ฐ์ต๋๋ค.
Source : https://icml.cc/2016/tutorials/
์์ ๊ทธ๋ฆผ์์ layers๋ CNN layer์ ๊ฐ์(๊น์ด)๋ฅผ ์๋ฏธํ๋ฉฐ ์ง๊ด์ ์ธ ์ดํด๋ฅผ ์ํด์ ์๋์ฒ๋ผ ๊ทธ๋ฆผ์ ๊ทธ๋ ค๋ณด์์ต๋๋ค.
ResNet ๊ฐ์
์๊ฐ
ResNet์ด ์๊ฐ๋ ๋ ผ๋ฌธ์ ์ ๋ชฉ์ Going Deeper with Convolutions๋ก, ๋ค์ ๋งํฌ์์ ํ์ธํด๋ณด์ค ์ ์์ต๋๋ค. (๋งํฌ)
ResNet์ ์ ์๋ค์ ์ผ์ ์์ค ์ด์์ ๊น์ด๊ฐ ๋๋ฉด ์คํ๋ ค ์์ ๋ชจ๋ธ๋ณด๋ค ๊น์ ๋ชจ๋ธ์ ์ฑ๋ฅ์ด ๋ ๋จ์ด์ง๋ค๋ ๊ฒ์ ์๋ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด ํ์ธํ ์ ์์์ต๋๋ค.
Plane network 20-layer์ 56-layer์ train error์ test error (๋ ผ๋ฌธ ๋ฐ์ท)
์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์์ฐจ ํ์ต(residual learning)์ด๋ผ๋ ๋ฐฉ๋ฒ์ ํตํด ๋ชจ๋ธ ์ฑ๋ฅ์ ํฅ์์ํจ ๊ฒ์ด ๋ฐ๋ก ResNet
์
๋๋ค. ์์ด๋์ด๋ ์ ๋ง ์ฌํํ๋ฐ์. ํน์ ์์น์์ ์
๋ ฅ์ด ๋ค์ด์์ ๋ ํฉ์ฑ๊ณฑ ์ฐ์ฐ์ ํต๊ณผํ ๊ฒฐ๊ณผ์ ์
๋ ฅ์ผ๋ก ๋ค์ด์จ ๊ฒฐ๊ณผ ๋๊ฐ์ง๋ฅผ ๋ํด์ ๋ค์ ๋ ์ด์ด์ ์ ๋ฌํ๋ ๊ฒ์ด ResNet์ ํต์ฌ์
๋๋ค. (์๋ ๊ทธ๋ฆผ ์ฐธ๊ณ )
Residual Learning (๋ ผ๋ฌธ ๋ฐ์ท)
์ ๊ทธ๋ฆผ์์ ๋ณผ ์ ์๋ค์ํผ ์์ฐจ ํ์ต์ ์ด์ ๋จ๊ณ์์ ๋ฝ์๋ ํน์ฑ๋ค์ ๋ณํ์ํค์ง ์๊ณ , ๊ทธ๋๋ก ๋ค์ ๋จ๊ณ๋ก ์ ๋ฌํ์ฌ ๋ํด์ฃผ๊ธฐ ๋๋ฌธ์ ์์์ ํ์ตํ low-level ํน์ง๊ณผ ๋ค์์ ํ์ตํ high-level ํน์ง์ ๋ชจ๋ ๋ค์ block(๋จ๊ณ)๋ก ์ ๋ฌํ ์ ์๋ค๋ ์ฅ์ ์ ๊ฐ์ง๊ณ ์์ต๋๋ค. ์ด์ GoogleNet์ ๊ฒฝ์ฐ, Neural Network์ Vanishing Gradient ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด Auxilary Classifier๋ฅผ ์ฌ์ฉํ์์ต๋๋ค. ํ์ง๋ง, ResNet์ ๊ฒฝ์ฐ ๋ํ๊ธฐ ์ฐ์ฐ์ ๊ธฐ์ธ๊ธฐ๊ฐ 1์ด๊ธฐ ๋๋ฌธ์ ์ญ์ ํ ์ loss๊ฐ ์ค์ง ์๊ณ , ๋ชจ๋ธ ์๊น์ง ์ ์ ํ๊ฐ ๋๋ค๋ ํน์ง์ด ์์ด์ GoogleNet๊ณผ๋ ๋ค๋ฅด๊ฒ Auxilary Classifier๊ฐ ๋ณ๋๋ก ํ์ํ์ง ์์ต๋๋ค.
๋ชจ๋ธ ๊ตฌ์กฐ
Overall Network
๋ ผ๋ฌธ์์๋ VGG-19, 34-layer Plain (without residual) ๋ชจ๋ธ๊ณผ 34-layer Residual ๋ชจ๋ธ์ ๋ค์๊ณผ ๊ฐ์ด ์๊ฐํํ๊ณ ์์ต๋๋ค.
VGG-19, 34-layer Plain & Residual (๋ ผ๋ฌธ ๋ฐ์ท)
์ ๊ทธ๋ฆผ์์ ์ค์ ์ featuremap์ dimension์ด ๋ฐ๋์ง ์์ ๊ทธ๋ฅ ๋ํด์ฃผ๋ ๊ฒฝ์ฐ์ด๊ณ , ์ ์ ์ ์ ๋ ฅ๋จ๊ณผ ์ถ๋ ฅ๋จ์ dimension์ ์ฐจ์ด๋ก ์ธํด ์ด๋ฅผ ๋ง์ถฐ์ค ์ ์๋ ํ ํฌ๋์ด ์ถ๊ฐ์ ์ผ๋ก ๋ํด์ง shortcut connection์ ๋๋ค.
์๋ํ๋ ๋ ผ๋ฌธ์์ ์ ์ํ๋ ๋ค์ํ ์ ํ์ ResNet๊ตฌ์กฐ๋ค์ ๋๋ค. ์ ๊ทธ๋ฆผ์ ์์๋ ์๋ ๊ทธ๋ฆผ์์ 34-layer ๋ชจ๋ธ๊ณผ ๋์ผํฉ๋๋ค.
ResNet 19, 34, 50, 101, 152 layer
Plain Network
Plain Network์ ๋ค์๊ณผ ๊ฐ์ ๊ท์น์ ๋ฐ๋ผ ๋ง๋ค์ด์ก์ต๋๋ค:
- ๊ฐ์ ํฌ๊ธฐ์ output feature map์ ๊ฐ๊ณ ์๋ค๋ฉด, ๊ฐ์ ์์ filters๋ฅผ ๊ฐ๋๋ก ํฉ๋๋ค.
- ๋ง์ฝ feature map size๊ฐ ๋ฐ์ผ๋ก ์ค์ด๋ค์๋ค๋ฉด, time-complexity๋ฅผ ์ ์งํ๊ธฐ ์ํด filters์ ์๋ ๋ ๋ฐฐ๊ฐ ๋๋๋ก ํฉ๋๋ค.
- Downsampling์ ํ๊ธฐ ์ํด์ stride๊ฐ 2์ธ conv layers๋ฅผ ํต๊ณผ์์ผ์ค๋๋ค.
- 1x1 convolution์ ๊ฒฝ์ฐ, ๋์ผํ ์ฌ์ด์ฆ์ feature map์ ์ ์งํ๊ธฐ ์ํด ๋ณ๋์ padding์ด ํ์์์ต๋๋ค.
- ํ์ง๋ง, 3x3 convolution์ ๊ฒฝ์ฐ, ๋์ผํ ์ฌ์ด์ฆ์ feature map์ ์ ์งํ๊ธฐ ์ํด size 1์ padding์ด ํ์ํ๊ฒ ๋ฉ๋๋ค.
- Network์ ๋ง์ง๋ง ๋จ์๋ Global Average Pooling(GAP)๋ฅผ ์ํํ๋ฉฐ, ImageNet Classification์ ๋ชฉ์ ์ผ๋ก ํ๊ธฐ ๋๋ฌธ์ 1000-way-fully-connected layer๋ก ์ด๋ฃจ์ด์ ธ ์์ต๋๋ค.
Residual Network
Residual Network
Residual Network(ResNet)์ ๊ธฐ๋ณธ์ ์ธ ์กฐ๊ฑด์ ์์ plain network์ ๋์ผํฉ๋๋ค. ํ๊ฐ์ง ๋ค๋ฅธ ์ ์ ๊ฐ๊ฐ์ block๋ค์ด ๋๋ ๋๋ง๋ค shortcut connection ์ถ๊ฐ๋๋ค๋ ์ ์ ๋๋ค.
- input๊ณผ output์ ์ฐจ์์ด ๊ฐ๋ค๋ฉด identity shortcut์ ๋ฐ๋ก ์ฌ์ฉ๋ ์ ์์ต๋๋ค. (1)
- ํ์ง๋ง, ์ฐจ์์ด ๋ค๋ฅด๋ค๋ฉด identity shortcut์ ๋ฐ๋ก ์ฌ์ฉํ ์ ์์ต๋๋ค. Identity Shortcut ๋์ Projection Shortcut์ด ์ฌ์ฉ๋๊ฒ ๋ฉ๋๋ค. (By using 1x1 Convolution) (2)
Shortcuts Comparison
- ํด๋น ๋ ผ๋ฌธ์์๋ shortcut์ ์ฌ์ฉ๋ฐฉ๋ฒ์ ๋ฐ๋ฅธ ์ฑ๋ฅ์ ์๋์ ๊ฐ์ด ๋น๊ตํฉ๋๋ค.
- (A) Increasing Dimension์ Zero Padding์ ํ์ฉํ Shortcut์ ์ฌ์ฉ
- (B) Increasing Dimension์ Projection Shortcut์ ์ฌ์ฉ
- (C) ๋ชจ๋ Shortcut์ Projection Shortcut์ผ๋ก ๋์ฒดํ์ฌ ์ฌ์ฉ
Table 3์ ๋ณด๋ฉด 3๊ฐ์ง ์ต์ ๋ชจ๋ Plain Network๋ณด๋ค ์ฑ๋ฅ์ด ์ข์ผ๋ฉฐ, A < B < C์์ผ๋ก ์ฑ๋ฅ์ด ์ข์ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค. ๋ ผ๋ฌธ์์๋ B๊ฐ A๋ณด๋ค ๋์ ์ด์ ๋ฅผ A์ zero-padding๊ณผ์ ์ residual learning์ด ์๊ธฐ ๋๋ฌธ์ด๋ผ๊ณ ๋งํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ , C๊ฐ B๋ณด๋ค ์ข์ ์ด์ ๋ก๋ extra parameters๊ฐ ๋ ๋ง๊ธฐ ๋๋ฌธ์ ์ด๋ ์ฑ๋ฅ ํฅ์์ผ๋ก ์ด์ด์ก๋ค๊ณ ์ด์ผ๊ธฐ ํฉ๋๋ค.
A, B, C์์์ ์์ ์ฐจ์ด๋ฅผ ํตํด ์ ์ ์๋ ๊ฒ์ Projection Shortcut์ ๋ณธ ๋ ผ๋ฌธ์์ ๋ฌธ์ ์ผ๊ณ ์๋ degradation ๋ฌธ์ ๋ฅผ address ํ๋ ๊ฒ์ ๋ณธ์ง์ด ์๋๋ผ๋ ๊ฒ์ ๋ณด์ฌ์ค๋๋ค. ๋ํ, extra parameter๊ฐ ์ถ๊ฐ๋๋ C๋ memory & time complexity ๋ฅผ ์ค์ด๊ธฐ ์ํด ์ฌ์ฉ๋์ง ์์์ต๋๋ค.
Deeper Bottleneck Architecture
๋ณธ ๋ ผ๋ฌธ์์ ์ ์๋ค์ Layer ๊ฐ ๊น์ด์ง๋ฉด training time ์ด ์ฆ๊ฐํ๋ ๊ฒ์ ๋ฐ๊ฒฌํ์๊ณ , ์ด๋ฅผ ๊ณ ๋ คํ์ฌ Residual Block์ ์๋์ ๊ฐ์ด 1x1 Convolution์ ํ์ฉํ์ฌ ๊ฐ์ ํ Bottleneck Block์ ์ ์ํ์์ต๋๋ค.
Bottleneck Block์ 1x1, 3x3, 1x1 convolution์ผ๋ก ๊ตฌ์ฑ๋ 3๊ฐ์ Layer๋ฅผ ์์ ๊ตฌ์กฐ๋ก, Basic Block ๋ณด๋ค Layer ์๊ฐ 1๊ฐ ๋ ๋ง์ง๋ง, time complexity๋ ๋น์ทํ๋ค๋ ํน์ง์ ๊ฐ๊ณ ์์ต๋๋ค. ์ด๋ Bottleneck Block์๋ ์์์ ์๊ฐํ ์ต์ B๋ฅผ ์ ์ฉํ์์ต๋๋ค.
์ด๋ฐ ๋ฐฉ๋ฒ์ ์ ์ฉํ์ฌ ๊น์ ๋ชจ๋ธ(50-layer, 101-layer, 152-layer์ ์ ์ฉํด๋ณธ ๊ฒฐ๊ณผ, ๊ธฐ์กด์ degradation์ ๋ฌธ์ ๊ฐ ๋ฐ์ํ์ง ์๊ณ , ๊น์ด๊ฐ ๋ ๊น์ด์ง์ ๋ฐ๋ผ ๋ ์ข์์ง๋ ๊ฒ์ ํ์ธํ ์ ์์์ต๋๋ค.
์คํ
CIFAR 10
๋จผ์ CIFAR10 ๋ฐ์ดํฐ์ ๋ํ์ฌ ์คํํ ๊ฒฐ๊ณผ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค. ์ ์ ์ Training Error๋ฅผ ์๋ฏธํ๊ณ , ์ค์ ์ Test Error๋ฅผ ์๋ฏธํฉ๋๋ค.
- Figure 6์ ์ข์ธก์ ์๋ ๊ทธ๋ํ๋ residual ์ฐ์ฐ์ ์ฌ์ฉํ์ง ์์ plain network๋ฅผ ์ฌ์ฉํ์ ๋์ Error์ ๋๋ค. ์ด๋ฅผ ์ดํด๋ณด๋ฉด layer๊ฐ ๊น์ ์๋ก Error๊ฐ ๋์ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค. (Degradation ๋ฌธ์ )
- Figure 6์ ์ค์์ ์๋ ๊ทธ๋ํ๋ residual ์ฐ์ฐ์ ์ฌ์ฉํ residual network๋ฅผ ์ฌ์ฉํ์ ๋์ Error์ ๋๋ค. ์ด๋ฅผ ์ดํด๋ณด๋ฉด layer๊ฐ ๊น์ ์๋ก Error๊ฐ ๋ฎ์ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
- Figure 6์ ์ฐ์ธก์ ์๋ ๊ทธ๋ํ๋ 1202-layer residual network์ 110-layer residual network๋ก, ์ ์ฌํ training error ๋ณด์์ง๋ง test ์ฑ๋ฅ์ ๋ ์ข์ง ์์ ๊ฒ์ผ๋ก ๋ณด์ Overfitting์ด ๋ฐ์ํ ๊ฒ์ ํ์ธํ ์ ์์์ต๋๋ค.
PASCAL VOC & MS COCO
๊ฐ๊ฐ PASCAL VOC 2007/2012 ๋ฐ์ดํฐ์ MS COCO ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ Object Detection์ ์์ด์๋ VGGNet์ ์ฌ์ฉํ ๊ฒ๋ณด๋ค ResNet์ ์ฌ์ฉํ ๊ฒ์ด ๋ ์ข์ ์ฑ๋ฅ์ด ๋์ค๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
PASCAL VOC 2007/2012
MS COCO
์ฝ๋
์ด๋ฒ ํฌ์คํธ์์๋ ResNet-50์ ๊ตฌํํด๋ณด๋ ์๊ฐ์ ๊ฐ๊ฒ ์ต๋๋ค.
๋ผ์ด๋ธ๋ฌ๋ฆฌ
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import tqdm
from tqdm.auto import trange
ํ์ดํผํ๋ผ๋ฏธํฐ
1
2
3
batch_size = 50
learning_rate = 0.0002
num_epoch = 100
Load CIFAR-10
1
2
3
4
5
6
7
8
9
10
11
12
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# define dataset
cifar10_train = datasets.CIFAR10(root="../Data/", train=True, transform=transform, target_transform=None, download=True)
cifar10_test = datasets.CIFAR10(root="../Data/", train=False, transform=transform, target_transform=None, download=True)
# define loader
train_loader = DataLoader(cifar10_train,batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
test_loader = DataLoader(cifar10_test,batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True)
# define classes
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Basic Module
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def conv_block_1(in_dim,out_dim, activation,stride=1):
model = nn.Sequential(
nn.Conv2d(in_dim,out_dim, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_dim),
activation,
)
return model
def conv_block_3(in_dim,out_dim, activation, stride=1):
model = nn.Sequential(
nn.Conv2d(in_dim,out_dim, kernel_size=3, stride=stride, padding=1),
nn.BatchNorm2d(out_dim),
activation,
)
return model
Bottleneck Module
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class BottleNeck(nn.Module):
def __init__(self,in_dim,mid_dim,out_dim,activation,down=False):
super(BottleNeck,self).__init__()
self.down=down
# ํน์ฑ์ง๋์ ํฌ๊ธฐ๊ฐ ๊ฐ์ํ๋ ๊ฒฝ์ฐ
if self.down:
self.layer = nn.Sequential(
conv_block_1(in_dim,mid_dim,activation,stride=2),
conv_block_3(mid_dim,mid_dim,activation,stride=1),
conv_block_1(mid_dim,out_dim,activation,stride=1),
)
# ํน์ฑ์ง๋ ํฌ๊ธฐ + ์ฑ๋์ ๋ง์ถฐ์ฃผ๋ ๋ถ๋ถ
self.downsample = nn.Conv2d(in_dim,out_dim,kernel_size=1,stride=2)
# ํน์ฑ์ง๋์ ํฌ๊ธฐ๊ฐ ๊ทธ๋๋ก์ธ ๊ฒฝ์ฐ
else:
self.layer = nn.Sequential(
conv_block_1(in_dim,mid_dim,activation,stride=1),
conv_block_3(mid_dim,mid_dim,activation,stride=1),
conv_block_1(mid_dim,out_dim,activation,stride=1),
)
# ์ฑ๋์ ๋ง์ถฐ์ฃผ๋ ๋ถ๋ถ
self.dim_equalizer = nn.Conv2d(in_dim,out_dim,kernel_size=1)
def forward(self,x):
if self.down:
downsample = self.downsample(x)
out = self.layer(x)
out = out + downsample
else:
out = self.layer(x)
if x.size() is not out.size():
x = self.dim_equalizer(x)
out = out + x
return out
Define ResNet-50
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# 50-layer
class ResNet(nn.Module):
def __init__(self, base_dim, num_classes=10):
super(ResNet, self).__init__()
self.activation = nn.ReLU()
self.layer_1 = nn.Sequential(
nn.Conv2d(3,base_dim,7,2,3),
nn.ReLU(),
nn.MaxPool2d(3,2,1),
)
self.layer_2 = nn.Sequential(
BottleNeck(base_dim,base_dim,base_dim*4,self.activation),
BottleNeck(base_dim*4,base_dim,base_dim*4,self.activation),
BottleNeck(base_dim*4,base_dim,base_dim*4,self.activation,down=True),
)
self.layer_3 = nn.Sequential(
BottleNeck(base_dim*4,base_dim*2,base_dim*8,self.activation),
BottleNeck(base_dim*8,base_dim*2,base_dim*8,self.activation),
BottleNeck(base_dim*8,base_dim*2,base_dim*8,self.activation),
BottleNeck(base_dim*8,base_dim*2,base_dim*8,self.activation,down=True),
)
self.layer_4 = nn.Sequential(
BottleNeck(base_dim*8,base_dim*4,base_dim*16,self.activation),
BottleNeck(base_dim*16,base_dim*4,base_dim*16,self.activation),
BottleNeck(base_dim*16,base_dim*4,base_dim*16,self.activation),
BottleNeck(base_dim*16,base_dim*4,base_dim*16,self.activation),
BottleNeck(base_dim*16,base_dim*4,base_dim*16,self.activation),
BottleNeck(base_dim*16,base_dim*4,base_dim*16,self.activation,down=True),
)
self.layer_5 = nn.Sequential(
BottleNeck(base_dim*16,base_dim*8,base_dim*32,self.activation),
BottleNeck(base_dim*32,base_dim*8,base_dim*32,self.activation),
BottleNeck(base_dim*32,base_dim*8,base_dim*32,self.activation),
)
self.avgpool = nn.AvgPool2d(1,1)
self.fc_layer = nn.Linear(base_dim*32,num_classes)
def forward(self, x):
out = self.layer_1(x)
out = self.layer_2(out)
out = self.layer_3(out)
out = self.layer_4(out)
out = self.layer_5(out)
out = self.avgpool(out)
out = out.view(batch_size,-1)
out = self.fc_layer(out)
return out
Train
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
device = torch.device("cuda:0")
model = ResNet(base_dim=64).to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=learning_rate)
loss_arr = []
for i in trange(num_epoch):
for j,[image,label] in enumerate(train_loader):
x = image.to(device)
y_= label.to(device)
optimizer.zero_grad()
output = model.forward(x)
loss = loss_func(output,y_)
loss.backward()
optimizer.step()
if i % 10 ==0:
print(loss)
loss_arr.append(loss.cpu().detach().numpy())
์ฑ๋ฅ (epoch = 100)
Train Loss
Test Accuracy
Accuracy of Test Data: 74.33999633789062%