[ํ์ดํ ์น] ํ์ดํ ์น ๊ธฐ์ด ์์ (Autograd๋)
์๋ณธ ๊ฒ์๊ธ: https://velog.io/@euisuk-chung/ํ์ดํ ์น-ํ์ดํ ์น-๊ธฐ์ด-์์-Autograd๋
์์ ํ์ ์ญ์ ํ
์ ๊ฒฝ๋ง(Neural Network)์ ์ด๋ค ์ ๋ ฅ ๋ฐ์ดํฐ์ ๋ํด ์คํ๋๋ ์ค์ฒฉ๋ ํจ์๋ค์ ์งํฉ์ฒด์ ๋๋ค. ์ ๊ฒฝ๋ง์ ์๋ 2๋จ๊ณ๋ฅผ ๊ฑฐ์ณ ํ์ต๋ฉ๋๋ค :
- ์์ ํ(Forward Propagation)
- ์ญ์ ํ(Backward Propagation)
Forward Propagation (์์ ํ)
Forward Propagation(์์ ํ) ๋จ๊ณ์์, ์ ๊ฒฝ๋ง์ ์ ๋ต์ ๋ง์ถ๊ธฐ ์ํด ์ต์ ์ ์ถ์ธก(best guess)์ ํฉ๋๋ค. ์ด๋ ๊ฒ ์ถ์ธก์ ํ๊ธฐ ์ํด์ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ๊ฐ ํจ์๋ค์์ ์คํํฉ๋๋ค.
Back Propagation (์ญ์ ํ)
Back Propagation (์ญ์ ํ) ๋จ๊ณ์์, ์ ๊ฒฝ๋ง์ ์ถ์ธกํ ๊ฐ์์ ๋ฐ์ํ error์ ๋น๋กํ์ฌ ํ๋ผ๋ฏธํฐ๋ค์ ์ ์ ํ ์ ๋ฐ์ดํธํฉ๋๋ค. ์ถ๋ ฅ(output)๋ก๋ถํฐ ์ญ๋ฐฉํฅ์ผ๋ก ์ด๋ํ๋ฉด์ ์ค๋ฅ์ ๋ํ ํจ์๋ค์ ๋งค๊ฐ๋ณ์๋ค์ ๋ฏธ๋ถ๊ฐ(gradient)์ ์์งํ๊ณ , ๊ฒฝ์ฌํ๊ฐ๋ฒ(gradient descent)์ ์ฌ์ฉํ์ฌ ๋งค๊ฐ๋ณ์๋ค์ ์ต์ ํ ํฉ๋๋ค.
๋ด๋ด๋คํธ์ํฌ ํ์ต ์๊ณ ๋ฆฌ์ฆ
-
๋ชจ๋ ๊ฐ์ค์น w๋ฅผ ์์๋ก ์์ฑ
[Forward Propagation]
-
์ ๋ ฅ๋ณ์ ๊ฐ๊ณผ ์ ๋ ฅ์ธต๊ณผ ์๋์ธต ์ฌ์ด์ w๊ฐ์ ์ด์ฉํ์ฌ ์๋๋ ธ๋์ ๊ฐ์ ๊ณ์ฐ
(์ ํ๊ฒฐํฉ ํ activationํ ๊ฐ)
-
์๋๋ ธ๋์ ๊ฐ๊ณผ ์๋์ธต๊ณผ ์ถ๋ ฅ์ธต ์ฌ์ด์ w๊ฐ์ ์ด์ฉํ์ฌ ์ถ๋ ฅ๋ ธ๋์ ๊ฐ์ ๊ณ์ฐ
(์ ํ๊ฒฐํฉ ํ activationํ ๊ฐ)
[Back Propagation]
- ๊ณ์ฐ๋ ์ถ๋ ฅ๋ ธ๋์ ๊ฐ๊ณผ ์ค์ ์ถ๋ ฅ๋ณ์์ ๊ฐ์ ์ฐจ์ด๋ฅผ ์ค์ผ ์ ์๋๋ก ์๋์ธต๊ณผ ์ถ๋ ฅ์ธต ์ฌ์ด์ w๊ฐ์ ์ ๋ฐ์ดํธ
- ๊ณ์ฐ๋ ์ถ๋ ฅ๋ ธ๋์ ๊ฐ๊ณผ ์ค์ ์ถ๋ ฅ๋ณ์์ ๊ฐ์ ์ฐจ์ด๋ฅผ ์ค์ผ ์ ์๋๋ก ์ ๋ ฅ์ธต๊ณผ ์๋์ธต ์ฌ์ด์ w๊ฐ์ ์ ๋ฐ์ดํธ
- ์๋ฌ๊ฐ ์ถฉ๋ถํ ์ค์ด๋ค ๋๊น์ง 2๋ฒ ~ 5๋ฒ์ ๋ฐ๋ณต
Autograd ๊ฐ๋
pyTorch๋ฅผ ์ด์ฉํด ์ฝ๋๋ฅผ ์์ฑํ ๋ ์ด๋ฌํ ์ญ์ ํ๋ฅผ ํตํด ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๋ฐ์ดํธํ๋ ๋ฐฉ๋ฒ์ ๋ฐ๋ก Autograd ์ ๋๋ค. ์ฐจ๊ทผ์ฐจ๊ทผ ์ฝ๋๋ฅผ ํตํด ์์๋ณด๋๋ก ํฉ์๋ค. Autograd์ ๋ํด ์ดํด๋ณด๊ธฐ ์ํด ๊ฐ๋จํ MLP(Mulyi-Layer Perceptron)์ ์์๋ก ์ดํด๋ณผ๊น์?
import torch
๋จผ์ pyTorch๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํด์ ๋ค์๊ณผ ๊ฐ์ด pyTorch๋ฅผ importํด์ค๋๋ค. ์ด๋, torch.cuda
์ is_available()
ํจ์๋ฅผ ํตํด ํ์ฌ ํ์ด์ฌ์ด ์คํ๋๊ณ ์๋ ํ๊ฒฝ์ด GPU๋ฅผ ์ด์ฉํด์ ๊ณ์ฐ์ ํ ์ ์๋๊ฐ๋ฅผ ํ์ธํ ์ ์์ต๋๋ค.
๐ป ์ฝ๋
1
2
3
4
5
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
๐ป ๊ฒฐ๊ณผ
1
2
3
4
5
# GPU ์ฌ์ฉ์ด ๊ฐ๋ฅํ ๋
device(type='cuda')
# GPU ์ฌ์ฉ์ด ๋ถ๊ฐ๋ฅํ ๋
device(type='cuda')
BATCH_SIZE
BATCH_SIZE๋ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ด ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๋ฐ์ดํธํ ๋ ๊ณ์ฐ๋๋ ๋ฐ์ดํฐ ๋ฌถ์์ ๊ฐ์์ ๋๋ค. ์์์ Neural Network์ด Forward Propagation(์์ ํ)์ Backward Propagation(์ญ์ ํ)๋ฅผ ์ํํ๋ฉด์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๋ฐ์ดํธ๋ฅผ ํ๋ค๊ณ ์๊ฐ ๋๋ ธ๋๋ฐ, ์ด๋ฌํ ์ ๋ฐ์ดํธ๋ฅผ ์ํํ๋ ๋ฐ ์ฌ์ฉ๋๋ ๋ฐ์ดํฐ ๋จ์(๊ฐฏ์)๊ฐ ๋๋ ๊ฒ์ด ๋ฐ๋ก BATCH_SIZE์ ๋๋ค. ์๋ ์์์์ BATCH_SIZE๋ก 32๋ฅผ ์ง์ ํด์คฌ๋๋ฐ, ์ด๋ ์ฝ๋ ์์ฑ์ ๋ง์๋๋ก(?) ์ ํด์ฃผ๋ ํ์ดํผํ๋ผ๋ฏธํฐ์ ๋๋ค.
๐ป ์ฝ๋
1
2
# ํ์ดํผํ๋ผ๋ฏธํฐ ์ง์
BATCH_SIZE = 32
INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, LEARNING_RATE
INPUT_SIZE
๋ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ์ ๋ ฅ๊ฐ์ ํฌ๊ธฐ์ด๋ฉฐ, ์ ๋ ฅ์ธต์ ๋ ธ๋ ์๋ฅผ ์๋ฏธํฉ๋๋ค.HIDDEN_SIZE
๋ ์ ๋ ฅ๊ฐ์ ๋ค์์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ์ฉํ์ฌ ๊ณ์ฐ๋๋ ๊ฐ์ ๊ฐ์๋ก, ์๋ ์ธต์ ๋ ธ๋ ์๋ฅผ ์๋ฏธํฉ๋๋ค.OUTPUT_SIZE
๋ ์๋๊ฐ์ ๋ค์์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ์ฉํ์ฌ ๊ณ์ฐ๋๋ ๊ฒฐ๊ณผ๊ฐ์ ๊ฐ์๋ก, ์ถ๋ ฅ ์ธต์ ๋ ธ๋ ์๋ฅผ ์๋ฏธํฉ๋๋ค.LEARNING_RATE
์ Gradient๋ฅผ ์ ๋ฐ์ดํธํ ๋ ๊ณฑํด์ฃผ๋ 0๊ณผ 1์ฌ์ด์ ์กด์ฌํ๋ ๊ฐ์ ๋๋ค. ์ข ๋ ๋๋ฆฌ์ง๋ง ์ฌ์ธํ๊ณ ์ด์ดํ ์ ๋ฐ์ดํธ๋ฅผ ์ํ๋ฉด ์์ rate์, ์ข ๋ ๋น ๋ฅด๊ฒ ์ ๋ฐ์ดํธ๋ฅผ ์ํ๋ฉด ํฐ rate๋ฅผ ์ค ์ ์์ต๋๋ค.
๐ป ์ฝ๋
1
2
3
4
5
# ํ์ดํผํ๋ผ๋ฏธํฐ ์ง์
INPUT_SIZE = 1000
HIDDEN_SIZE = 100
OUTPUT_SIZE = 2
LEARNING_RATE = 1e-6
- ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ ์ธํ์ผ๋ฉด ์คํ์ ํด๋ด์ผ๊ฒ ์ฃ ? ์ผ๋จ ์คํํ๊ฒฝ์ ์ํด ๋ค์๊ณผ ๊ฐ์ด ์์์ ๊ฐ์ผ๋ก input(X), output(Y), Weights(W1, W2)๋ฅผ ์ ์ํด์ค๋๋ค.
- ์ด๋,
requires_grad=True
๋ autograd ์ ๋ชจ๋ ์ฐ์ฐ(operation)๋ค์ ์ถ์ ํด์ผ ํ๋ค๊ณ ์๋ ค์ค๋๋ค.
๐ป ์ฝ๋
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# ์์์ X, Y, Weight ์ ์
# x : input ๊ฐ >> (32, 1000)
x = torch.randn(BATCH_SIZE,
INPUT_SIZE,
device = device,
dtype = torch.float,
requires_grad = False)
# y : output ๊ฐ >> (32, 2)
y = torch.randn(BATCH_SIZE,
OUTPUT_SIZE,
device = device,
dtype = torch.float,
requires_grad = False)
# w1 : input -> hidden >> (1000, 100)
w1 = torch.randn(INPUT_SIZE,
HIDDEN_SIZE,
device = device,
dtype = torch.float,
requires_grad = True)
# w2 : hidden -> output >> (100, 2)
w2 = torch.randn(HIDDEN_SIZE,
OUTPUT_SIZE,
device = device,
dtype = torch.float,
requires_grad = True)
Train Model (iteration = 500)
- ๋ณธ ํฌ์คํธ์ Autograd๋ฅผ ํ์ธํด๋ณด๋ ํฌ์คํธ์ด๋ฏ๋ก, ๋จ์ํ๊ฒ for๋ฌธ์ ์ด์ฉํ์ฌ 500๋ฒ iteration์ ์ํํ๋๋ก ์ฝ๋๋ฅผ ์์ฑํ์์ต๋๋ค.
torch.mm()
: mm์ matrix multiplication์ ์ค์๋ง์ผ๋ก, ํ๋ ฌ์ ๊ณฑ์ ์ ์๋ฏธํฉ๋๋ค.torch.nn.ReLU()
: ReLUํจ์, ReLU๋ max(0, x)๋ฅผ ์๋ฏธํ๋ ํจ์์ธ๋ฐ, 0๋ณด๋ค ์์์ง๊ฒ ๋๋ฉด 0์ด ๋๊ณ , ๊ทธ ์ด์์ ๊ฐ์ ์ ์งํ๋ค๋ ํน์ง์ ๊ฐ์ง๊ณ ์์ต๋๋ค.loss.backward()
: loss์ ๋ํ์ฌ.backward()
๋ฅผ ํธ์ถํ ๊ฒ์ผ๋ก, autograd๋ ๊ฐ ํ๋ผ๋ฏธํฐ ๊ฐ์ ๋ํด ๋ฏธ๋ถ๊ฐ(gradient)์ ๊ณ์ฐํ๊ณ ์ด๋ฅผ ๊ฐ ํ ์์.grad
์์ฑ(attribute)์ ์ ์ฅํฉ๋๋ค.with torch.no_grad()
: ๋ฏธ๋ถ๊ฐ(gradient) ๊ณ์ฐ์ ์ฌ์ฉํ์ง ์๋๋ก ์ค์ ํ๋ ์ปจํ ์คํธ-๊ด๋ฆฌ์(Context-manager)์ ๋๋ค. ํด๋น ๋ชจ๋๋ ์ ๋ ฅ์ requires_grad=True๊ฐ ์์ด๋, ์ด๋ฅผ requires_grad=False๋ก ๋ฐ๊ฟ์ค๋๋ค.
๐ป ์ฝ๋
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from torch import nn
# 500 iteration
for t in range(1, 501):
# ์๋๊ฐ
hidden = nn.ReLU(x.mm(w1))
# ์์ธก๊ฐ
y_pred = hidden.mm(w2)
# ์ค์ฐจ์ ๊ณฑํฉ ๊ณ์ฐ
loss = (y_pred - y).pow(2).sum()
# iteration 100 ๋ง๋ค ๊ธฐ๋กํ๋๋ก
if t % 100 == 0:
print(t, "th Iteration: ", sep = "")
print(">>>> Loss: ", loss.item())
# Loss์ Gradient ๊ณ์ฐ
loss.backward()
# ํด๋น ์์ ์ Gradient๊ฐ์ ๊ณ ์
with torch.no_grad():
# Weight ์
๋ฐ์ดํธ
w1 -= LEARNING_RATE * w1.grad
w2 -= LEARNING_RATE * w2.grad
# Weight Gradient ์ด๊ธฐํ(0)
w1.grad.zero_()
w2.grad.zero_()
- 500๋ฒ์ ๋ฐ๋ณต๋ฌธ์ ์คํํ๋ฉด์ ์ ์ Loss๊ฐ ์ค์ด๋๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
๐ป ๊ฒฐ๊ณผ
1
2
3
4
5
6
7
8
9
10
100th Iteration:
>>>> Loss: 926.969116210
200th Iteration:
>>>> Loss: 6.41975164413
300th Iteration:
>>>> Loss: 0.06706248223
400th Iteration:
>>>> Loss: 0.00112969405
500th Iteration:
>>>> Loss: 0.00011484944
์ฌํ ๊ฐ๋
Computational Graph (์ฐ์ฐ ๊ทธ๋ํ)
autograd
๋ ๋ฐ์ดํฐ(ํ ์)์ ๋ฐ ์คํ๋ ๋ชจ๋ ์ฐ์ฐ๋ค์ ๊ธฐ๋ก์ ๊ฐ์ฒด๋ก ๊ตฌ์ฑ๋ ๋ฐฉํฅ์ฑ ๋น์ํ ๊ทธ๋ํ(DAG; Directed Acyclic Graph)์ ์ ์ฅ(keep)ํฉ๋๋ค.- ๋ฐฉํฅ์ฑ ๋น์ํ ๊ทธ๋ํ(DAG)์ NN์ ์ ๋ฐ์ ์ธ ๊ณ์ฐ๊ณผ์ ์ ๊ทธ๋ํ๋ก ๋ํ๋ธ ๊ฒ์ผ๋ก, ์(leave)์ ์ ๋ ฅ ํ ์(๋ฐ์ดํฐ)์ด๊ณ , ๋ฟ๋ฆฌ(root)๋ ๊ฒฐ๊ณผ ํ ์(๋ฐ์ดํฐ)์ ๋๋ค.
- ์ด๋ฌํ ๋ฐฉํฅ์ฑ ๋น์ํ ๊ทธ๋ํ(DAG)๋ฅผ ๋ฟ๋ฆฌ์์๋ถํฐ ์๊น์ง ์ถ์ ํ๋ฉด ์ฐ์ ๋ฒ์น(chain rule)์ ๋ฐ๋ผ ๊ธฐ์ธ๊ธฐ(gradient)๋ฅผ ์๋์ผ๋ก ๊ณ์ฐํ ์ ์๋ ๊ตฌ์กฐ์ ๋๋ค.
์์ ํ ๋จ๊ณ ์์, autograd๋ ์๋ ๋ ๊ฐ์ง ์์ ์ ๋์์ ์ํํฉ๋๋ค.
- ์์ฒญ๋ ์ฐ์ฐ์ ์ํํ์ฌ ๊ฒฐ๊ณผ ํ ์๋ฅผ ๊ณ์ฐํ๊ณ ,
- DAG์ ์ฐ์ฐ์ gradient function์ ์ ์ง(maintain)ํฉ๋๋ค.
์ญ์ ํ ๋จ๊ณ ๋ DAG ๋ฟ๋ฆฌ(root)์์ .backward()
๊ฐ ํธ์ถ๋ ๋ ์์๋ฉ๋๋ค. autograd ๋ ์๋ ์ธ ๊ฐ์ง ์์
์ ์์ฐจ์ ์ผ๋ก ์ํํฉ๋๋ค.
- ๊ฐ
.grad_fn
์ผ๋ก๋ถํฐ gradient๋ฅผ ๊ณ์ฐ - ๊ฐ ํ
์์
.grad
์์ฑ์ ๊ณ์ฐ ๊ฒฐ๊ณผ๋ฅผ ์ ์ฅ(accumulate) - ์ฐ์ ๋ฒ์น์ ์ฌ์ฉํ์ฌ, ๋ชจ๋ ์(leaf) ํ ์๋ค๊น์ง ์ ํ(propagate)
์ฐธ๊ณ ๊ฐ๋
์ถ์ฒ : https://tutorials.pytorch.kr/beginner/nn_tutorial.html
- ์ฐ์ฐ ๊ทธ๋ํ์ autograd๋ ๋ณต์กํ ์ฐ์ฐ์๋ฅผ ์ ์ํ๊ณ ๋ํจ์(derivative)๋ฅผ ์๋์ผ๋ก ๊ณ์ฐํ๋ ๋งค์ฐ ๊ฐ๋ ฅํ ํจ๋ฌ๋ค์(paradigm)์ ๋๋ค. ํ์ง๋ง ๋๊ท๋ชจ ์ ๊ฒฝ๋ง์์๋ autograd ๊ทธ ์์ฒด๋ง์ผ๋ก๋ ๋๋ฌด ์ ์์ค(low-level)์ผ ์ ์์ต๋๋ค.
torch.nn, torch.optim
-
PyTorch๋ ์ ๊ฒฝ๋ง(neural network)๋ฅผ ์์ฑํ๊ณ ํ์ต์ํค๋ ๊ฒ์ ๋์์ฃผ๊ธฐ ์ํด์
torch.nn
,torch.optim
์ด ์ ๊ณต๋ฉ๋๋ค.-
torch.nn : ๋ค์ํ ๋ด๋ด ๋คํธ์ํฌ๋ฅผ ์์ฑํ ์ ์๋ ํจํค์ง์ ๋๋ค.
torch.nn.Module
: ํจ์์ฒ๋ผ ๋์ํ์ง๋ง, ๋ํ ์ํ(state)๋ฅผ ํฌํจํ ์ ์๋ ํธ์ถ ๊ฐ๋ฅํ ์ค๋ธ์ ํธ๋ฅผ ์์ฑํฉ๋๋ค. ์ด๋ ํฌํจ๋ Parameter๋ค์ด ์ด๋ค ๊ฒ์ธ์ง ์๊ณ , ๋ชจ๋ ๊ธฐ์ธ๊ธฐ๋ฅผ 0์ผ๋ก ์ค์ ํ๊ณ ๊ฐ์ค์น ์ ๋ฐ์ดํธ ๋ฑ์ ์ํด ๋ฐ๋ณตํ ์ ์์ต๋๋ค.torch.nn.Parameter
: Module ์ ์ญ์ ํ ๋์ ์ ๋ฐ์ดํธ๊ฐ ํ์ํ ๊ฐ์ค์น๊ฐ ์์์ ์๋ ค์ฃผ๋ ํ ์์ฉ ๋ํผ์ ๋๋ค. requires_grad ์์ฑ์ด ์ค์ ๋ ํ ์๋ง ์ ๋ฐ์ดํธ ๋ฉ๋๋ค.torch.nn.functional
: ํ์ฑํ ํจ์, ์์ค ํจ์ ๋ฑ์ ํฌํจํ๋ ๋ชจ๋์ด๊ณ , ๋ฌผ๋ก ์ปจ๋ณผ๋ฃจ์ ๋ฐ ์ ํ ๋ ์ด์ด ๋ฑ์ ๋ํด์ ์ํ๋ฅผ ์ ์ฅํ์ง์๋(non-stateful) ๋ฒ์ ์ ๋ ์ด์ด๋ฅผ ํฌํจํฉ๋๋ค.
-
torch.optim: ์์์๋
torch.no_grad()
๋ก ํ์ต ๊ฐ๋ฅํ ๋งค๊ฐ๋ณ์๋ฅผ ๊ฐ๋ ํ ์๋ค์ ์ง์ ์กฐ์ํ์ฌ ๋ชจ๋ธ์ ๊ฐ์ค์น(weight)๋ฅผ ๊ฐฑ์ ํ์์ต๋๋ค. ์ด๋ ๊ฐ๋จํ ์ต์ ํ ์๊ณ ๋ฆฌ์ฆ์์๋ ํฌ๊ฒ ๋ถ๋ด์ด ๋์ง ์์ง๋ง, ์ค์ ๋ก ์ ๊ฒฝ๋ง์ ํ์ตํ ๋๋ AdaGrad, RMSProp, Adam ๋ฑ๊ณผ ๊ฐ์ ๋ ์ ๊ตํ ์ตํฐ๋ง์ด์ (optimizer)๋ฅผ ์ฌ์ฉํ๊ณค ํฉ๋๋ค. ์ด์ PyTorch์ optim ํจํค์ง๋ ์ต์ ํ ์๊ณ ๋ฆฌ์ฆ์ ๋ํ ์์ด๋์ด๋ฅผ ์ถ์ํํ๊ณ ์ผ๋ฐ์ ์ผ๋ก ์ฌ์ฉํ๋ ์ต์ ํ ์๊ณ ๋ฆฌ์ฆ์ ๊ตฌํ์ฒด(implementation)๋ฅผ ์ ๊ณตํฉ๋๋ค.
-
Dataset, DataLoader
- ๋ฐ์ดํฐ ์ํ์ ์ฒ๋ฆฌํ๋ ์ฝ๋๋ ์ง์ ๋ถํ๊ณ ์ ์ง๋ณด์๊ฐ ์ด๋ ค์ธ ์ ์์ต๋๋ค. ๋ ๋์ ๊ฐ๋ ์ฑ(readability)๊ณผ ๋ชจ๋์ฑ(modularity)์ ์ํด ๋ฐ์ดํฐ์ ์ฝ๋๋ฅผ ๋ชจ๋ธ ํ์ต ์ฝ๋๋ก๋ถํฐ ๋ถ๋ฆฌํ๋ ๊ฒ์ด ์ด์์ ์ ๋๋ค.
-
PyTorch๋
torch.utils.data.DataLoader
์torch.utils.data.Dataset
์ ๋ ๊ฐ์ง ๋ฐ์ดํฐ ๊ธฐ๋ณธ ์์๋ฅผ ์ ๊ณตํ์ฌ ๋ฏธ๋ฆฌ ์ค๋นํด๋(pre-loaded) ๋ฐ์ดํฐ์ ๋ฟ๋ง ์๋๋ผ ๊ฐ์ง๊ณ ์๋ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ ์ ์๋๋ก ํฉ๋๋ค.- torch.utils.data.Dataset: ์ํ๊ณผ ์ ๋ต(label)์ ์ ์ฅํ๊ณ , len ๋ฐ getitem ์ด ์๋ ๊ฐ์ฒด์ ์ถ์ ์ธํฐํ์ด์ค์ ๋๋ค.
- torch.utils.data.DataLoader: ๋ชจ๋ ์ข ๋ฅ์ Dataset์ ๊ธฐ๋ฐ์ผ๋ก ๋ฐ์ดํฐ์ ๋ฐฐ์น๋ค์ ์ถ๋ ฅํ๋ ๋ฐ๋ณต์(iterator)๋ฅผ ์์ฑํฉ๋๋ค.
๊ธด ๊ธ ์ฝ์ด์ฃผ์ ์ ๊ฐ์ฌํฉ๋๋ค ^~^