[Paper Review] Structured State Space Models for Deep Sequence Modeling
์๋ณธ ๊ฒ์๊ธ: https://velog.io/@euisuk-chung/Structured-State-Space-Models-for-Deep-Sequence-Modeling
์๊ณ์ด ๋ฐ์ดํฐ๋ฅผ ํจ์จ์ ์ผ๋ก ์ฒ๋ฆฌํ๋ ๋ฐฉ๋ฒ์ ์ง๋ ๋ช ๋ ๋์ ๋น ๋ฅด๊ฒ ๋ฐ์ ํ์ต๋๋ค. ํนํ, CMU์ ๊ณ์ Albert Gu ๊ต์๋์ ๊ธด ์๊ณ์ด ์์กด์ฑ(Long-Range Dependencies, LRDs)์ ์ฒ๋ฆฌํ๋ ๋ฐ ์ง์คํ HiPPO(2020), LSSL(2021), ๊ทธ๋ฆฌ๊ณ S4(2022)์ ๊ฐ์ ์ฐ๊ตฌ๋ค์ ํ๊ณ ๊ณ์ญ๋๋ค.
์ด๋ฒ ๊ธ์์๋ ์ฐ๊ตฌ์ ํ๋ฆ๊ณผ ๊ฐ ๋ชจ๋ธ์ ๊ธฐ์ ์ ๋ฐฐ๊ฒฝ๊ณผ ์ฃผ์ ๊ธฐ์ฌ๋ฅผ ์ค๋ช ํ๊ณ , ์ด๋ ค์ด ๊ฐ๋ ๋ค์ ํ์ด๋ด ๋๋ค. ์ด๋ฏธ์ง๋ค์ ์๋ Reference์ ์ ์ด๋ ๊ฐ์, ๋ธ๋ก๊ทธ ๋๋ ๋ ผ๋ฌธ์์์ ๋ฐ์ทํ์ฌ ํธ์ง ๋๋ ์ฌ์ฉํ์์ต๋๋ค.
Backgrounds
-
Sequence Modeling์ ํ์์ฑ
Sequence Modeling
์ ์๊ฐ์ ๋ฐ๋ฅธ ๋ฐ์ดํฐ์ ํจํด์ ๋ถ์ํ๊ณ ์์ธกํ๋ ๋ฐ ํ์์ ์ธ ๊ธฐ์ ์
๋๋ค. ์๋ฅผ ๋ค์ด, ์์ฑ ์ธ์, ๊ธ์ต ์๊ณ์ด ๋ถ์, ๋ฐ์ด์ค ์ ํธ ๋ถ์ ๋ฑ ๋ค์ํ ๋ถ์ผ์์ ์ด๋ฌํ ๊ธฐ์ ์ด ํ์ฉ๋ฉ๋๋ค. ํนํ, ๊ธด ์ํ์ค(long sequences)๋ฅผ ๋ค๋ฃจ๋ ๋ชจ๋ธ์ ์ด๋ฌํ ๋ฐ์ดํฐ๋ฅผ ํจ๊ณผ์ ์ผ๋ก ์ฒ๋ฆฌํ๊ณ ์ค์ํ ์ ๋ณด๋ฅผ ์ถ์ถํ๋ ๋ฐ ์ค์ ์ ๋ก๋๋ค.
-
Sequence Modeling์ ์ฃผ์ ๊ณผ์
๊ธด ์ํ์ค๋ฅผ ๋ค๋ฃจ๋ ๊ณผ์ ์์ ๋ ๊ฐ์ง ์ฃผ์ ๊ณผ์ ๊ฐ ์์ต๋๋ค.
- ์ฒซ์งธ, ๋ฐ์ดํฐ์ ์๊ฐ์ ์ฐ์์ฑ(time continuity)์ ์ ์งํ๋ฉด์๋ ์ด๋ฅผ ํจ๊ณผ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์๋ ๋ชจ๋ธ์ด ํ์ํฉ๋๋ค.
- ๋์งธ, ํ์ต ๊ณผ์ ์์ ๋ฐ์ํ๋ Vanishing Gradient ๋ฌธ์ ๋ฅผ ํด๊ฒฐํด์ผ ํฉ๋๋ค.
์ด๋ RNN์ด๋ ๊ธฐ์กด์ ์์ฐจ ๋ชจ๋ธ๋ค์ด ๊ธด ์ํ์ค์์ ๋ฐ์ํ๋ ๊ณตํต์ ์ธ ๋ฌธ์ ๋ก, ์๊ฐ์ ๋ฐ๋ผ ์ ํธ๊ฐ ์ ์ฐจ ์ฝํด์ ธ ๋ชจ๋ธ ํ์ต์ ์ด๋ ค์์ ์ค๋๋ค.
-
State Space Model(SSM) ์๊ฐ
State Space Model(SSM)
์ ๋ณธ๋ ์ ์ด ์ด๋ก ์์ ์ ๋ํ ๋ชจ๋ธ๋ก, ์์คํ
์ ์ํ(state)์ ์ถ๋ ฅ์ ์ํ์ ์ผ๋ก ์ ์ํ ๊ฒ์
๋๋ค. ์ด ๋ชจ๋ธ์ ์
๋ ฅ ๋ฐ์ดํฐ(xxx)๋ฅผ ๋ฐ์ ์ํ(hhh)๋ฅผ ๊ณ์ฐํ ํ ์ด๋ฅผ ์ถ๋ ฅ(yyy)์ผ๋ก ๋ณํํ๋ ๋ ๊ฐ์ง ์ฃผ์ ๋ฐฉ์ ์์ผ๋ก ์ ์๋ฉ๋๋ค.
- SSM์ ํฌ๊ฒ 3๊ฐ์ง Representation์ผ๋ก ํํ๋ ์ ์์ต๋๋ค:
- ์ฐ์ ํํ (Continuous Representation)
- ์์ฐจ์ ํํ (Recurrent Representation)
- ํฉ์ฑ๊ณฑ ํํ (Convolution Representation)
1. ์ฐ์ ํํ (Continuous Representation)
๊ฐ์ฅ ๋จผ์ SSM์ ์ฐ์ ํํ(continuous Representation)
์ ์ฒ๋ฆฌํ ์ ์์ผ๋ฉฐ, ์ด๋ฅผ ํตํด ์ํ์ค ๋ฐ์ดํฐ์ ์ฐ์์ฑ์ ์์ฐ์ค๋ฝ๊ฒ ๋ชจ๋ธ๋งํ ์ ์์ต๋๋ค.
SSM์ ์ฃผ์ ์ํ์ ํํ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- ์ํ ๋ฐฉ์ ์: hโฒ(t)=Ah(t)+Bx(t)hโ(t) = Ah(t) + Bx(t)hโฒ(t)=Ah(t)+Bx(t)
- ์ถ๋ ฅ ๋ฐฉ์ ์: y(t)=Ch(t)+Dx(t)y(t) = Ch(t) + Dx(t)y(t)=Ch(t)+Dx(t)
์ด ๋ฐฉ์ ์์ ๊ธฐ๋ฐ์ผ๋ก, SSM์ ์ ๋ ฅ ์ํ์ค๊ฐ ์ฃผ์ด์ก์ ๋ ์ด๋ฅผ ์ฒ๋ฆฌํ์ฌ ์ฐ์์ ์ธ ์ถ๋ ฅ์ ์์ฑํ๋ ๋ชจ๋ธ์ ๋๋ค.
ํ์ง๋ง, ์ฌ๊ธฐ์ ๋ฐํ๋๋ y๋ ์ฐ์๋ ์๊ณ์ด ํํ(continuous-time representation)์ ๋๋ค. ์ด๋ฅผ ๊ธฐ๊ณ ๋๋ ์ฌ๋์ด ์ดํดํ ์ ์๋ ๋ฒ์ฃผ๋ก ๊ฐ์ ธ์ค๊ธฐ ์ํด์๋ Discrete Signal๋ก discretization(์ด์ฐํ) ์์ ์ ์ํํด์ผํฉ๋๋ค.
๐ ์ด์ฐํ๋?
์ด์ฐํ(้ขๆฃๅ, discretization)๋ ์์ฉ์ํ์์, ์ฐ์์ ์ธ ํจ์, ๋ชจ๋ธ, ๋ณ์, ๋ฐฉ์ ์์ ์ด์ฐ์ ์ธ ๊ตฌ์ฑ์์๋ก ๋ณํํ๋ ํ๋ก์ธ์ค(process)์ด๋ค. ์ด ํ๋ก์ธ์ค๋ ์ผ๋ฐ์ ์ผ๋ก ๋์งํธ ์ปดํจํฐ์์ ์์น์ ํ๊ฐ ๋ฐ ๊ตฌํ์ ์ ํฉํ๋๋ก ํ๋ ์ฒซ ๋จ๊ณ๋ก ์ํ๋๋ค.
2. Recurrent Representation
๋ค์์ผ๋ก Recurrent Representation์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์์ ์์ฐจ์ ์ผ๋ก ์ํ hkh_khkโ๋ฅผ ์ ๋ฐ์ดํธํ๋ ๊ตฌ์กฐ์ ๋๋ค. ์ฆ, kkk-๋ฒ์งธ ์๊ฐ ๋จ๊ณ์์์ ์ํ hkh_khkโ๋ ์ด์ ์ํ hkโ1h_{k-1}hkโ1โ์ ์์กดํฉ๋๋ค.
์ ๊ทธ๋ฆผ์ ๋ณด์๋ฉด ์ด์ ์ ์ค์ ์ผ๋ก ์ด์ด์ง ๊ทธ๋ํ์๋ ๋ค๋ฅด๊ฒ ์ง๊ธ์ ๊ทธ๋ํ๋ ์๊ฒ ์๊ฒ ๋ธ๋ก์ผ๋ก ๋๋ ๊ฒ์ ๋ณด์ค ์ ์์ฃ ? ์ด๊ฒ์ ๋ฐ์ดํฐ๊ฐ ์ด์ฐํ๋์๊ธฐ ๋๋ฌธ์ ๋๋ค.
์ด๋์ ๊ฐ ๋ง์ด ๋ณธ ๊ทธ๋ฆผ์๋๊ฐ์? ๋ฐ๋ก RNN์ ๋ชจ์๊ณผ ์ ์ฌํ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
๋
ผ๋ฌธ์์๋ ์ด์ฐํ๋ฅผ ์ํด Zero-order hold (ZOH)
์ด๋ผ๋ ๊ธฐ๋ฒ์ ์ฌ์ฉํฉ๋๋ค. ZOH๋ ๋์งํธ ์ ํธ๋ฅผ ์๋ ๋ก๊ทธ ์ ํธ๋ก ๋ณํํ๋ ๊ณผ์ ์์ ์ฌ์ฉ๋๋ ์ค์ํ ๊ธฐ๋ฒ์ผ๋ก, ๊ฐ ์ํ๋ง ์ฃผ๊ธฐ ๋์ ์ ํธ ๊ฐ์ ์ผ์ ํ๊ฒ ์ ์งํ๋ ๋ฐฉ๋ฒ์
๋๋ค.
โ๏ธ ZOH์ ์ํ์ ํํ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
xZOH(t)=โn=โโโx[n]โ rect(tโT/2โnTT)x_{\text{ZOH}}(t) = \sum_{n=-\infty}^{\infty} x[n] \cdot \text{rect}\left(\frac{t-T/2-nT}{T}\right)xZOHโ(t)=โn=โโโโx[n]โ rect(TtโT/2โnTโ)
์ฌ๊ธฐ์:
- x[n]x[n]x[n]์ ์ด์ฐ ์๊ฐ ์ ๋ ฅ ์ ํธ
- TTT๋ ์ํ๋ง ์ฃผ๊ธฐ
- rect(โ )\text{rect}(\cdot)rect(โ )๋ ์ง์ฌ๊ฐํ ํจ์
(ZOH๊ธฐ๋ฐ) ์ฐ์ ์๊ฐ SSM์ ์ด์ฐ ์๊ฐ SSM์ผ๋ก ๋ณํ
1. ์ฐ์ ์๊ฐ SSM:
- hโฒ(t)=Ah(t)+Bx(t)hโ(t) = Ah(t) + Bx(t)hโฒ(t)=Ah(t)+Bx(t)
- y(t)=Ch(t)+Dx(t)y(t) = Ch(t) + Dx(t)y(t)=Ch(t)+Dx(t)
2. ZOH ๊ฐ์ :
- x(t)=x(kฮt)ย ,ย forkฮtโคt<(k+1)ฮtx(t) = x(k\Delta t) \quad \text{ , for} \quad k\Delta t \leq t < (k+1)\Delta tx(t)=x(kฮt)ย ,ย forkฮtโคt<(k+1)ฮt
3. ์ํ ๋ฐฉ์ ์ ํด๊ฒฐ:
- h(t)=eA(tโkฮt)h(kฮt)+โซkฮtteA(tโฯ)Bx(kฮt)dฯh(t) = e^{A(t-k\Delta t)}h(k\Delta t) + \int_{k\Delta t}^t e^{A(t-\tau)}Bx(k\Delta t)d\tauh(t)=eA(tโkฮt)h(kฮt)+โซkฮttโeA(tโฯ)Bx(kฮt)dฯ
4. ์ด์ฐ ์๊ฐ ๋ชจ๋ธ ๋์ถ:
- hk+1=Aหhk+Bหxkh_{k+1} = \bar{A}h_k + \bar{B}x_khk+1โ=Aหhkโ+Bหxkโ
-
yk=Chk+Dxky_k = Ch_k + Dx_kykโ=Chkโ+Dxkโ
์ฌ๊ธฐ์,
- Aห=eAฮt\bar{A} = e^{A\Delta t}Aห=eAฮt
- Bห=Aโ1(eAฮtโI)B\bar{B} = A^{-1}(e^{A\Delta t} - I)BBห=Aโ1(eAฮtโI)B
์ด์ ์ด์ฐํํ ๊ฒฐ๊ณผ๋ฅผ ์ดํด๋ณด๋ฉด ๋ค์๊ณผ ๊ฐ์ด ๊ฐ๊ฐ์ T=0, T=1, T=2์ ๋ํด์ ์ด์ time k-1์ hkโ1h_{k-1}hkโ1โ์ input๊ณผ ํ์์ xkx_kxkโ์ input์ ๋ฐ์์ hkh_khkโ๋ฅผ ๋์ถํ๊ณ ์ด๋ฅผ ํตํด yky_kykโ๋ฅผ ์ฌ๊ท์ ์ผ๋ก ํธ์ถํ๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค.
๋ฐ๋ก RNN๊ณผ ์ ์ฌํ ํํ๋ก ๋ง์ด์ฃ !!
3. Convolution Representation
Recurrent Representation์ ์์ฐจ์ ์ธ ์ํ ์ ๋ฐ์ดํธ๋ฅผ Convolution Representation์ผ๋ก ๋ฐ๊พธ๊ธฐ ์ ์ ๋จผ์ ์๊ฐ ์์ผ๋ก hkh_khkโ๋ฅผ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
์๋ฅผ ๋ค์ด,
-
์ฒซ ๋ฒ์งธ ์ํ(k=1)๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- ์ํ h1h_1h1โ:h1=Aหh0+Bหx0h_1 = \bar{A} h_0 + \bar{B} x_0h1โ=Aหh0โ+Bหx0โ
- ์ถ๋ ฅ y1y_1y1โ:y1=Ch1+Dx0=C(Aหh0+Bหx0)+Dx0y_1 = C h_1 + D x_0 = C(\bar{A} h_0 + \bar{B} x_0) + D x_0y1โ=Ch1โ+Dx0โ=C(Aหh0โ+Bหx0โ)+Dx0โ
-
๋ ๋ฒ์งธ ์ํ(k=2)๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- ์ํ h2h_2h2โ:h2=Aหh1+Bหx1=Aห2h0+AหBหx0+Bหx1h_2 = \bar{A} h_1 + \bar{B} x_1 = \bar{A}^2 h_0 + \bar{A} \bar{B} x_0 + \bar{B} x_1h2โ=Aหh1โ+Bหx1โ=Aห2h0โ+AหBหx0โ+Bหx1โ
- ์ถ๋ ฅ y2y_2y2โ:y2=Ch2+Dx1=C(Aห2h0+AหBหx0+Bหx1)+Dx1y_2 = C h_2 + D x_1 = C(\bar{A}^2 h_0 + \bar{A} \bar{B} x_0 + \bar{B} x_1) + D x_1y2โ=Ch2โ+Dx1โ=C(Aห2h0โ+AหBหx0โ+Bหx1โ)+Dx1โ
-
์ธ ๋ฒ์งธ ์ํ(k=3)๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- ์ํ h3h_3h3โ:h3=Aหh2+Bหx2=Aห3h0+Aห2Bหx0+AหBหx1+Bหx2h_3 = \bar{A} h_2 + \bar{B} x_2 = \bar{A}^3 h_0 + \bar{A}^2 \bar{B} x_0 + \bar{A} \bar{B} x_1 + \bar{B} x_2h3โ=Aหh2โ+Bหx2โ=Aห3h0โ+Aห2Bหx0โ+AหBหx1โ+Bหx2โ
- ์ถ๋ ฅ y3y_3y3โ:y3=Ch3+Dx2=C(Aห3h0+Aห2Bหx0+AหBหx1+Bหx2)+Dx2y_3 = C h_3 + D x_2 = C(\bar{A}^3 h_0 + \bar{A}^2 \bar{B} x_0 + \bar{A} \bar{B} x_1 + \bar{B} x_2) + D x_2y3โ=Ch3โ+Dx2โ=C(Aห3h0โ+Aห2Bหx0โ+AหBหx1โ+Bหx2โ)+Dx2โ
-
๋ค ๋ฒ์งธ ์ํ(k=4)๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- ์ํ h4h_4h4โ:h4=Aหh3+Bหx3=Aห4h0+Aห3Bหx0+Aห2Bหx1+AหBหx2+Bหx3h_4 = \bar{A} h_3 + \bar{B} x_3 = \bar{A}^4 h_0 + \bar{A}^3 \bar{B} x_0 + \bar{A}^2 \bar{B} x_1 + \bar{A} \bar{B} x_2 + \bar{B} x_3h4โ=Aหh3โ+Bหx3โ=Aห4h0โ+Aห3Bหx0โ+Aห2Bหx1โ+AหBหx2โ+Bหx3โ
- ์ถ๋ ฅ y4y_4y4โ:y4=Ch4+Dx3=C(Aห4h0+Aห3Bหx0+Aห2Bหx1+AหBหx2+Bหx3)+Dx3y_4 = C h_4 + D x_3 = C(\bar{A}^4 h_0 + \bar{A}^3 \bar{B} x_0 + \bar{A}^2 \bar{B} x_1 + \bar{A} \bar{B} x_2 + \bar{B} x_3) + D x_3y4โ=Ch4โ+Dx3โ=C(Aห4h0โ+Aห3Bหx0โ+Aห2Bหx1โ+AหBหx2โ+Bหx3โ)+Dx3โ
-
๋ค์ฏ ๋ฒ์งธ ์ํ(k=5)๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- ์ํ h5h_5h5โ:h5=Aหh4+Bหx4=Aห5h0+Aห4Bหx0+Aห3Bหx1+Aห2Bหx2+AหBหx3+Bหx4h_5 = \bar{A} h_4 + \bar{B} x_4 = \bar{A}^5 h_0 + \bar{A}^4 \bar{B} x_0 + \bar{A}^3 \bar{B} x_1 + \bar{A}^2 \bar{B} x_2 + \bar{A} \bar{B} x_3 + \bar{B} x_4h5โ=Aหh4โ+Bหx4โ=Aห5h0โ+Aห4Bหx0โ+Aห3Bหx1โ+Aห2Bหx2โ+AหBหx3โ+Bหx4โ
- ์ถ๋ ฅ y5y_5y5โ:y5=Ch5+Dx4=C(Aห5h0+Aห4Bหx0+Aห3Bหx1+Aห2Bหx2+AหBหx3+Bหx4)+Dx4y_5 = C h_5 + D x_4 = C(\bar{A}^5 h_0 + \bar{A}^4 \bar{B} x_0 + \bar{A}^3 \bar{B} x_1 + \bar{A}^2 \bar{B} x_2 + \bar{A} \bar{B} x_3 + \bar{B} x_4) + D x_4y5โ=Ch5โ+Dx4โ=C(Aห5h0โ+Aห4Bหx0โ+Aห3Bหx1โ+Aห2Bหx2โ+AหBหx3โ+Bหx4โ)+Dx4โ
๊ท์น์ด ์ข ๋ณด์ด์๋์?! ์ข ๋ ์ด์๊ฒ ์ ๊ฐ ๋ง๋ ๊ทธ๋ฆผ์ ๋ฐ์ ๋ณด์ฌ๋๋ฆฌ๊ฒ ์ต๋๋ค. (*D term์ ์๋ตํจ)
1. k=1k = 1k=1 (์ฒซ ๋ฒ์งธ ์ถ๋ ฅ)
- ์ํ: h1=CBหx0h_1 = C \bar{B} x_0h1โ=CBหx0โ
- ์ฌ๊ธฐ์ ์ปค๋์ ๋ง์ง๋ง ํญ๋ชฉ CBหC \bar{B}CBห๊ฐ ์ ๋ ฅ x0x_0x0โ์ ๊ณฑํด์ ธ ์ฒซ ๋ฒ์งธ ์ถ๋ ฅ y1y_1y1โ๊ฐ ๊ณ์ฐ๋ฉ๋๋ค.
- ํจ๋ฉ์ด ์๊ธฐ ๋๋ฌธ์ ์ปค๋์ ์ ๋ ํญ๋ชฉ์ ์ ๋ ฅ๊ณผ ์ํธ์์ฉํ์ง ์๊ณ ํจ๋ฉ(0)์ ํด๋นํฉ๋๋ค.
- ์ถ๋ ฅ: y1=CBหx0y_1 = C \bar{B} x_0y1โ=CBหx0โ
- ์ฒซ ๋ฒ์งธ ์ถ๋ ฅ์ CBหx0C \bar{B} x_0CBหx0โ๋ก ํํ๋ฉ๋๋ค.
2. k=2k = 2k=2 (๋ ๋ฒ์งธ ์ถ๋ ฅ)
- ์ํ: h2=CAหBหx0+CBหx1h_2 = C \bar{A} \bar{B} x_0 + C \bar{B} x_1h2โ=CAหBหx0โ+CBหx1โ
- ์ด์ ์ปค๋์ ๋ ๋ฒ์งธ ํญ๋ชฉ์ด x0x_0x0โ, ๋ง์ง๋ง ํญ๋ชฉ์ด x1x_1x1โ๊ณผ ๊ณฑํด์ง๋ฉด์ ๋ ๋ฒ์งธ ์ํ๊ฐ ๊ณ์ฐ๋ฉ๋๋ค.
- ํจ๋ฉ ๊ฐ์ด ํ๋ ๋จ์์๊ณ , ์ปค๋์ ์ฒซ ๋ฒ์งธ ํญ๋ชฉ์ ์ฌ์ ํ ํจ๋ฉ(0)๊ณผ ์ํธ์์ฉํฉ๋๋ค.
- ์ถ๋ ฅ: y2=CAหBหx0+CBหx1y_2 = C \bar{A} \bar{B} x_0 + C \bar{B} x_1y2โ=CAหBหx0โ+CBหx1โ
- ๋ ๋ฒ์งธ ์ถ๋ ฅ์ ์ด์ ์ ๋ ฅ๊ณผ ํ์ฌ ์ ๋ ฅ์ ํฉ์ผ๋ก ๊ณ์ฐ๋ฉ๋๋ค.
3. k=3k = 3k=3 (์ธ ๋ฒ์งธ ์ถ๋ ฅ)
- ์ํ: h3=CAห2Bหx0+CAหBหx1+CBหx2h_3 = C \bar{A}^2 \bar{B} x_0 + C \bar{A} \bar{B} x_1 + C \bar{B} x_2h3โ=CAห2Bหx0โ+CAหBหx1โ+CBหx2โ
- ์ธ ๋ฒ์งธ ์ํ์์๋ ์ปค๋์ ๋ชจ๋ ํญ๋ชฉ์ด ์ค์ ์ ๋ ฅ๊ณผ ์ํธ์์ฉํ๊ธฐ ์์ํฉ๋๋ค.
- ์ปค๋์ ์ฒซ ๋ฒ์งธ ํญ๋ชฉ์ x0x_0x0โ, ๋ ๋ฒ์งธ ํญ๋ชฉ์ x1x_1x1โ, ์ธ ๋ฒ์งธ ํญ๋ชฉ์ x2x_2x2โ์ ๊ณฑํด์ง๋๋ค.
- ์ถ๋ ฅ: y3=CAห2Bหx0+CAหBหx1+CBหx2y_3 = C \bar{A}^2 \bar{B} x_0 + C \bar{A} \bar{B} x_1 + C \bar{B} x_2y3โ=CAห2Bหx0โ+CAหBหx1โ+CBหx2โ
- ์ธ ๋ฒ์งธ ์ถ๋ ฅ์ x0x_0x0โ, x1x_1x1โ, x2x_2x2โ์ ๋ํ ์ปค๋ ๊ฐ์คํฉ์ผ๋ก ๊ณ์ฐ๋ฉ๋๋ค.
4. k=4k = 4k=4 (๋ค ๋ฒ์งธ ์ถ๋ ฅ)
- ์ํ: h4=CAห2Bหx1+CAหBหx2+CBหx3h_4 = C \bar{A}^2 \bar{B} x_1 + C \bar{A} \bar{B} x_2 + C \bar{B} x_3h4โ=CAห2Bหx1โ+CAหBหx2โ+CBหx3โ
- ๋ค ๋ฒ์งธ ์ํ์์๋ ์ปค๋์ด x1x_1x1โ, x2x_2x2โ, x3x_3x3โ๊ณผ ์ํธ์์ฉํฉ๋๋ค.
- ๋ ์ด์ ํจ๋ฉ์ด ์ ์ฉ๋์ง ์์ผ๋ฉฐ, ์ ๋ ฅ ์ํ์ค์ ์ปค๋ ๊ฐ์ ์์ ํ ์ํธ์์ฉ์ด ์ด๋ฃจ์ด์ง๋๋ค.
- ์ถ๋ ฅ: y4=CAห2Bหx1+CAหBหx2+CBหx3y_4 = C \bar{A}^2 \bar{B} x_1 + C \bar{A} \bar{B} x_2 + C \bar{B} x_3y4โ=CAห2Bหx1โ+CAหBหx2โ+CBหx3โ
- ๋ค ๋ฒ์งธ ์ถ๋ ฅ์ x1x_1x1โ, x2x_2x2โ, x3x_3x3โ์ ๋ํ ์ปค๋ ๊ฐ์คํฉ์ ๋๋ค.
Convolution Representation
๋ฐฉ์์ ์ฅ์ ์ Recurrent Representation์์ ๊ฐ ์๊ฐ ๋จ๊ณ๋ณ๋ก ์์ฐจ์ ์ผ๋ก ์ํ๋ฅผ ์
๋ฐ์ดํธํ๋ ๋์ , ๋ชจ๋ ์๊ฐ ๋จ๊ณ์ ์ถ๋ ฅ์ ํ ๋ฒ์ ๊ณ์ฐํ ์ ์๋ค๋ ์ ์
๋๋ค.
- ๋ณ๋ ฌ ์ฒ๋ฆฌ ๊ฐ๋ฅ: Recurrent Representation์์๋ ๊ฐ ์๊ฐ ๋จ๊ณ๋ณ๋ก ์์ฐจ์ ์ผ๋ก ์ํ๋ฅผ ์ ๋ฐ์ดํธํด์ผ ํ๋ฏ๋ก ๊ณ์ฐ์ด ์ง๋ ฌํ๋์ด ์์ต๋๋ค. ๊ทธ๋ฌ๋ Convolution Representation์์๋ ์ปค๋์ ์ด์ฉํ์ฌ ์ ๋ ฅ ์ํ์ค ์ ์ฒด์ ๊ฑธ์ณ ๋์์ ์ถ๋ ฅ์ ๊ณ์ฐํ ์ ์์ด, ๋ณ๋ ฌํ๊ฐ ๊ฐ๋ฅํด์ง๋๋ค. ์ด๋ ํนํ ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ๋ ํจ์จ์ ์ธ ๊ณ์ฐ์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค.
- ๋ ํฐ ์ปค๋ ์ ์ฉ ๊ฐ๋ฅ: ์์์์๋ ์ปค๋ ์ฌ์ด์ฆ๋ฅผ 3์ผ๋ก ์ค์ ํ์ง๋ง, ์ด๋ก ์ ์ผ๋ก๋ ๋ ํฐ ์ปค๋๋ ์ฌ์ฉํ ์ ์์ต๋๋ค. ๋ ํฐ ์ปค๋์ ๋ ๊ธด ๋ฒ์์ ๊ณผ๊ฑฐ ์ ๋ ฅ์ ํ ๋ฒ์ ์ฒ๋ฆฌํ ์ ์์ด, ๋ ๋์ ๋ฌธ๋งฅ ์ ๋ณด๋ฅผ ํ์ฉํ ์ ์๊ฒ ํฉ๋๋ค. ์ด๋ ์ํ์ค ๋ฐ์ดํฐ์์ ์ฅ๊ธฐ์ ์ธ ์ข ์์ฑ์ ๋ ์ ๋ฐ์ํ๋ ๋ฐ ๋์์ด ๋ฉ๋๋ค.
- ํจ์จ์ฑ: ํฉ์ฑ๊ณฑ ์ฐ์ฐ์ ์ผ๋ฐ์ ์ผ๋ก GPU์ ๊ฐ์ ๋ณ๋ ฌํ๊ฐ ๊ฐ๋ฅํ ํ๋์จ์ด์์ ๋งค์ฐ ๋น ๋ฅด๊ฒ ์ฒ๋ฆฌ๋ ์ ์์ต๋๋ค. ์ด๋ Recurrent Representation์ ๋นํด ๊ณ์ฐ ์๋์์ ํฐ ์ด์ ์ ์ ๊ณตํฉ๋๋ค.
๊ทธ๋ฌ๋, ์ด์์ ์ผ๋ก ์ด๋ฌํ deepSSM์ ๋ฐ๋ก ์ ์ฉํ๊ธฐ์๋ ๋ง์ ๋ฌธ์ ์ ๋ค์ด ์์๋๋ฐ์.
์๋ ์ฐ๊ตฌ๋ค์ ์ด๋ฐ Convolution Representation์ ์ด๋ป๊ฒ ํจ์จ์ ์ผ๋ก ๊ณ์ฐํ๊ณ ์ฒ๋ฆฌํ ์ ์๋๊ฐ์ ๋ํ ์ฐ๊ตฌ๋ค์ ๋๋ค.
- HiPPO : Recurrent Memory with Optimal Polynomial Projections (2020)
- LSSL : Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers (2021)
- S4 : Efficiently Modeling Long Sequences with Structured State Spaces (2022)
Research
์ด ๋ ผ๋ฌธ๋ค์ ๊ฐ๊ฐ ์๊ณ์ด ๋ฐ์ดํฐ๋ฅผ ๋ค๋ฃจ๋ ๊ธฐ์กด ๋ชจ๋ธ์ ํ๊ณ๋ฅผ ๊ทน๋ณตํ๋ ์ค์ํ ๊ธฐ์ ์ ๋ฐ์ ์ ๋ด๊ณ ์์ต๋๋ค.
-
HiPPO: Recurrent Memory with Optimal Polynomial Projections (NeurIPS, 2020)
- ๋ชฉ์ :
๊ธด ์ํ์ค์ ๋ํ ๋ฉ๋ชจ๋ฆฌ ๋ฌธ์ ๋ฅผ ํด๊ฒฐ
ํ๊ณ ,๋ฉ๋ชจ๋ฆฌ๋ฅผ ํจ์จ์ ์ผ๋ก ์ ์งํ๋ฉด์ ์ ๋ ฅ ์ ๋ณด๋ฅผ ๊ณ์ ์ ๋ฐ์ดํธ
ํ๋ ๋ฐฉ๋ฒ์ ์ ์ํฉ๋๋ค. - ํจ๊ณผ: ์ด ์ฐ๊ตฌ๋ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ๊ณผ ์ ๋ณด ์ ์ง ๊ฐ์ ๊ท ํ์ ์ฐพ๋ ๋ฐ ์ด์ ์ ๋ง์ถฅ๋๋ค. ์ด๋ฅผ ํตํด ๊ธด ์ํ์ค์์๋ ์ ๋ณด๋ฅผ ํจ๊ณผ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์๊ฒ ๋ฉ๋๋ค.
- ๋ชฉ์ :
-
LSSL: Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers (NeurIPS, 2021)
- ๋ชฉ์ : ์ด ์ฐ๊ตฌ๋
์ฐ์ ์๊ฐ ๋ชจ๋ธ๊ณผ ์ ํ ์ํ ๊ณต๊ฐ ๋ ์ด์ด(LSSL)๋ฅผ ๊ฒฐํฉ
ํ์ฌ,์๊ฐ์ ๋ฐ๋ฅธ ์ฐ์์ ์ธ ๋ณํ์ ๋น์ฐ์์ ์ธ ๋ณํ๋ฅผ ๋์์ ์ฒ๋ฆฌ
ํ ์ ์๊ฒ ๋ง๋ญ๋๋ค. - ํจ๊ณผ: LSSL์ ๋ชจ๋ธ์ ์ ์ฐ์ฑ์ ๋์ฌ์, ์๊ณ์ด ๋ฐ์ดํฐ๋ฟ ์๋๋ผ ๋ค์ํ ์ข ๋ฅ์ ์ฐ์์ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ์ ์๋๋ก ๋์ต๋๋ค.
- ๋ชฉ์ : ์ด ์ฐ๊ตฌ๋
-
S4: Efficiently Modeling Long Sequences with Structured State Spaces (ICLR, 2022)
- ๋ชฉ์ : S4๋ Convolution Representation์ ํจ์จ์ฑ์ ๊ทน๋ํํ๋ฉด์๋,
์ฅ๊ธฐ์ ์ธ ์ข ์์ฑ์ ๋ ์ ์ฒ๋ฆฌํ ์ ์๊ฒ ์ต์ ํ
๋์์ต๋๋ค. - ํจ๊ณผ: S4๋ ํนํ ์ฅ๊ธฐ์ ์ธ ํจํด ํ์ต์ ๊ฐ์ ์ด ์์ด, ๊ธฐ์กด์ ๋ชจ๋ธ๋ณด๋ค ํจ์ฌ ๊ธด ์ํ์ค์์๋ ์ฐ์ํ ์ฑ๋ฅ์ ๋ณด์ ๋๋ค.
- ๋ชฉ์ : S4๋ Convolution Representation์ ํจ์จ์ฑ์ ๊ทน๋ํํ๋ฉด์๋,
-
HiPPO: Recurrent Memory with Optimal Polynomial Projections (Neurips 2020)
Preliminary
๋ณธ๊ฒฉ์ ์ผ๋ก HiPPO
๋ฅผ ์ดํด๋ณด๊ธฐ์ ์์, ๋ค์ ์ํ ๊ฐ๋
๋ค์ ์ด๋ ์ ๋ ์ดํดํ๊ณ ์์ด์ผ ๊ด๋ จ ๋ด์ฉ์ ๋ ์ ์ดํดํ ์ ์์ต๋๋ค: ๋ผ๊ฒ๋ฅด(Laguerre) ๋คํญ์, ๋ฅด์ฅ๋๋ฅด(Legendre) ๋คํญ์, ๊ทธ๋ฆฌ๊ณ ๋คํญ์ ํฌ์ ์ฐ์ฐ์์
๋๋ค. ์ด ๊ฐ๋
๋ค์ ์๊ณ์ด ๋ฐ์ดํฐ๋ฅผ ๋ถ์ํ๊ณ ๊ทผ์ฌํ๋ ๋ฐ ์ค์ํ ์ญํ ์ ํ๋ฉฐ, ๋ณต์กํ ๋ฐ์ดํฐ๋ฅผ ์ํ์ ์ผ๋ก ๋ค๋ฃจ๋ ๊ฐ๋ ฅํ ๋๊ตฌ์
๋๋ค. ๊ฐ ๊ฐ๋
์ ์ฐจ๋ก๋ก ์ค๋ช
ํ๊ฒ ์ต๋๋ค.
1. ๋คํญ์ ํฌ์ ์ฐ์ฐ์(Polynomial Projection Operator)
๋คํญ์ ํฌ์ ์ฐ์ฐ์
๋ ๋ณต์กํ ๋ฐ์ดํฐ๋ฅผ ํน์ ์ง๊ต ๋คํญ์ ๊ธฐ์ ๋ฅผ ์ฌ์ฉํ์ฌ ๋ ๊ฐ๋จํ ๋คํญ์์ผ๋ก ํํํ๋ ๊ณผ์ ์ ๋๋ค. ์๊ณ์ด ๋ฐ์ดํฐ๋ ํจ์๊ฐ ์ฃผ์ด์ก์ ๋, ์ด๋ฅผ ์ง๊ตํ๋ ๋คํญ์๋ค์ ์ ํ ๊ฒฐํฉ์ผ๋ก ๊ทผ์ฌํฉ๋๋ค.
๐ก ์ง๊ต ๋คํญ์ ๊ธฐ์ (Orthogonal Polynomial Basis)๋ ์ฌ๋ฌ ๋คํญ์ ์ค์์๋ ์๋ก ์ง๊ต(orthogonal)ํ๋ ์ฑ์ง์ ๊ฐ์ง ๋คํญ์๋ค์ ์งํฉ์ ์๋ฏธํฉ๋๋ค.
์ง๊ต์ฑ
์ ๋ ํจ์(๋๋ ๋ ๋คํญ์) ์ฌ์ด์ ๋ด์ (inner product)์ด 0์ด๋ผ๋ ์๋ฏธ์ ๋๋ค. ์ง๊ต์ฑ์ ๋ฐ์ดํฐ๋ ํจ์์ ์๋ก ๋ค๋ฅธ ์ฑ๋ถ์ด ์๋ก ์ํฅ์ ๋ฏธ์น์ง ์๋ ๋ ๋ฆฝ์ ์ธ ๊ด๊ณ๋ฅผ ๋ํ๋ ๋๋ค.๊ธฐ์
๋ ์ฃผ์ด์ง ๊ณต๊ฐ์ ๊ตฌ์ฑํ๋ โ๊ธฐ๋ณธโ ์์๋ค์ ์งํฉ์ ์๋ฏธํฉ๋๋ค. ๊ธฐ์ ๋ฒกํฐ์ ์ ํ ๊ฒฐํฉ์ ํตํด ๊ณต๊ฐ ๋ด์ ๋ชจ๋ ๋ฒกํฐ(๋๋ ํจ์)๋ฅผ ํํํ ์ ์๋ ๊ฒ์ฒ๋ผ, ๊ธฐ์ ๋คํญ์์ ์ฌ์ฉํ๋ฉด ์ฃผ์ด์ง ํจ์๋ ๋ฐ์ดํฐ๋ฅผ ๊ทธ ๊ธฐ์ ๋คํญ์๋ค์ ์ ํ ๊ฒฐํฉ์ผ๋ก ํํํ ์ ์์ต๋๋ค.
3์ฐจ ๋คํญ์ ๊ณต๊ฐ์์๋ ๋ค์๊ณผ ๊ฐ์ ๊ธฐ์ ๋ฅผ ์๊ฐํ ์ ์์ต๋๋ค:
1,x,x2,x31, x, x^2, x^31,x,x2,x3
3์ฐจ ์ดํ์ ๋ชจ๋ ๋คํญ์์ ์ด๋ค์ ์ ํ ๊ฒฐํฉ์ผ๋ก ํํ๋ ์ ์์ต๋๋ค:
f(x)=a0+a1x+a2x2+a3x3f(x) = a_0 + a_1 x + a_2 x^2 + a_3 x^3f(x)=a0โ+a1โx+a2โx2+a3โx3
๋คํญ์ ํฌ์์ ํต์ฌ
: ์ฃผ์ด์ง ํจ์๋ ๋ฐ์ดํฐ๋ฅผ ์ง๊ต ๋คํญ์ ๊ธฐ์ ์์ โํฌ์โํ์ฌ ๊ฐ์ฅ ์ ํฉํ ๊ทผ์ฌ๊ฐ์ ์ฐพ๋ ๊ฒ์ ๋๋ค. ์ง๊ต ๋คํญ์์ ์๋ก ๋ ๋ฆฝ์ด๊ธฐ ๋๋ฌธ์, ๋ฐ์ดํฐ๋ฅผ ์ฌ๋ฌ ๊ฐ์ ๋ ๋ฆฝ์ ์ธ ์ฑ๋ถ์ผ๋ก ๋ถํดํ์ฌ ๋ถ์ํ๋ ๊ฒ์ด ๊ฐ๋ฅํฉ๋๋ค.์ค์ฐจ ์ต์ํ
: ํฌ์ ์ฐ์ฐ์๋ ๋ณดํต ์ต์ ์ ๊ณฑ๋ฒ(least squares method)์ ์ฌ์ฉํ์ฌ ์ฃผ์ด์ง ๋ฐ์ดํฐ๋ฅผ ๋คํญ์ ๊ธฐ์ ๋ก ํํํ๋ ๊ณผ์ ์์ ์ค์ฐจ๋ฅผ ์ต์ํํฉ๋๋ค.
์ผ๋ฐ์ ์ผ๋ก โ๋ด์ !โ ํ๋ฉด ๊ณ ๋ฑํ๊ต์์ ๋ฐฐ์ด aโ b=โฃaโฃโฃbโฃcosโกฮธa \cdot b = | a | ย | b | \cos\thetaaโ b=โฃaโฃโฃbโฃcosฮธ ๊ฐ ์๊ฐ๋์ค๊ฒ๋๋ค. ํ์ง๋ง, ํจ์ ๊ฐ์ ๋ด์ ์ ์ด๋ฅผ ํ์ฅํ ๊ฐ๋ ์ผ๋ก ๋จ์ํ ๊ฐ๋๋ ํฌ๊ธฐ์ ๊ฐ์ ์ง๊ด์ ์ธ ๊ฐ๋ ์ผ๋ก ์ค๋ช ๋์ง ์์ต๋๋ค. |
๐ฌ (REVIEW) ๋ฒกํฐ ๋ด์ : aโ b=โฃaโฃโฃbโฃcosโกฮธa \cdot b = a ย b \cos\thetaaโ b=โฃaโฃโฃbโฃcosฮธ
- aaa์ bbb๋ ๋ฒกํฐ์ ๋๋ค.
โฃaโฃ a โฃaโฃ์ โฃbโฃ b โฃbโฃ๋ ๊ฐ ๋ฒกํฐ์ ๊ธธ์ด(ํฌ๊ธฐ, magnitude)์ ๋๋ค. - ฮธ\thetaฮธ๋ ๋ ๋ฒกํฐ ์ฌ์ด์ ๊ฐ๋์ ๋๋ค.
- ๋ ๋ฒกํฐ์ ๋ด์ ์ ๋ ๋ฒกํฐ ์ฌ์ด์ ์ ์ฌ๋๋ฅผ ์ธก์ ํ๋๋ฐ, ๋ฒกํฐ๊ฐ ํํํ ์๋ก ๋ด์ ์ ๊ฐ์ ํฌ๊ณ , ์ง๊ต(์ฆ, 90๋์ผ ๋)ํ ์๋ก ๋ด์ ์ 0์ด ๋ฉ๋๋ค.
โจ (NEW) ํจ์ ๋ด์ : โจf,gโฉ=โซabf(x)g(x)w(x)โdx\langle f, g \rangle = \int_a^b f(x) g(x) w(x) \, dxโจf,gโฉ=โซabโf(x)g(x)w(x)dx
- f(x)f(x)f(x)์ g(x)g(x)g(x)๋ ํจ์์ ๋๋ค.
- [a,b][a, b][a,b]๋ ํจ์๊ฐ ์ ์๋ ๊ตฌ๊ฐ์ ๋๋ค.
- w(x)w(x)w(x)๋ ๊ฐ์ค ํจ์๋ก, ๋ด์ ๊ณ์ฐ์์ ํน์ ๊ตฌ๊ฐ์ ๊ฐ์ค์น๋ฅผ ๋ถ์ฌํ๋ ์ญํ ์ ํฉ๋๋ค.
- ํจ์์ ๋ด์ ์ ๋ฒกํฐ ๋ด์ ์ฒ๋ผ ํจ์ ์ฌ์ด์ ์ ์ฌ๋๋ฅผ ์ธก์ ํ๋ ์ญํ ์ ํฉ๋๋ค.
-
ํฌ์ ์ฐ์ฐ์์ ์๋ ๋ฐฉ์:
-
๊ธฐ์ ๋คํญ์ ์ ํ: ํน์ ๊ตฌ๊ฐ์์ ์ง๊ตํ๋ ๋คํญ์์ ์ ํํฉ๋๋ค.
โณ ์๋ฅผ ๋ค์ด, ๊ตฌ๊ฐ
[-1, 1]
์์ ๋ฅด์ฅ๋๋ฅด ๋คํญ์, ๊ตฌ๊ฐ[0, โ)
์์ ๋ผ๊ฒ๋ฅด ๋คํญ์์ ์ฌ์ฉํ ์ ์์ต๋๋ค. (์๋ ์ฐธ๊ณ ) -
๊ณ์ ๊ฒฐ์ : ๋คํญ์ ํฌ์ ์ฐ์ฐ์๋ ์ฃผ์ด์ง ๋ฐ์ดํฐ์ ๊ฐ์ฅ ์ ํฉํ ๋คํญ์ ๊ธฐ์ ์ ๊ณ์๋ฅผ ์ฐพ์๋ ๋๋ค.
โณ ์ด๋ฅผ ํตํด ๋ฐ์ดํฐ๋ฅผ ํํํ๋ ํจ์๊ฐ ๊ฐ ๋คํญ์์ ์ ํ ๊ฒฐํฉ์ผ๋ก ๋ํ๋ฉ๋๋ค.
-
ํจ์ ๊ทผ์ฌ: ํฌ์๋ ๊ฒฐ๊ณผ๋ ์๋ ๋ฐ์ดํฐ์ ๋ํ โ์ต์ ์ ๊ทผ์ฌโ๋ฅผ ์ ๊ณตํฉ๋๋ค.
โณ ์ด๋ฅผ ํตํด ๋ฐ์ดํฐ๋ฅผ ๋จ์ํํ๊ฑฐ๋ ๋ถ์ํ ์ ์์ต๋๋ค.
-
2. ๋ฅด์ฅ๋๋ฅด ๋คํญ์(Legendre Polynomials)
- ๋ฅด์ฅ๋๋ฅด ๋คํญ์์ ๊ตฌ๊ฐ
[-1, 1]
์์ ๊ฐ์ค ํจ์ w(x)=1w(x) = 1w(x)=1์ ๋ํด ์ง๊ต์ฑ์ ๊ฐ๋ ๋คํญ์์ ๋๋ค. -
๋ฅด์ฅ๋๋ฅด ๋คํญ์์ ์ง๊ต์ฑ์ ๋ค์ ์์์ผ๋ก ํํ๋ฉ๋๋ค:
2n+12โซโ11Pn(x)Pm(x)โdx=ฮดnm\frac{2n+1}{2} \int_{-1}^{1} P_n(x) P_m(x) \, dx = \delta_{nm}22n+1โโซโ11โPnโ(x)Pmโ(x)dx=ฮดnmโ
-
์ด ์์์ ์๋ก ๋ค๋ฅธ ์ฐจ์์ ๋ฅด์ฅ๋๋ฅด ๋คํญ์๋ค์ด ๊ตฌ๊ฐ
[-1, 1]
์์ ์ง๊ตํจ์ ๋ํ๋ ๋๋ค. ์ฌ๊ธฐ์:- Pn(x)P_n(x)Pnโ(x)์ Pm(x)P_m(x)Pmโ(x)๋ ๊ฐ๊ฐ ์ฐจ์๊ฐ ๋ค๋ฅธ ๋ฅด์ฅ๋๋ฅด ๋คํญ์์ ๋๋ค.
- ฮดnm\delta_{nm}ฮดnmโ๋ ํฌ๋ก๋ค์ปค ๋ธํ๋ก, n=mn = mn=m์ผ ๋๋ 1, ๊ทธ๋ ์ง ์์ผ๋ฉด 0์ ์๋ฏธํฉ๋๋ค. ์ฆ, ๊ฐ์ ์ฐจ์์ผ ๊ฒฝ์ฐ ๋ด์ ์ด 1์ด ๋๊ณ , ๋ค๋ฅธ ์ฐจ์์ผ ๊ฒฝ์ฐ ๋ด์ ์ด 0์ด ๋ฉ๋๋ค.
-
๋ํ, ๋ฅด์ฅ๋๋ฅด ๋คํญ์์ ๋ค์๊ณผ ๊ฐ์ ๊ฒฝ๊ณ ์กฐ๊ฑด์ ๋ง์กฑํฉ๋๋ค:
Pn(1)=1,Pn(โ1)=(โ1)nP_n(1) = 1, \quad P_n(-1) = (-1)^nPnโ(1)=1,Pnโ(โ1)=(โ1)n
์ด๋ ๋ฅด์ฅ๋๋ฅด ๋คํญ์์ ๊ฐ์ด ๊ตฌ๊ฐ ๋์ ์์ ์ด๋ป๊ฒ ๋์ํ๋์ง๋ฅผ ๋ณด์ฌ์ค๋๋ค.
- (์ฐธ๊ณ ) HiPPO์์๋ Legendre ๋คํญ์์ด ์๊ฐ ์ถ์ ๋ฐ๋ฅธ ๋ฐ์ดํฐ๋ฅผ ์์ถํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. ์ด๋ฅผ ํตํด ์ด์ ์์ ์์์ ๋ฐ์ดํฐ๋ฅผ ํจ๊ณผ์ ์ผ๋ก ๋ฉ๋ชจ๋ฆฌ์ ์ ์ฅํ๊ณ ๊ธฐ์ตํ ์ ์์ต๋๋ค.
3. ๋ผ๊ฒ๋ฅด ๋คํญ์(Laguerre Polynomials)
- ๋ผ๊ฒ๋ฅด ๋คํญ์์
[0, โ)
๊ตฌ๊ฐ์์ ๊ฐ์ค ํจ์ eโxe^{-x}eโx์ ๋ํด ์ง๊ต์ฑ์ ๊ฐ๋ ๋คํญ์์ ๋๋ค. -
๋ผ๊ฒ๋ฅด ๋คํญ์์ ์ง๊ต์ฑ์ ๋ค์ ์์์ผ๋ก ํํ๋ฉ๋๋ค:
โซ0โxฮฑeโxLn(ฮฑ)(x)Lm(ฮฑ)(x)โdx=(n+ฮฑ)!n!ฮดnm\int_0^{\infty} x^\alpha e^{-x} L_n^{(\alpha)}(x) L_m^{(\alpha)}(x) \, dx = \frac{(n + \alpha)!}{n!} \delta_{nm}โซ0โโxฮฑeโxLn(ฮฑ)โ(x)Lm(ฮฑ)โ(x)dx=n!(n+ฮฑ)!โฮดnmโ
-
์ด ์์์ ์๋ก ๋ค๋ฅธ ์ฐจ์์ ๋ผ๊ฒ๋ฅด ๋คํญ์๋ค์ด ๊ฐ์ค ํจ์ eโxe^{-x}eโx์ ๋ํด ์ง๊ตํจ์ ๋ํ๋ ๋๋ค. ์ฌ๊ธฐ์:
- Ln(ฮฑ)(x)L_n^{(\alpha)}(x)Ln(ฮฑ)โ(x)์ Lm(ฮฑ)(x)L_m^{(\alpha)}(x)Lm(ฮฑ)โ(x)๋ ๊ฐ๊ฐ ์ผ๋ฐํ๋ ๋ผ๊ฒ๋ฅด ๋คํญ์์ผ๋ก, ๋งค๊ฐ๋ณ์ ฮฑ\alphaฮฑ์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๋๋ค.
- ฮดnm\delta_{nm}ฮดnmโ๋ ํฌ๋ก๋ค์ปค ๋ธํ๋ก, n=mn = mn=m์ผ ๋๋ 1, ๊ทธ๋ ์ง ์์ผ๋ฉด 0์ ๋๋ค.
-
ํ์ค ๋ผ๊ฒ๋ฅด ๋คํญ์์ ๋งค๊ฐ๋ณ์ ฮฑ=0\alpha = 0ฮฑ=0์ผ ๋์ ํน์ํ ๊ฒฝ์ฐ๋ก, ๋ค์๊ณผ ๊ฐ์ ์ง๊ต์ฑ์ ๊ฐ์ง๋๋ค:
โซ0โeโxLn(x)Lm(x)โdx=(n)!n!ฮดnm=ฮดnm\int_0^{\infty} e^{-x} L_n(x) L_m(x) \, dx = \frac{(n)!}{n!} \delta_{nm} = \delta_{nm}โซ0โโeโxLnโ(x)Lmโ(x)dx=n!(n)!โฮดnmโ=ฮดnmโ
- (์ฐธ๊ณ ) HiPPO์์๋ Laguerre ๋คํญ์์ ์ฌ์ฉํ์ฌ ๊ณผ๊ฑฐ ๋ฐ์ดํฐ๋ฅผ ํํํ๊ณ ๊ธฐ์ตํ๋ ๋ฐฉ์์ผ๋ก, ํนํ ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ์ ์ฌ์ฉ๋ฉ๋๋ค. ์ด ๋คํญ์์ ์๊ฐ์ด ์ง๋จ์ ๋ฐ๋ผ ๋ฐ์ดํฐ๊ฐ ์ด๋ป๊ฒ ๋ณํ๋์ง ๋ชจ๋ธ๋งํ๋ ๋ฐ ์ ํฉํฉ๋๋ค.
HiPPO Preliminary
HiPPO์์๋ ๋คํญ์ ํฌ์ ์ฐ์ฐ์(Legendre ๋คํญ์๊ณผ Laguerre ๋คํญ์)์ ํตํด ์๊ฐ ์ถ์ ๋ฐ๋ฅธ ๋ฐ์ดํฐ๋ฅผ ์์ถํ๊ณ , ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ค์ด๋ฉฐ, ์ค์ํ ์ ๋ณด๋ฅผ ์์ฝํ์ฌ ์ ์ฅํ๋ ๋ฐฉ์์ผ๋ก ์ฌ์ฉ๋ฉ๋๋ค.
-
์ด ๋คํญ์๋ค์ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ๋คํญ์ ๊ณต๊ฐ์ ํฌ์ํ์ฌ, ์ด์ ์์ ์ ๋ฐ์ดํฐ๋ฅผ ํจ์จ์ ์ผ๋ก ๊ธฐ์ตํ๊ณ ์ ๋ฐ์ดํธํ๋ ๋ฐ ๋์์ ์ค๋๋ค.
- Legendre ๋คํญ์์ ๊ตฌ๊ฐ
[-1, 1]
๋ด์์ ์๊ณ์ด ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ธ๋งํ๊ณ , ์ง๊ต์ฑ์ ํตํด ๋ฉ๋ชจ๋ฆฌ์ ํจ์จ์ ์ธ ๊ด๋ฆฌ๊ฐ ๊ฐ๋ฅํฉ๋๋ค. - Laguerre ๋คํญ์์ ์ฃผ๋ก ์ ํธ ์ฒ๋ฆฌ์์ ๊ธด ์๊ฐ์ ๊ฑธ์ณ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ๋ ์ฌ์ฉ๋๋ฉฐ, HiPPO์์๋ ๋ฐ์ดํฐ๋ฅผ ์์ฝํ๊ณ ์ ์ฅํ๋ ๋ฐ ์ฌ์ฉํฉ๋๋ค.
- Legendre ๋คํญ์์ ๊ตฌ๊ฐ
Introduction (์๋ก )
-
Introductin์์๋ ๋จผ์ Sequential ๋ฐ์ดํฐ์ ์ฒ๋ฆฌ๋ฅผ ์ํ ํ์กดํ๋ RNN ๋ชจ๋ธ์ ์ ์ฝ ์ฌํญ๋ค์ ์๋์ ๊ฐ์ด ์์ ํฉ๋๋ค:
- Limited Memory Horizon: RNN์ ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ๋ ์ด์ ์ ๋ณด์ ๊ธฐ์ต์ด ์ฝํด์ง๋ ๊ฒฝํฅ์ด ์์ต๋๋ค. ์ฆ, ๋ชจ๋ธ์ด ์ด์ ๋ฐ์ดํฐ์์ ์ค์ํ ์ ๋ณด๋ฅผ ์์ด๋ฒ๋ฆฌ๋ ๋ฌธ์ ์ ์ง๋ฉดํ๊ฒ ๋ฉ๋๋ค.
- Vanishing Gradients: RNN์ ์ญ์ ํ ๊ณผ์ ์์ ๊ธฐ์ธ๊ธฐ๊ฐ ๋งค์ฐ ์์์ ธ์ ๊ฐ์ค์น ์ ๋ฐ์ดํธ๊ฐ ๊ฑฐ์ ์ด๋ฃจ์ด์ง์ง ์๋ ๋ฌธ์ ์ ์ง๋ฉดํฉ๋๋ค. ์ด๋ก ์ธํด ๋ชจ๋ธ์ด ์ฅ๊ธฐ ์์กด์ฑ์ ํ์ตํ๊ธฐ๊ฐ ๋งค์ฐ ์ด๋ ค์์ง๋๋ค.
- ์ํ์ค ๊ธธ์ด ๋ฐ ์๊ฐ ์ฒ๋์ ๋ํ ์ ํ ์ ๋ณด ์๊ตฌ: ๊ธฐ์กด RNN ๋ฐ ๊ทธ ๋ณํ๋ค์ ํน์ ํ ์ํ์ค ๊ธธ์ด๋ ์๊ฐ ์ฒ๋์ ๋ํ ์ ํ ์ ๋ณด(prior)๋ฅผ ํ์๋ก ํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ด๋ฌํ ์ ํ ์ ๋ณด๋ ๋ถํ์คํ ํ๊ฒฝ์ด๋ ๋ฐ์ดํฐ ๋ถํฌ ๋ณํ์ ๋ํด ์ผ๋ฐํํ๊ธฐ ์ด๋ ต์ต๋๋ค.
-
์ด๋ก ์ ๋ณด์ฅ ๊ฒฐ์ฌ(Theoretical Guarantees):
- ๊ธฐ์กด ๋ฐฉ๋ฒ๋ค์ ์ฅ๊ธฐ ์์กด์ฑ์ ์ผ๋ง๋ ์ ์บก์ฒํ ์ ์๋์ง์ ๋ํ ์ด๋ก ์ ๋ณด์ฅ์ด ๋ถ์กฑํฉ๋๋ค. ํนํ, ๊ธฐ์ธ๊ธฐ ๊ฒฝ๊ณ ๋ฑ๊ณผ ๊ฐ์ ์ฑ๋ฅ์ ๋ํ ์ด๋ก ์ ๊ทผ๊ฑฐ๊ฐ ๊ฒฐ์ฌ๋์ด ์์ด, ํจ๊ณผ์ ์ธ ์ฑ๋ฅ์ ๊ธฐ๋ํ๊ธฐ ์ด๋ ต์ต๋๋ค.
- ์ฅ๊ธฐ ๋ฐ ๋ณต์กํ ์๊ฐ ์์กด์ฑ ๋ชจ๋ธ๋ง์ ์ด๋ ค์: RNN์ ๋ณต์กํ ์๊ฐ ์์กด์ฑ์ ๋ชจ๋ธ๋งํ๋ ๋ฐ ํ๊ณ๊ฐ ์์ผ๋ฉฐ, ์ด๋ก ์ธํด ์๋ฃ ๋ฐ์ดํฐ์ ๊ฐ์ ๋ค์ํ ์ํ๋ง ์ฃผ๊ธฐ๋ฅผ ๊ฐ์ง ๋ฐ์ดํฐ์์ ํจ๊ณผ์ ์ผ๋ก ์๋ํ์ง ๋ชปํ ์ ์์ต๋๋ค.
- ๋ ผ๋ฌธ์์๋ ์ด๋ฌํ ํ๊ณ์ ์ ํด๊ฒฐํ๊ธฐ ์ํด HiPPO(High-order Polynomial Projection Operators)๋ผ๋ ์๋ก์ด ํ๋ ์์ํฌ๋ฅผ ์ ์ํฉ๋๋ค.
HiPPO
๋ ์ฐ์ ์ ํธ ๋ฐ ์ด์ฐ ์๊ณ์ด ๋ฐ์ดํฐ๋ฅผ ์ต์ ์ ๋ฐฉ๋ฒ์ผ๋ก ์์ถํ๊ณ ๊ณผ๊ฑฐ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ธ๋งํ์ฌ ์ฅ๊ธฐ ์์กด์ฑ์ ์ ์ฒ๋ฆฌํ ์ ์๋๋ก ๋์์ค๋๋ค.
The HiPPO Framework: High-order Polynomial Projection Operators (HiPPO ํ๋ ์์ํฌ: ๊ณ ์ฐจ ๋คํญ์ ํฌ์ ์ฐ์ฐ์)
HiPPO ํ๋ ์์ํฌ
์ ๋ชฉํ๋ ์๊ฐ์ ๋ฐ๋ผ ๋ณํํ๋ ๋ฐ์ดํฐ๋ฅผ ์์ถ๋ ํํ๋ก ์ ์งํ๋ฉฐ, ๊ฐ ์๊ฐ t์์ ๊ณผ๊ฑฐ ๋ฐ์ดํฐ๋ฅผ ํจ์จ์ ์ผ๋ก ํํํ๋ ๊ฒ์
๋๋ค.
- ์ด ํ๋ ์์ํฌ๋ ์จ๋ผ์ธ ํจ์ ๊ทผ์ฌ๋ฅผ ํตํด ๋ฉ๋ชจ๋ฆฌ ๋ฉ์ปค๋์ฆ์ ๊ณ ์ํ๊ณ , ๊ณ ์ฐจ ๋คํญ์ ํฌ์ ์ฐ์ฐ์๋ฅผ ์ฌ์ฉํ๋ ๋ฐฉ์์ผ๋ก ์์ฐจ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค
๋ฌธ์ ์ ์
-
์ ๋ ฅ ํจ์ f(t)f(t)f(t)์ ๋์ ์ด๋ ฅ์ ์จ๋ผ์ธ์ผ๋ก ์์ถํ์ฌ ํํํ๋ ๋ฐฉ๋ฒ์ ๋ ผ์ํฉ๋๋ค.
- Online Approximation (์จ๋ผ์ธ ๊ทผ์ฌ):
- ๊ฐ ์๊ฐ ttt๋ง๋ค fโคtf_{\leq t}fโคtโ๋ฅผ ๊ทผ์ฌํ๊ธฐ ์ํด ์ธก๋ ฮผ(t)\mu(t)ฮผ(t)๊ฐ ๋ณํํฉ๋๋ค.
- ์ด ์ธก๋๋ ๋ค์ํ ๊ณผ๊ฑฐ ์ ๋ ฅ์ ์ค์๋๋ฅผ ์กฐ์ ํ๋ฉฐ, ์ต์ ์ ๋คํญ์ ๊ทผ์ฌ๋ฅผ ์ฐพ์๋ด๋ ๊ณผ์ ์์ ํ์ฉ๋ฉ๋๋ค.
- Online Approximation (์จ๋ผ์ธ ๊ทผ์ฌ):
-
ํจ์์ ์ญ์ฌ fโคtf_{\leq t}fโคtโ๋ฅผ ์ ์งํ๊ธฐ ์ํด์๋ ๋ ๊ฐ์ง ํ์ ์์๊ฐ ๋์ถ๋ฉ๋๋ค:
๊ทผ์ฌ ๋ฐฉ๋ฒ
๊ณผ์๋ธ์คํ์ด์ค
.-
Function Approximation with respect to a Measure (์ธก๋์ ๋ํ ํจ์ ๊ทผ์ฌ):
- ๊ทผ์ฌ ํ์ง์ ์ ๋ํํ๋ ๋ฐฉ๋ฒ์ ํ๋ฅ ์ธก๋ ฮผ\muฮผ๋ฅผ ํตํด ๋ด์ ์ ์ ์ํ๋ ๋ฐฉ์์ ๋๋ค.
- ๋ด์ ์ โจf,gโฉฮผ=โซ0โf(x)g(x)dฮผ(x)\langle f, g \rangle_\mu = \int_0^\infty f(x) g(x) d\mu(x)โจf,gโฉฮผโ=โซ0โโf(x)g(x)dฮผ(x)๋ก ํํ๋๋ฉฐ, ํจ์ fff์ ggg ์ฌ์ด์ ๊ฑฐ๋ฆฌ ๋๋ ์ค์ฐจ๋ฅผ ์ธก์ ํ ์ ์๋ ๊ธฐ์ค์ ์ ๊ณตํฉ๋๋ค.
-
Polynomial Basis Expansion (๋คํญ์ ๊ธฐ์ด ํ์ฅ):
- ๋คํญ์์ ๊ธฐ๋ฐ์ผ๋ก ํ ๋ถ๋ถ ๊ณต๊ฐ GGG๋ฅผ ์ฌ์ฉํ์ฌ ํจ์๋ฅผ ๊ทผ์ฌํฉ๋๋ค.
- ์ด ๋ถ๋ถ ๊ณต๊ฐ์ ์ฐจ์ NNN ๋ฏธ๋ง์ ๋คํญ์์ผ๋ก ๊ตฌ์ฑ๋๋ฉฐ, ์ด๋ ์ ๋ ฅ ํจ์์ ๊ทผ์ฌ๋ฅผ ์ํด ์ฌ์ฉํ ์ ์๋ ๊ธฐ์ค์ ์ ๊ณตํฉ๋๋ค.
- ์ด๋ฌํ ๊ธฐ์ด ํ์ฅ์ ๋ค์ํ ํจ์๋ค์ ํจ๊ณผ์ ์ผ๋ก ํํํ ์ ์๋ ๊ธฐ๋ฐ์ ์ ๊ณตํฉ๋๋ค.
-
HiPPO ํต์ฌ ์์ด๋์ด
1. Choose suitable basis (์ ์ ํ ๊ธฐ์ ์ ํ)
-
์๋ฏธ:
- ํน์ ํจ์ f(t)f(t)f(t)๋ฅผ ๊ทผ์ฌํ๊ธฐ ์ํด, ๊ทธ ํจ์์ ๊ณต๊ฐ์์ ์ ์ ํ ๋คํญ์ ๊ธฐ์ ๋ฅผ ์ ํํ๋ ๋จ๊ณ์ ๋๋ค.
- ์ด ๊ธฐ์ ๋ ํจ์์ ์ฑ์ง๊ณผ ์๊ฐ ๊ฐ๋ณ ์ธก์ ฮผ(t)\mu(t)ฮผ(t)์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๋ฉฐ, ์ผ๋ฐ์ ์ผ๋ก๋ orthogonal ๋คํญ์์ด ์ฌ์ฉ๋ฉ๋๋ค.
-
์ธ๋ถ ์ฌํญ:
- ์ ํ๋ ๊ธฐ์ {gn}n<N{g_n}_{n < N}{gnโ}n<Nโ๋ NNN์ฐจ์์ ๋คํญ์ ๊ณต๊ฐ์ ๊ตฌ์ฑํ๋ฉฐ, ์ด ๊ธฐ์ ์ ๋ํด ํจ์ fโคtf_{\leq t}fโคtโ๋ฅผ projection ํฉ๋๋ค.
- ์ด๋ ์ฃผ์ด์ง ํจ์์ ๊ธฐ์ ์ ๊ด๊ณ๋ฅผ ์ ์ํ๊ธฐ ์ํ ๊ฒ์ ๋๋ค. ์ต์ ์ ๊ณ์ c(t)c(t)c(t)๋ ๋ค์๊ณผ ๊ฐ์ ๋ด์ ์ ํตํด ๊ณ์ฐ๋ฉ๋๋ค:c(t)n:=โจfโคt,gnโฉฮผ(t)c(t)_n := \langle f_{\leq t}, g_n \rangle_{\mu(t)}c(t)nโ:=โจfโคtโ,gnโโฉฮผ(t)โ
- ์ด ๋จ๊ณ์ ๋ชฉ์ ์ ์ ๋ ฅ ์ ํธ์ ์ค์ํ ํน์ฑ๋ค์ ๋ณด์กดํ๋ฉด์ ๋ณต์กํ ํจ์๋ฅผ ๊ทธ ๊ธฐ์ ์ ๋ง์ถฐ ๊ฐ๋จํ ๋คํญ์์ผ๋ก ํํํ๋ ๊ฒ์ ๋๋ค.
2. Differentiate the projection (ํ๋ก์ ์
๋ฏธ๋ถ)
-
์๋ฏธ:
- ์ ํํ ๊ธฐ์ ์ ๋ํด ์๊ฐ ttt์ ๋ฐ๋ผ projection์ ๋ฏธ๋ถํ๋ ๋จ๊ณ์ ๋๋ค.
- ์ด๋ ์ฃผ์ด์ง ํจ์์ ์๊ฐ์ ๋ณํ๋ฅผ ํฌ์ฐฉํ๊ณ , projection ๊ณ์์ ๋์ญํ์ ์ดํดํ๋ ๋ฐ ํ์ํฉ๋๋ค.
-
์ธ๋ถ ์ฌํญ:
- ๋ฏธ๋ถ์ ํตํด ์ป์ ๊ด๊ณ๋ projection์ ๋ณํ๋์ ์ค๋ช ํ๋ฉฐ, ์ผ๋ฐ์ ์ผ๋ก ์ด๋ฌํ ๋ฏธ๋ถ์ ์๊ธฐ ์ ์ฌ์ฑ์ ๊ฐ์ง๋ ๋ฐฉ์ ์ ํํ๋ก ํํ๋ฉ๋๋ค.ddtcn(t)=functionย ofย f(t)ย andย (ck(t))kโ[N]\frac{d}{dt}c_n(t) = \text{function of } f(t) \text{ and } (c_k(t))_{k \in [N]}dtdโcnโ(t)=functionย ofย f(t)ย andย (ckโ(t))kโ[N]โ
- ์ด ๋จ๊ณ๋ ํ๋ก์ ์ ๊ณ์๊ฐ ์๊ฐ์ ๋ฐ๋ผ ์ด๋ป๊ฒ ๋ณํํ๋์ง๋ฅผ ์ค๋ช ํ๋ ODE(์๋ฏธ๋ถ๋ฐฉ์ ์)๋ฅผ ์๋ฆฝํ๋ ๋ฐ ํ์์ ์ ๋๋ค. ์ด๋ฅผ ํตํด, c(t)c(t)c(t)์ ๋์ญํ์ด ์ ๋์ ์ผ๋ก ๋ถ์ ๊ฐ๋ฅํ๊ฒ ๋ฉ๋๋ค.
HiPPO ํ๋ ์์ํฌ
HiPPO๋ ํจ์ ๊ทผ์ฌ๋ฅผ ์ํ ์ผ์ข ์ ๋์ ์์คํ ๋ฐฉ๋ฒ๋ก ์ผ๋ก, ์ฃผ์ด์ง ํจ์ f(t)f(t)f(t)๋ฅผ ์๊ฐ์ ๋ฐ๋ผ ์์ถํ๊ณ ์ ์ฅํ๋ ๊ณผ์ ์ ๋ค๋ฃน๋๋ค. ์ด ๊ณผ์ ์ ์ธก๋์ ๊ธฐ๋ฐํ ์ง๊ต ๊ธฐ์ ๋ฅผ ์ฌ์ฉํ์ฌ ํจ์๋ฅผ ๋คํญ์ ๊ณต๊ฐ์ผ๋ก ํฌ์(projection)ํ๊ณ , ์๊ฐ์ ๋ฐ๋ผ ๋ณํํ๋ ํจ์์ ์ ๋ณด๋ฅผ ํจ์จ์ ์ผ๋ก ํํํ ์ ์๋๋ก ์ค๊ณ๋์์ต๋๋ค.
์๋์ ๊ทธ๋ฆผ์ผ๋ก ์ ๋ฆฌ๋ฅผ ํ๋๊น ์ดํด๊ฐ ๋๋๊ตฐ์! ๐ฅ (์ค๋๋ง์ ์์๋ณด๋๊น ๋จธ๋ฆฌ๊ฐ๐ฑ)
๊ธ๋ก ๋ค์ ํ๋ฒ ์ข ์ ๋ฆฌํด๋ณผ๊น์?
โ Projection ์ฐ์ฐ : ํจ์ f(t)f(t)f(t)๋ฅผ ๋คํญ์ ๊ณต๊ฐ์ผ๋ก ํฌ์
- ํฌ์ ์ฐ์ฐ์ proj\text{proj}proj๋ ํจ์ f(t)f(t)f(t)๋ฅผ ์ผ์ ์๊ฐ ttt๊น์ง์ ์ ๋ณด๋ก ์ ํํ์ฌ ๋คํญ์ ๊ณต๊ฐ GGG์ ํฌ์ํฉ๋๋ค. ์ฆ, ์ฃผ์ด์ง f(t)f(t)f(t)์ ์ ๋ณด๋ฅผ ๋คํญ์ g(t)g(t)g(t)๋ก ๊ทผ์ฌํ์ฌ ๋ํ๋ ๋๋ค.
- ์ด ๊ณผ์ ์์ ์ค์ํ ๊ฒ์, ํฌ์์ ํตํด ์ป์ ๋คํญ์์ด ์๊ฐ ttt ์ด์ ์ ํจ์ ์ ๋ณด fโคt(x)f_{\leq t}(x)fโคtโ(x)๋ฅผ ์ต๋ํ ์ ํํ๊ฒ ํํํ๋ ๊ฒ์ ๋๋ค. ํฌ์ ์ฐ์ฐ์ ๋ชฉํ๋, ์ฃผ์ด์ง ์ธก๋ ฮผ(t)\mu(t)ฮผ(t) ํ์์ ์ค์ฐจ๊ฐ ์ต์ํ๋๋๋ก ๋คํญ์ g(t)g(t)g(t)๋ก ํจ์๋ฅผ ๊ทผ์ฌํ๋ ๊ฒ์ ๋๋ค.
โก Coefficients ๊ณ์ฐ: ๊ณ์ c(t)c(t)c(t) ๊ตฌํ๊ธฐ
- ํฌ์๋ ๋คํญ์ g(t)g(t)g(t)๋ ๋คํญ์ ๊ธฐ์ ํจ์๋ค์ ์ ํ ๊ฒฐํฉ์ผ๋ก ํํ๋๋ฉฐ, ๊ฐ ๊ธฐ์ ํจ์์ ๊ณฑํด์ง๋ ๊ณ์ c(t)c(t)c(t)๋ ์๊ฐ์ ๋ฐ๋ผ ๋ณํํฉ๋๋ค.
- HiPPO๋ ์ด ๊ณ์ c(t)c(t)c(t)๋ฅผ ํจ์จ์ ์ผ๋ก ๊ณ์ฐํ์ฌ, ํจ์ f(t)f(t)f(t)์ ๊ณผ๊ฑฐ ๊ธฐ๋ก์ ์์ถํ๋ ๋ฐฉ์์ผ๋ก ํํํฉ๋๋ค. c(t)c(t)c(t)๋ RN\mathbb{R}^NRN์ ๋ฒกํฐ๋ก, ์ด๋ ์ ํ๋ NNN๊ฐ์ ๊ธฐ์ ํจ์์ ๋ํ ๊ณ์๋ฅผ ์๋ฏธํฉ๋๋ค.
โข ๋ฏธ๋ถ ๋ฐฉ์ ์ (ODE)์ผ๋ก ๊ณ์์ ์งํ ๋ชจ๋ธ๋ง
- ํฌ์๋ ํจ์์ ๊ณ์ c(t)c(t)c(t)๋ ์๊ฐ์ ๋ฐ๋ผ ์งํํ๋ฉฐ, ์ด ๋ณํ๋ ์๋ฏธ๋ถ ๋ฐฉ์ ์(ODE)์ผ๋ก ํํ๋ฉ๋๋ค:ddtc(t)=A(t)c(t)+B(t)f(t)\frac{d}{dt}c(t) = A(t)c(t) + B(t)f(t)dtdโc(t)=A(t)c(t)+B(t)f(t)
- ์ด ๋ฐฉ์ ์์ ๊ณ์ c(t)c(t)c(t)๊ฐ ์๊ฐ ttt์ ๋ฐ๋ผ ์ด๋ป๊ฒ ๋ณํํ๋์ง๋ฅผ ์ค๋ช ํฉ๋๋ค. A(t)A(t)A(t)์ B(t)B(t)B(t)๋ ๊ฐ๊ฐ ๊ณ์์ ํจ์์ ๋ณํ์จ์ ๋ํ๋ด๋ ํ๋ ฌ์ ๋๋ค.
- ์ค์ํ ์ ์, HiPPO๊ฐ ์ด ODE๋ฅผ ํตํด ํจ์๋ฅผ ์๊ฐ์ ๋ฐ๋ผ ์จ๋ผ์ธ ๋ฐฉ์์ผ๋ก ์์ถํ๋ค๋ ๊ฒ์ ๋๋ค. ์ฆ, ์ค์๊ฐ์ผ๋ก ํจ์์ ์ ๋ณด๋ฅผ ์ ์ฅํ๊ณ ์งํ์ํต๋๋ค.
๐ก High Order Projection: Measure Families and HiPPO ODEs
- HiPPO ํ๋ ์์ํฌ์์
๊ณ ์ฐจ ๋คํญ์ ํฌ์(High Order Projection)
์ ํตํด ๊ณผ๊ฑฐ ๋ฐ์ดํฐ๋ฅผ ๋คํญ์ ํํ๋ก ํจ์จ์ ์ผ๋ก ์์ถํ๊ณ ์ด๋ฅผ ์ค์๊ฐ์ผ๋ก ์ ๋ฐ์ดํธํ๋ ๊ฒ์ ๋๋ค.
- ํนํ, HiPPO์์๋ LagT(Translated Laguerre Measure)์ LegT(Translated Legendre Measure) ๋ ๊ฐ์ง ์ธก์ (Measure) ๋ฐฉ๋ฒ์ ์ ์ํ๊ณ , ์ด๋ฅผ ๋ฐํ์ผ๋ก ๋ฏธ๋ถ ๋ฐฉ์ ์(ODE)์ ์ฌ์ฉํ์ฌ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ ๋ฐ์ดํธํ๋ ๋ฐฉ์์ ์ ์ํฉ๋๋ค.
๐ฌ Translated Laguerre Measure (LagT)
LagT
๋ ์ต๊ทผ์ ๋ฐ์ดํฐ๊ฐ ๋ ์ค์ํ๋ค๋ ๊ฐ์ ์ ๋ฐ์ํฉ๋๋ค.
- ๊ณผ๊ฑฐ๋ก ๊ฐ์๋ก ๋ฐ์ดํฐ์ ์ค์๋๊ฐ ์ง์์ ์ผ๋ก ๊ฐ์ํฉ๋๋ค.
Measure ์ ์:
ฮผ(t)(x)=eโ(tโx)(ifย xโคt)\mu(t)(x) = e^{-(t-x)} \quad \text{(if } x \leq t \text{)}ฮผ(t)(x)=eโ(tโx)(ifย xโคt)
- ์ด ์์์ xโคtx \leq txโคt์ผ ๋๋ง ์ ์๋๋ฉฐ, ๊ณผ๊ฑฐ๋ก ๊ฐ์๋ก eโ(tโx)e^{-(t-x)}eโ(tโx)๋ผ๋ ํจ์๊ฐ ์ง์์ ์ผ๋ก ๊ฐ์ํจ์ ๋ํ๋ ๋๋ค.
- ์ด๋ ์ต๊ทผ์ ๋ฐ์ดํฐ๊ฐ ๊ณผ๊ฑฐ์ ๋ฐ์ดํฐ๋ณด๋ค ์ค์ํ๋ค๋ ์๋ฏธ์ ๋๋ค.
ODE ํํ:
ddtc(t)=โAc(t)+Bf(t)\frac{d}{dt} c(t) = -Ac(t) + Bf(t)dtdโc(t)=โAc(t)+Bf(t)
- ์ฌ๊ธฐ์ c(t)c(t)c(t)๋ ํฌ์๋ ๋คํญ์์ ๊ณ์ ๋ฒกํฐ๋ฅผ ์๋ฏธํฉ๋๋ค.
- ์ด ์์์ ์ฃผ์ด์ง ๋ฐ์ดํฐ f(t)f(t)f(t)๋ LagT๊ฐ ์ต๊ทผ ๋ฐ์ดํฐ๋ฅผ ์ค์ํ๊ฒ ๋ฐ์ํ๋๋ก ์ค๊ณ๋ ๋ฐฉ์์ผ๋ก ๋คํญ์ ๊ธฐ์ ์ ํฌ์๋ฉ๋๋ค.
ํ๋ ฌ A์ B ์ ์:
Ank={1ifย nโฅk0ifย n<k,Bn=1A_{nk} = \begin{cases} 1 & \text{if } n \geq k \ 0 & \text{if } n < k \end{cases} \quad, \quad B_n = 1Ankโ={10โifย nโฅkifย n<kโ,Bnโ=1
- ์ด๋ ์ง์์ ๊ฐ์๋ฅผ ๋ฐ์ํ ๋ฉ์ปค๋์ฆ์ผ๋ก, ์ต๊ทผ์ ๋ฐ์ดํฐ๊ฐ ๋ ์ค์ํ ๋ฐฉ์์ผ๋ก ๋คํญ์์ ๊ณ์๋ค์ ์ ๋ฐ์ดํธํฉ๋๋ค.
๐ฌ Translated Legendre Measure (LegT)
LegT
๋ ๊ณ ์ ๋ ์๊ฐ ๋ฒ์ ๋ด์ ๋ฐ์ดํฐ๋ง ์ค์ํ๋ค๊ณ ๊ฐ์ ํฉ๋๋ค.
- ์ฆ, ์ผ์ ๊ธธ์ด์ ์ฌ๋ผ์ด๋ฉ ์๋์ฐ(Sliding Window) ๋ฐฉ์์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค.
Measure ์ ์:
ฮผ(t)(x)=1ฮธItโฮธ,t\mu(t)(x) = \frac{1}{\theta} I_{[t-\theta, t]}(x)ฮผ(t)(x)=ฮธ1โI[tโฮธ,t]โ(x)
- ์ฌ๊ธฐ์ I[tโฮธ,t]I_{[t-\theta, t]}I[tโฮธ,t]โ๋ ์ฌ๋ผ์ด๋ฉ ์๋์ฐ๋ฅผ ๋ํ๋ด๋ฉฐ, ๊ธธ์ด ฮธ\thetaฮธ๋งํผ์ ์๊ฐ ์ฐฝ์์ ๋ฐ์ดํฐ์ ๊ท ๋ฑํ ๊ฐ์ค์น๋ฅผ ๋ถ์ฌํฉ๋๋ค.
- ์ฆ, ์๊ฐ ์ฐฝ [tโฮธ,t][t-\theta, t][tโฮธ,t] ์ฌ์ด์ ๋ฐ์ดํฐ๋ฅผ ์ค์ํ๊ฒ ๋ค๋ฃน๋๋ค.
ODE ํํ:
ddtc(t)=โ1ฮธAc(t)+1ฮธBf(t)\frac{d}{dt} c(t) = -\frac{1}{\theta} Ac(t) + \frac{1}{\theta} Bf(t)dtdโc(t)=โฮธ1โAc(t)+ฮธ1โBf(t)
- ์ฌ๊ธฐ์๋ ์ญ์ c(t)c(t)c(t)๋ ๋คํญ์์ ๊ณ์ ๋ฒกํฐ๋ฅผ ๋ํ๋ ๋๋ค.
ํ๋ ฌ A์ B ์ ์:
Ank={(โ1)nโk(2n+1)ifย nโฅk2n+1ifย n<k,Bn=(โ1)n(2n+1)A_{nk} = \begin{cases} (-1)^{n-k}(2n + 1) & \text{if } n \geq k \ 2n + 1 & \text{if } n < k \end{cases} \quad, \quad B_n = (-1)^n (2n + 1)Ankโ={(โ1)nโk(2n+1)2n+1โifย nโฅkifย n<kโ,Bnโ=(โ1)n(2n+1)
- ์ด๋ ์ผ์ ํ ์๊ฐ ์ฐฝ ๋ด์์ ๋ฐ์ดํฐ๋ฅผ ํฌ์ํ์ฌ ์ ์งํ๋ฉฐ, ์ผ์ ์๊ฐ ๋ฒ์ ๋ด์ ๋ฐ์ดํฐ์๋ง ์ค์์ฑ์ ๋ถ์ฌํฉ๋๋ค.
โฃ Discrete-time HiPPO Recurrence (์ด์ฐ ์๊ฐ ์ฌ๊ท ๊ด๊ณ)
- HiPPO ํ๋ ์์ํฌ๋ฅผ ์ฐ์ ์๊ฐ(Continuous Time)์์ ์ด์ฐ ์๊ฐ(Discrete Time)์ผ๋ก ๋ณํํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์ค๋ช ํฉ๋๋ค.
-
์ด๋ ์ค์ง์ ์ธ ์๊ณ์ด ๋ฐ์ดํฐ๋ ์ด์ฐ์ ์ธ ์ํ์ค ๋ฐ์ดํฐ์ ์ ์ฉํ๊ธฐ ์ํด ODE๋ฅผ ์ด์ฐํํ๋ ๊ณผ์ ์ ๋๋ค.
- ODE๋ฅผ ์ด์ฐํํ์ฌ ์ค์ง์ ์ผ๋ก ๊ณ์ฐ ๊ฐ๋ฅํ ํํ๋ก ๋ง๋ค๋ฉด, ์๋์ ๊ฐ์ ์ฌ๊ท ๊ด๊ณ๋ฅผ ์ป๊ฒ ๋ฉ๋๋ค:ck+1=Akck+Bkfkc_{k+1} = A_k c_k + B_k f_kck+1โ=Akโckโ+Bkโfkโ
- ์ด ์์ ์ด์ ์๊ฐ์ ๊ณ์ ckc_kckโ์ ์๋ก์ด ํจ์ ๊ฐ fkf_kfkโ์ ์ฌ์ฉํ์ฌ ๋ค์ ์๊ฐ k+1k+1k+1์์์ ๊ณ์ ck+1c_{k+1}ck+1โ๋ฅผ ๊ณ์ฐํฉ๋๋ค. ์ฆ, ์ด ์์ ํจ์์ ์ ๋ณด๋ฅผ ์ด์ฐ์ ์ธ ์๊ฐ ๋จ๊ณ์์ ์ฌ๊ท์ ์ผ๋ก ์ ๋ฐ์ดํธํ๋ ๋ฐฉ์์ผ๋ก ๊ตฌํ๋ฉ๋๋ค.
- ์ด ๊ณผ์ ์ ํตํด, HiPPO๋ ํจ์์ ๊ณผ๊ฑฐ ๊ธฐ๋ก์ ์ ํ ๊ฒฐํฉ์ ํํ๋ก ์์ถํ์ฌ ์ ์ฅํ๊ณ , ์ค์๊ฐ์ผ๋ก ์ ๋ฐ์ดํธํ๋ ํจ๊ณผ์ ์ธ ๋ฐฉ๋ฒ์ ์ ๊ณตํฉ๋๋ค.
๐ HiPPO-LegS: Scaled Measures for Timescale Robustness (HiPPO-LegS: ์๊ณ์ด ๊ฒฌ๊ณ ์ฑ์ ์ํ ํ์ฅ๋ ์ธก์ ๋ฐฉ๋ฒ)
HiPPO-LegS
๋ ์๊ฐ ์ฒ๋์ ๊ฐ๊ฑดํ ๋ฉ๋ชจ๋ฆฌ ๋ฉ์ปค๋์ฆ์ ์ ๊ณตํ๋ ์๋ก์ด ์ ๊ทผ ๋ฐฉ์์ ๋๋ค. ์ด ๋ฉ์ปค๋์ฆ์ ๊ณผ๊ฑฐ ๋ชจ๋ ์๊ฐ์ ๋ํด ๊ท ๋ฑํ ๊ฐ์ค์น๋ฅผ ๋ถ์ฌํ์ฌ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๊ตฌ์ฑํฉ๋๋ค.
- ์ ์ฒด ์ด๋ ฅ ๊ณ ๋ ค: LegS๋ ์์ ํ ๊ณผ๊ฑฐ ์ด๋ ฅ์ ๊ณ ๋ คํ์ฌ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๊ตฌ์ฑํ๋ฉฐ, ์ด๋ ํน์ ์ฌ๋ผ์ด๋ฉ ์๋์ฐ๋ฅผ ์ฌ์ฉํ๋ ๊ธฐ๋ฒ๊ณผ ๋ฌ๋ฆฌ ๋ชจ๋ ๊ณผ๊ฑฐ ๋ฐ์ดํฐ๋ฅผ ๊ท ๋ฑํ๊ฒ ํ๊ฐํฉ๋๋ค.
- ๋ฐ๋ฉด, LagT์ LegT๋ ํน์ ์๊ฐ ๋ฒ์ ๋ด์์ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ฏ๋ก ์ฅ๊ธฐ์ ์์กด์ฑ์ ํฌ์ฐฉํ๋ ๋ฐ ํ๊ณ๊ฐ ์์ ์ ์์ต๋๋ค.
- ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ์์: LegS๋ ๋ฉ๋ชจ๋ฆฌ ๊ตฌ์ฑ์ ํ์ํ ํ์ดํผํ๋ผ๋ฏธํฐ ์์ด ๋์ํฉ๋๋ค.
- ๋ฐ๋ฉด, LagT์ LegT๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์กฐ์ ํด์ผ ํ ์ ์์ต๋๋ค.
- ์๊ฐ ์ค์ผ์ผ์ ๋ํ ๊ฐ๊ฑด์ฑ: LegS๋ ์ ๋ ฅ ์ ํธ์ ์๊ฐ ์ฒ๋๊ฐ ๋ฐ๋์ด๋ ์์ ์ ์ผ๋ก ๋์ํ ์ ์์ต๋๋ค.
- ๋ฐ๋ฉด LagT๋ LegT๋ ํน์ ์๊ฐ ์ฒ๋์ ๋ํด ์ต์ ํ๋ ๊ฒฐ๊ณผ๋ฅผ ๋ํ๋ผ ์ ์์ง๋ง, ๋ค๋ฅธ ์๊ฐ ์ฒ๋์์๋ ์ฑ๋ฅ์ด ์ ํ๋ ์ ์์ต๋๋ค.
- ์๊ฐ์ ๋ฐ๋ฅธ ๊ณ์ฐ ํจ์จ์ฑ: LegS๋ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฐ์ดํธ ๊ณผ์ ์ ๊ฐ์ํํ์ฌ ๊ฐ ์๊ฐ ๋จ๊ณ์์ ๋ ๋น ๋ฅด๊ฒ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.
- LagT๋ LegT๋ ์๋์ ์ผ๋ก ๋ณต์กํ ์ ๋ฐ์ดํธ ๊ท์น์ ์ฌ์ฉํด์ผ ํ ์ ์์ต๋๋ค.
- Gradient ๋ฐ ์ญ์ ํ ๋ฌธ์ ํด๊ฒฐ: LegS๋ ๊ธฐ์ธ๊ธฐ ํฌ๊ธฐ๊ฐ ๋ณด์กด๋ ์ ์๋ ๋ฉ์ปค๋์ฆ์ ์ ๊ณตํ์ฌ, ๊ธด ์ํ์ค์ ๊ฑธ์ณ ํ์ต ์์ ์ฑ์ ๋์ ๋๋ค.
- LagT์ LegT๋ ๋๋๋ก ๊ทธ๋๋์ธํธ๊ฐ ์์ค๋๋ ๋ฌธ์ ๊ฐ ๋ฐ์ํ ์ ์์ต๋๋ค.
Empirical Validation (์ค์ฆ์ ๊ฒ์ฆ)
- 4.1 Long-range Memory Benchmark Tasks (์ฅ๊ธฐ ๋ฉ๋ชจ๋ฆฌ ๋ฒค์น๋งํฌ ๊ณผ์ ): ์ฅ๊ธฐ ๋ฉ๋ชจ๋ฆฌ ์์กด์ฑ์ ํ๊ฐํ๋ ๋ค์ํ ๋ฒค์น๋งํฌ ๊ณผ์ ์์ HiPPO-LegS์ ์ฑ๋ฅ์ ๊ฒ์ฆํฉ๋๋ค.
- 4.2 Timescale Robustness of HiPPO-LegS (HiPPO-LegS์ ์๊ณ์ด ๊ฒฌ๊ณ ์ฑ): HiPPO-LegS๊ฐ ๋ค์ํ ์๊ฐ ์ฒ๋์์ ์ผ๋ง๋ ๊ฒฌ๊ณ ํ๊ฒ ์ฑ๋ฅ์ ๋ฐํํ๋์ง ๊ฒ์ฆํฉ๋๋ค.
- 4.3 Theoretical Validation and Scalability (์ด๋ก ์ ๊ฒ์ฆ ๋ฐ ํ์ฅ์ฑ): HiPPO ํ๋ ์์ํฌ๊ฐ ์ด๋ก ์ ์ผ๋ก ์ด๋ป๊ฒ ์ฑ๋ฅ์ด ๋ณด์ฅ๋๋์ง์ ๊ทธ ํ์ฅ์ฑ์ ์ค๋ช ํฉ๋๋ค.
- 4.4 Additional Experiments (์ถ๊ฐ ์คํ): ์ถ๊ฐ ์คํ์ ํตํด HiPPO ๋ฉ๋ชจ๋ฆฌ ๋ฉ์ปค๋์ฆ์ ์ ์ฉ์ฑ์ ๊ฒ์ฆํฉ๋๋ค.
Conclusion (๊ฒฐ๋ก )
- HiPPO ํ๋ ์์ํฌ๊ฐ ๋ฉ๋ชจ๋ฆฌ ๋ฌธ์ ์ ๋ํ ๊ทผ๋ณธ์ ์ธ ํด๊ฒฐ์ฑ ์ ์ ์ํ๋ฉฐ, ๊ธฐ์กด์ ๋ฉ๋ชจ๋ฆฌ ๋ฉ์ปค๋์ฆ์ ํตํฉํ๊ณ ํ์ฅํ์ฌ ๋ ๋์ ์ฑ๋ฅ์ ๋ฐํํ ์ ์์์ ๊ฒฐ๋ก ์ผ๋ก ์ ์ํฉ๋๋ค
-
LSSL: Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers (NeurIPS, 2021)
์ด ๋ ผ๋ฌธ์ ๋ชฉ์ฐจ์ ๋ฐ๋ฅธ ๊ฐ๋ ๋ค์ ๋ค์๊ณผ ๊ฐ์ด ์ค๋ช ํ๊ฒ ์ต๋๋ค:
Introduction
LSSL(Linear State-Space Layer)
๋ ์ํ(Recurrent), ํฉ์ฑ๊ณฑ(Convolutional), ์ฐ์ ์๊ฐ ๋ชจ๋ธ(Continuous-time)์ ์ฅ์ ์ ๊ฒฐํฉํ ์๋ก์ด ๋ชจ๋ธ ํจ๋ฌ๋ค์์ผ๋ก, ์๊ฐ์ ๋ฐ๋ฅธ ์์ฐจ ๋ฐ์ดํฐ ์ฒ๋ฆฌ๋ฅผ ๋์ฑ ํจ์จ์ ์ผ๋ก ํ ์ ์๋๋ก ์ค๊ณ๋ ๊ตฌ์กฐ์ ๋๋ค.-
๋ฐฐ๊ฒฝ ๋ฐ ๋ฌธ์ ์ ์:
-
๋จธ์ ๋ฌ๋์์ ์ํ์ค ๋ฐ์ดํฐ(Sequential Data)๋ฅผ ์ฒ๋ฆฌํ๋ ์ผ๋ฐ์ ์ธ ๋ฐฉ์์ RNN(Recurrent Neural Network), CNN(Convolutional Neural Network), NeuralODE(Neural Differential Equation) ๋ฑ์ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค. ์ด๋ค์ ๊ฐ๊ฐ ์ฅ๋จ์ ์ด ์์ต๋๋ค.
RNN
์ ์ํ์ค ๋ฐ์ดํฐ์ ๋ํ ์ํ ์ ์ฅ(Stateful) ์ฑ์ง์ ๊ฐ๊ณ ์์ผ๋, ๋งค ์คํ ๋ง๋ค ์ ์ฅ๊ณผ ๊ณ์ฐ์ด ํ์ํ๋ฏ๋ก ๋งค์ฐ ๋นํจ์จ์ ์ ๋๋ค. ๋ํ์ ์ธ ๋ฌธ์ ๋ก๋ Vanishing Gradient Problem์ด ์์ต๋๋ค.CNN
์ ๋ณ๋ ฌ ์ฒ๋ฆฌ์ ๋น ๋ฅธ ํ๋ จ์ด ๊ฐ๋ฅํ๋, ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ํ๊ณ๊ฐ ์์ต๋๋ค. ์ฆ, ๋ก์ปฌ ์ ๋ณด์ ๊ตญํ๋์ด ์์ผ๋ฉฐ ๊ธด ๋ฌธ๋งฅ(long-term dependency)์ ๊ธฐ์ตํ๋ ๋ฅ๋ ฅ์ด ๋ถ์กฑํฉ๋๋ค.NeuralODE
๋ ์ฐ์์ ์๊ฐ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ํ์ ์ผ๋ก ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ์ง๋ง ๊ณ์ฐ ๋น์ฉ์ด ๋ง์ด ๋ค๊ณ , ํนํ ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ๋ ๋งค์ฐ ๋นํจ์จ์ ์ ๋๋ค.
-
- LSSL์ ์ ์ ๋ฐ ๋ชฉ์ :
- ๋ณธ ์ฐ๊ตฌ์์๋ ์ด๋ฌํ
RNN
,CNN
,NeuralODE
๊ฐ๊ฐ์ ์ฅ์ ์ ์ด๋ฆฌ๋ฉด์๋ ๊ฐ ๋ชจ๋ธ์ ๋จ์ ์ ๊ทน๋ณตํ๋ ์๋ก์ด ๊ตฌ์กฐ์ธ Linear State-Space Layer(LSSL)๋ฅผ ์ ์ํฉ๋๋ค. - ์ฃผ์ ๋ชฉํ๋
CNN
์ ๋ณ๋ ฌ ์ฒ๋ฆฌ ์ฅ์ ,RNN
์ ์ํ ์ถ๋ก ๋ฅ๋ ฅ,NeuralODE
์ ์๊ฐ ์ฒ๋(Time-scale) ์ ์๋ ฅ์ ๋์์ ์ ๊ณตํ๋ ๋ชจ๋ธ์ ๊ฐ๋ฐํ๋ ๊ฒ์ ๋๋ค.- ์ฌ๊ท์ฑ(Recurrent): ํน์ ์๊ฐ ๊ฐ๊ฒฉ ฮt\Delta tฮt๋ฅผ ์ฌ์ฉํ์ฌ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ๋ถ์ฐ์ํ(Discretization)ํ๋ฉด, ์ฌ๊ท์ ์ธ ๋ฐฉ์์ผ๋ก ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค. ์ด๋ฅผ ํตํด RNN์ฒ๋ผ ์ํ๋ฅผ ์ถ์ ํ ์ ์์ต๋๋ค.
- ํฉ์ฑ๊ณฑ์ฑ(Convolutional): ์ ํ ์๊ฐ ๋ถ๋ณ ์์คํ (Linear Time-Invariant System, LTI)์ผ๋ก์, ์ฐ์์ ์ธ ํฉ์ฑ๊ณฑ์ผ๋ก ํํ์ด ๊ฐ๋ฅํฉ๋๋ค. ์ด๋ฅผ ํตํด CNN๊ณผ ๊ฐ์ด ๋ณ๋ ฌ ์ฒ๋ฆฌ ๋ฐ ํจ์จ์ ์ธ ํ๋ จ์ด ๊ฐ๋ฅํฉ๋๋ค.
- ์ฐ์ ์๊ฐ ๋ชจ๋ธ(Continuous-time): LSSL์ ๋ฏธ๋ถ ๋ฐฉ์ ์์ผ๋ก ํํ๋๋ฏ๋ก ์ฐ์ ์๊ฐ ๋ชจ๋ธ๋ก์์ ์ฅ์ ์ ๊ฐ์ง๋ฉฐ, ๋ค์ํ ์๊ฐ ์ฒ๋์ ์ ์ํ ์ ์๋ ์ ์ฐ์ฑ์ ์ ๊ณตํฉ๋๋ค.
- ๋ณธ ์ฐ๊ตฌ์์๋ ์ด๋ฌํ
์๋ ๊ทธ๋ฆผ์ ๋ ผ๋ฌธ์์ ๋์จ Figure1๋ก ์์์ ์ค๋ช ํ๋ LSSL์ 3๊ฐ์ง View๋ฅผ ์ค๋ช ํฉ๋๋ค.
-
View 1
. Continuous-time ๊ด์ :- ์ด ๋ชจ๋์์๋ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ด ์ฐ์์ ์๊ฐ ttt์ ๋ฐ๋ผ ๋ณํ๋ฉฐ, ๋ถ๊ท์นํ ์ํ๋ง ๋ฐ์ดํฐ๋ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค. (๋ฏธ๋ถ ๋ฐฉ์ ์ ํํ)
- ์ xห(t)=Ax(t)+Bu(t)\dot{x}(t) = A x(t) + B u(t)xห(t)=Ax(t)+Bu(t)๋ ์ํ๊ฐ ์๊ฐ์ ๋ฐ๋ผ ์ด๋ป๊ฒ ๋ณํํ๋์ง ๋ํ๋ด๋ฉฐ, ์ถ๋ ฅ์ y(t)=Cx(t)+Du(t)y(t) = C x(t) + D u(t)y(t)=Cx(t)+Du(t)๋ก ์ ์๋ฉ๋๋ค.
-
View 2
. Recurrent ๊ด์ :- ์ด์ฐํ(Discretization)๋ฅผ ํตํด RNN๊ณผ ๊ฐ์ ํํ๋ก ์ฌ์ฉํ ์ ์์ผ๋ฉฐ ์๊ฐ ๊ฐ๊ฒฉ ฮt\Delta tฮt์ ๋ฐ๋ผ ์ํ๊ฐ ๋ณํํ๊ณ , ์ด์ ์ํ ์ ๋ณด xkโ1x_{k-1}xkโ1โ๋ฅผ ์ฌ์ฉํ์ฌ ํ์ฌ ์ํ xkx_kxkโ์ ์ถ๋ ฅ์ ๊ณ์ฐํฉ๋๋ค.
- ์ด๋ฅผ ํตํด ๋ฌดํํ ๋ฌธ๋งฅ(Unbounded Context)์ ์ฒ๋ฆฌํ ์ ์์ผ๋ฉฐ, ํจ์จ์ ์ธ ์ถ๋ก ์ด ๊ฐ๋ฅํฉ๋๋ค.
-
View 3
. Convolutional ๊ด์ :- ํฉ์ฑ๊ณฑ์ ๋ฐฉ์์ผ๋ก๋ ํํ์ด ๊ฐ๋ฅํฉ๋๋ค. ํฉ์ฑ๊ณฑ ์ปค๋ KKK๋ ์ ํ ์์คํ ์ ๊ธฐ๋ฐ์ผ๋ก ๊ณ์ฐ๋๋ฉฐ, ์ด๋ฅผ ํตํด ์ ๋ ฅ ์ํ์ค์ ๋ํด ๋ณ๋ ฌ๋ก ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.
- CNN๊ณผ ๊ฐ์ด ๋ก์ปฌ ์ ๋ณด(Local Information)๋ฅผ ์ฌ์ฉํ๋ฉด์๋, ๋ณ๋ ฌํ๋ ํ๋ จ์ด ๊ฐ๋ฅํ๋ค๋ ์ฅ์ ์ด ์์ต๋๋ค.
Linear State-Space Layers (LSSL)
-
3.1 LSSL์ ๋ค์ํ ๋ทฐ (Different Views of the LSSL)
- LSSL์ ๊ธฐ๋ณธ ์์์
์ํ ๊ณต๊ฐ ํํ(state-space representation)
์ธA, B, C, Dํ๋ ฌ
์ ์ฌ์ฉํ์ฌ ์ ์๋ฉ๋๋ค. ์์์ผ๋ก๋ ์๋์ ๊ฐ์ด ํํ๋ฉ๋๋ค.xห(t)=Ax(t)+Bu(t)\dot{x}(t) = A x(t) + B u(t)xห(t)=Ax(t)+Bu(t) y(t)=Cx(t)+Du(t)y(t) = C x(t) + D u(t)y(t)=Cx(t)+Du(t) - LSSL์ ์ด ๋ชจ๋ธ์
์ด์ฐํ(discretization)
ํ์ฌ ฮt\Delta tฮt๋ผ๋ ํ์์คํ ์ ๊ธฐ๋ฐ์ผ๋ก ์ ๋ ฅ ์ํ์ค u(t)u(t)u(t)๋ฅผ ์ถ๋ ฅ ์ํ์ค y(t)y(t)y(t)๋ก ๋ณํํ๋ ์ํ์ค ํฌ ์ํ์ค ๋งตํ์ ์ ๊ณตํฉ๋๋ค. ์ด๋, ๊ฐ ํ์์คํ ์ H-dim feature ๋ฒกํฐ๋ฅผ ํฌํจํ ์ํ์ค๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค. - LSSL์ ์ฌ๋ฌ ๊ฐ์ง ๋ฐฉ์์ผ๋ก ๊ณ์ฐ๋ ์ ์์ผ๋ฉฐ, ๊ทธ ๋ฐฉ์๋ค์ ํฌ๊ฒ
์ฌ๊ท์ ๋ชจ๋ธ(Recurrent Model)
,ํฉ์ฑ๊ณฑ ๋ชจ๋ธ(Convolutional Model)
,์ฐ์ ์๊ฐ ๋ชจ๋ธ(Continuous-Time Model)
๋ก ๋๋ฉ๋๋ค. ๋ ผ๋ฌธ์ ํด๋น ํํธ์์ ์ด๋ฅผ ๋์์ ์ผ๋ก ํํํ๋ฉด์ ๊ฐ ๋ฐฉ์์ด ์ด๋ป๊ฒ ๋ค๋ฅด๊ฒ ์๋ํ๋์ง๋ฅผ ๋ณด์ฌ์ค๋๋ค. (์ด๋ ์์์ 3. State Space Model(SSM) ์๊ฐ์์๋ ๋ค๋ค์ผ๋ ๋๋ฌด ๊น๊ฒ ๊ฐ์ง๋ ์๊ฒ ์ต๋๋ค)
- LSSL์ ๊ธฐ๋ณธ ์์์
โ Recurrent View (์ฌ๊ท์ ๊ด์ )
-
์ฌ๊ท์ ๊ด์
์์๋ ์ํ ๋ฒกํฐ xtโ1x_{t-1}xtโ1โ์ด ์ด์ ์ ๋ ฅ ์ ๋ณด์ ํ์ฌ ์ ๋ ฅ ์ ๋ณด ๊ฐ์ ๋ฌธ๋งฅ์ ์ ์งํฉ๋๋ค.- ์ด๋ฅผ ํตํด ํจ์จ์ ์ธ ์ํ ์ถ๋ก ์ ํ ์ ์์ผ๋ฉฐ, ์ํ ์ ๊ฒฝ๋ง(RNN)์ฒ๋ผ ์๋ํฉ๋๋ค.
โก Convolutional View (ํฉ์ฑ๊ณฑ ๊ด์ )
-
ํฉ์ฑ๊ณฑ ๊ด์
์์, LSSL์ state ๋ฒกํฐ๋ฅผ ํตํด ํํฐ๋ง๋ ์ถ๋ ฅ์ ์ ๊ณตํฉ๋๋ค.- ํฉ์ฑ๊ณฑ ๊ด์ ์์ ๊ณ์ฐ ํจ์จ์ฑ์ ๋์ด๊ธฐ ์ํด FFT(๋น ๋ฅธ ํธ๋ฆฌ์ ๋ณํ)๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.
-
3.2 LSSL์ ํํ๋ ฅ (Expressivity of LSSLs)
- ์ด ์ ์์๋ LSSL์ด ์ค์ ๋ก ์ด๋ ์ ๋๊น์ง ๋ค์ํ ์ฌ๊ท์ ํน์ฑ๊ณผ ํฉ์ฑ๊ณฑ์ ํน์ฑ ํํ์ ์ผ๋ง๋ ์ ํ ์ ์๋์ง๋ฅผ ๋ถ์ํฉ๋๋ค.
โ ํฉ์ฑ๊ณฑ์ด ๊ฐ๋ฅํ LSSL
- ์ํ ๊ณต๊ฐ ์์คํ ๊ณผ ์ํ์ค ์๋ต(Impulse Response) : ์ํ ๊ณต๊ฐ ์์คํ ์ ๊ธฐ๋ณธ์ ์ผ๋ก ์ฐ์ ์๊ฐ ๋๋ ๋ถ์ฐ์ ์๊ฐ ์์คํ ์ ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ ์ํ ๋ณ์๋ก ํํํ๋ ๋ฐฉ์์ ๋๋ค. LSSL๋ ์ด๋ฌํ ์ํ ๊ณต๊ฐ ์์คํ ์ ๊ธฐ๋ฐ์ผ๋ก ํ์ฌ ์ ๋ ฅ u(t)u(t)u(t)๋ฅผ ์๊ฐ์ ๋ฐ๋ผ ์ํ x(t)x(t)x(t)์ ์ถ๋ ฅ y(t)y(t)y(t)๋ก ๋ณํํฉ๋๋ค. ์ํ์ ์ผ๋ก๋ ๋ค์๊ณผ ๊ฐ์ ํํ์ ๋๋ค:xห(t)=Ax(t)+Bu(t)\dot{x}(t) = A x(t) + B u(t)xห(t)=Ax(t)+Bu(t) y(t)=Cx(t)+Du(t)y(t) = C x(t) + D u(t)y(t)=Cx(t)+Du(t)
- ์ฌ๊ธฐ์ ์
๋ ฅ u(t)u(t)u(t)๊ฐ ์ฃผ์ด์ก์ ๋ ์์คํ
์ด ์๊ฐ์ ๋ฐ๋ผ ์ด๋ป๊ฒ ๋ณํ๋์ง๋ฅผ ๋ํ๋ด๋ ํจ์๊ฐ ์ํ์ค ์๋ต ํจ์์
๋๋ค. ์ํ์ค ์๋ต ํจ์๋ ์์คํ
์ด ํน์ ์
๋ ฅ(์ฆ, ์ํ์ค)์ ๋ํด ์ด๋ป๊ฒ ๋ฐ์ํ๋์ง๋ฅผ ๋ณด์ฌ์ค๋๋ค.
๐ฌ ์ํ์ค ์๋ต ํจ์(Impulse Response Function, IRF)๋ ?
์ํ์ค ์๋ต ํจ์(IRF)
๋ ์์คํ ์ด๋ ์ ํธ ์ฒ๋ฆฌ์์ ์ค์ํ ๊ฐ๋ ์ ๋๋ค. ์ด๋ ์์คํ ์ด ๋จ์ ์ํ์ค ์ ๋ ฅ์ ๋ํด ์ด๋ป๊ฒ ๋ฐ์ํ๋์ง๋ฅผ ๋ํ๋ด๋ ํจ์์ ๋๋ค.์ํ์ค(impulse)
๋ ๋ฌผ๋ฆฌํ์์ ๋ฌผ์ฒด์ ์์ฉํ๋ ํ์ด ์๊ฐ์ ๊ฑธ์ณ ๋ณํํ๋ ๊ณผ์ ์ ์ค๋ช ํ๋ ๊ฐ๋ ์ ๋๋ค. ์ผ๋ฐ์ ์ผ๋ก ์ํ์ค๋ ํ๊ณผ ์๊ฐ์ ๊ณฑ์ผ๋ก ์ ์๋๋ฉฐ, ๋ฌผ์ฒด์ ์ด๋๋ ๋ณํ์ ๊ด๋ จ์ด ์์ต๋๋ค. ์์์ผ๋ก ํํํ๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค: J=Fโ ฮtJ = F \cdot \Delta tJ=Fโ ฮt
-
์ํ์ค ์๋ต๊ณผ ํฉ์ฑ๊ณฑ ์ฐ์ฐ : ์ํ์ค ์๋ต ํจ์ h(t)h(t)h(t)๋ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์์ ์ถ๋ ฅ์ ๊ณ์ฐํ๋๋ฐ ๋งค์ฐ ์ค์ํฉ๋๋ค. ์ํ์ค ์๋ต์ ์๋ฉด ์ ๋ ฅ ์ ํธ u(t)u(t)u(t)๊ฐ ์ฃผ์ด์ก์ ๋ ์์คํ ์ ์ถ๋ ฅ์ ๋ค์๊ณผ ๊ฐ์ ํฉ์ฑ๊ณฑ ์ฐ์ฐ์ผ๋ก ํํํ ์ ์์ต๋๋ค:
y(t)=(hโu)(t)=โซh(ฯ)u(tโฯ)dฯy(t) = (h * u)(t) = \int h(\tau) u(t - \tau) d\tauy(t)=(hโu)(t)=โซh(ฯ)u(tโฯ)dฯ
- ์ฆ, ์์คํ ์ ์ถ๋ ฅ์ ์ ๋ ฅ ์ ํธ u(t)u(t)u(t)์ ์์คํ ์ ์ํ์ค ์๋ต h(t)h(t)h(t)์ ํฉ์ฑ๊ณฑ์ผ๋ก ๊ณ์ฐ๋ฉ๋๋ค. ์ฌ๊ธฐ์ ํฉ์ฑ๊ณฑ ์ฐ์ฐ์ด ์ค์ํ ์ด์ ๋, ์ํ์ค ์๋ต ํจ์๊ฐ ์์คํ ์ ์๊ฐ์ ํน์ฑ์ ๊ฒฐ์ ํ๋ฉฐ, ์ด๋ฅผ ํตํด ๊ณผ๊ฑฐ์ ์ ๋ ฅ๋ค์ด ํ์ฌ์ ์ถ๋ ฅ์ ์ด๋ป๊ฒ ๊ฒฐ์ ํ๋์ง๋ฅผ ์ค๋ช ํ ์ ์๊ธฐ ๋๋ฌธ์ ๋๋ค.
-
LSSL์์ ํฉ์ฑ๊ณฑ์ ์ญํ : LSSL์ ์ํ ๊ณต๊ฐ ์์คํ ์ ๊ธฐ๋ฐ์ผ๋ก ํ์ง๋ง, ์ด๋ฅผ ์ด์ฐํ(Discretization)ํ์ฌ ํฉ์ฑ๊ณฑ์ผ๋ก ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.
- ์ด์ฐํ๋ ์์คํ ์ ์ค์ ๋ก ์๊ฐ์ ๋ฐ๋ผ ์ ๋ ฅ์ด ์ด๋ป๊ฒ ๋ณํ๋์ง๋ฅผ ๊ณ์ฐํ ๋, ์ํ์ค ์๋ต ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ํฉ์ฑ๊ณฑ ํํฐ๋ก ๋ณํํ ์ ์์ต๋๋ค.
- ์ฆ, ์ํ ๊ณต๊ฐ ์์คํ ์ ํตํด ์ ์๋ ์์คํ ์ ์๋ต์ ํฉ์ฑ๊ณฑ ํํฐ๋ก ํํํ ์ ์๋ค๋ ์๋ฏธ์ ๋๋ค.
โก LSSL์ RNN๊ณผ์ ๊ด๊ณ
-
RNN
์ ์ ๋ ฅ ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ๋, ์ด์ ์๊ฐ์ ์ํ htโ1h_{t-1}htโ1โ๋ฅผ ํ์ฌ ์ํ hth_thtโ์ ์ ๋ฌํจ์ผ๋ก์จ ์๊ฐ์ ์ข ์์ฑ์ ์ ์งํฉ๋๋ค.- ์ฆ, RNN์ ์ด์ ํ์์คํ ์ ์ ๋ณด๋ฅผ ๋ค์ ํ์์คํ ์ผ๋ก ์ ๋ฌํ๋ฉด์ ์ํ๋ฅผ ๊ฐฑ์ ํ๊ณ , ์ด๋ฅผ ํตํด ๊ธด ์ํ์ค์ ์ ๋ณด๋ฅผ ์ถ์ ํ ์ ์์ต๋๋ค. ์ํ์ ์ผ๋ก RNN์ ์ํ ๊ฐฑ์ ๋ฐฉ์ ์์ ๋ค์๊ณผ ๊ฐ์ด ํํ๋ฉ๋๋ค:ht=ฯ(Whhtโ1+Wxxt)h_t = \sigma(W_h h_{t-1} + W_x x_t)htโ=ฯ(Whโhtโ1โ+Wxโxtโ)
LSSL
๋ RNN์ฒ๋ผ ์๊ฐ์ ๋ฐ๋ฅธ ์ํ ๊ฐฑ์ ์ ์ํํฉ๋๋ค. LSSL์ ์ํ ๊ฐฑ์ ๋ฐฉ์ ์์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ๊ธฐ๋ฐํ ๋ฏธ๋ถ ๋ฐฉ์ ์์ผ๋ก ์ ์๋๋๋ฐ, ์ด๋ฅผ ์ด์ฐํํ๋ฉด RNN๊ณผ ์ ์ฌํ ๊ตฌ์กฐ๊ฐ ๋ฉ๋๋ค.
hโฒ(t)=Ah(t)+Bx(t)hโ(t) = Ah(t) + Bx(t)hโฒ(t)=Ah(t)+Bx(t)
y(t)=Ch(t)+Dx(t)y(t) = Ch(t) + Dx(t)y(t)=Ch(t)+Dx(t)
โ
์ด์ฐํ ์ํ
hk+1=Aหhk+Bหxkh_{k+1} = \bar{A}h_k + \bar{B}x_khk+1โ=Aหhkโ+Bหxkโ
yk=Chk+Dxky_k = Ch_k + Dx_kykโ=Chkโ+Dxkโ
- ๋ํ, RNN์ ์์ ๊ฐ์ ์ํ ๊ฐฑ์ ๊ณผ์ ์์ ๊ฒ์ดํ ๋ฉ์ปค๋์ฆ(Gating Mechanism)์ ํตํด ๊ฐ ํ์์คํ ์์ ์ ๋ณด๋ฅผ ์ผ๋ง๋ ์ ๋ฌํ ์ง ์กฐ์ ํฉ๋๋ค. LSTM์ด๋ GRU์์์ ๊ฒ์ดํ ๋ฉ์ปค๋์ฆ์ RNN์ด ๊ฐ ํ์์คํ ์์ ์ ๋ณด์ ํ๋ฆ์ ์กฐ์ ํ๋ ์ค์ํ ์์์ ๋๋ค.
-
์ด ๊ฒ์ดํ ๋ฉ์ปค๋์ฆ์ ์ฌ์ค์ ์๊ฐ ์ฒ๋(Time-scale)๋ฅผ ๋ถ๋๋ฝ๊ฒ ํ์ฌ ๊ฐ ์คํ ์์์ ์ํ ๋ณํ๊ฐ ๋๋ฌด ๊ธ๊ฒฉํ์ง ์๊ฒ ๋ง๋๋ ์ญํ ์ ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, LSTM์ Forget Gate๋ ์ด์ ์ํ๋ฅผ ์ผ๋ง๋ ๊ธฐ์ตํ ์ง ์กฐ์ ํ๋๋ฐ, ์ด๋ ์ผ์ ํ ์๊ฐ ์ฒ๋์์์ ๋ณํ๋ฅผ ๋ถ๋๋ฝ๊ฒ ํ๋ ์ญํ ์ ํฉ๋๋ค.
-
LSSL์์๋ ฮt\Delta tฮt๋ผ๋ ์๊ฐ ๊ฐ๊ฒฉ(Time-step)์ด ์ค์ํ ์ญํ ์ ํฉ๋๋ค. ์ด ์๊ฐ ๊ฐ๊ฒฉ์ ๊ฐ ํ์์คํ ๊ฐ์ ์ํ ๋ณํ๋ฅผ ๊ฒฐ์ ํ๋ฉฐ, ์ด๋ RNN์์ ๊ฒ์ดํ ๋ฉ์ปค๋์ฆ๊ณผ ๋งค์ฐ ์ ์ฌํ ์ญํ ์ ํฉ๋๋ค.
- ์ฆ, LSSL์ ์๊ฐ ์ฒ๋๋ RNN์ ๊ฒ์ดํ ๋ฉ์ปค๋์ฆ๊ณผ ๋ณธ์ง์ ์ผ๋ก ๊ฐ์ ๊ฐ๋ ์ผ๋ก, ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ๋ ์๊ฐ์ ๋ฐ๋ฅธ ๋ณํ๋์ ๋ถ๋๋ฝ๊ฒ ์กฐ์ ํ๋ ์ญํ ์ ํฉ๋๋ค.
โข Deep LSSL
-
LSSL์ ํ๋์ ๋ ์ด์ด๋ก ์ฌ์ฉํ์ง ์๊ณ ์ฌ๋ฌ ๋ ์ด์ด๋ก ์์์ ๋ณด๋ค ๊น์ ๋คํธ์ํฌ๋ก ํ์ฅํ ์ ์์ต๋๋ค. ์ด ๊ตฌ์กฐ๋ ํนํ ๋น์ ํ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ์ ํฉํฉ๋๋ค.
๊ธฐ๋ณธ LSSL ๊ตฌ์กฐ
: LSSL์ RLโRL\mathbb{R}^L \to \mathbb{R}^LRLโRL seq-to-seq ๋งคํ์ ์ํํ๋ฉฐ, ๊ฐ๊ฐ์ LSSL ๋ ์ด์ด๋ ํ๋ผ๋ฏธํฐ A,B,C,DA, B, C, DA,B,C,D์ ์๊ฐ ๊ฐ๊ฒฉ ฮt\Delta tฮt๋ก ์ ์๋ฉ๋๋ค. ์ ๋ ฅ ์ํ์ค๋ H ์ฐจ์์ ํผ์ฒ๋ก ์ฒ๋ฆฌ๋๋ฉฐ, ๊ฐ ํผ์ฒ๊ฐ ๋ ๋ฆฝ์ ์ผ๋ก ํ์ต๋ฉ๋๋ค.Layer Stacking
: Deep LSSL์ ์ฌ๋ฌ LSSL ๋ ์ด์ด๋ฅผ ์์์ ๋ ๋ณต์กํ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค. ๊ฐ ๋ ์ด์ด๋ ์๋ก ๋ค๋ฅธ ์ํ ๊ณต๊ฐ ํ๋ผ๋ฏธํฐ์ ์๊ฐ ๊ฐ๊ฒฉ์ ํ์ตํ์ฌ, ๋ค์ฐจ์์ ์ธ ์๊ฐ ์ฒ๋์์ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค.Residual Connections
: ResNet๊ณผ ๊ฐ์ Residual Connections์ ์ฌ์ฉํ์ฌ ๋ฅ๋ฌ๋ ๋คํธ์ํฌ์์ ๋ฐ์ํ ์ ์๋ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๋ฅผ ํด๊ฒฐํฉ๋๋ค. ๊ฐ ๋ ์ด์ด์ ์ถ๋ ฅ์ ๋ค์ ๋ ์ด์ด๋ก ์ง์ ์ ๋ฌํจ์ผ๋ก์จ ์ ๋ณด๊ฐ ์ฌ๋ผ์ง์ง ์๊ฒ ์ ์งํ๋ ๋ฐฉ์์ ๋๋ค.Normalization
: LSSL์ ๋ ์ด์ด๊ฐ ๊น์ด์ง์๋ก Layer Normalization์ด ํ์ํฉ๋๋ค. ์ด๋ ๋ ์ด์ด๊ฐ ์์ผ ๋ ๋ฐ์ํ๋ ๋ด๋ถ ๊ณต๋ณ๋ ๋ณํ(Internal Covariate Shift)๋ฅผ ์ค์ฌ์ฃผ์ด, ํ์ต ์๋๋ฅผ ๋์ด๊ณ ์ฑ๋ฅ์ ๊ฐ์ ํฉ๋๋ค.
Appendix B.1 (M) LSSL Computation
- LSSL์ ๊ณ์ฐ์ ์๊ฐ์ด ๋ง์ด ๊ฑธ๋ฆด ์ ์์ง๋ง, ์ผ๋ถ ๊ณ์ฐ์ ์บ์ฑํจ์ผ๋ก์จ ํจ์จ์ฑ์ ๋์ผ ์ ์์ต๋๋ค. ํนํ, ํ๋ จ๋์ง ์์ AAA์ ฮt\Delta tฮt ํ๋ผ๋ฏธํฐ๋ฅผ ๊ณ ์ ํ ๊ฒฝ์ฐ ์บ์ฑ์ ํตํด ๊ณ์ฐ ํจ์จ์ ๊ทน๋ํํ ์ ์์ต๋๋ค.
- ์ ์ด ํ๋ ฌ(Transition Matrix): ์ํ ์ ์ด ํ๋ ฌ Aห\bar{A}Aห๋ ๋ธ๋๋ฐ์ค ๋งคํธ๋ฆญ์ค-๋ฒกํฐ ๊ณฑ ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํ์ฌ ๊ณ์ฐ๋๋ฉฐ, ์ด๋ฅผ ์บ์ฑํด ๋ ์ผ๋ก์จ ์ฐ์ฐ์ ๋ฐ๋ณตํ์ง ์์๋ ๋ฉ๋๋ค.
- ํฌ๋ฆด๋กํ ํ๋ ฌ(Krylov Matrix): ํฌ๋ฆด๋กํ ํ๋ ฌ์ ์ ๋ ฅ๊ณผ ์ํ ์ ์ด ํ๋ ฌ AAA, ๊ทธ๋ฆฌ๊ณ BBB ํ๋ ฌ์ ํตํด ๊ณ์ฐ๋๋ฉฐ, ์ด ๊ณ์ฐ์ ๋ณ๋ ฌํ๋ ์ ์์ต๋๋ค. ์ ๊ณฑ ์ฐ์ฐ ๋ฐ ์ง์ํ๋ฅผ ํตํด ํจ์จ์ ์ผ๋ก ๊ณ์ฐํ ์ ์์ต๋๋ค. ์ต์ข ์ ์ผ๋ก ์ด ํฌ๋ฆด๋กํ ํ๋ ฌ์ (AkB)(A^k B)(AkB)์ ํํ๋ก ์บ์ฑ๋์ด ํฉ์ฑ๊ณฑ ์ฐ์ฐ ์ ์ ์ ์ฅ๋ฉ๋๋ค.
- ๋ณต์ก๋: ์บ์ฑ์ ์ฌ์ฉํ ์ด ์๊ณ ๋ฆฌ์ฆ์ ๊ณ์ฐ ๋ณต์ก๋๊ฐ O(NL)O(NL)O(NL)๋ก ์ค์ด๋ค์ง๋ง, ์ด๋ฅผ ์บ์ฑํ๊ธฐ ์ํ ์ถ๊ฐ์ ์ธ ๋ฉ๋ชจ๋ฆฌ ๊ณต๊ฐ์ด ํ์ํฉ๋๋ค. ์ด ๋ถ๋ถ์ ํ๋ จ ์ ๋ชจ๋ธ์ ํจ์จ์ฑ์ ๊ทน๋ํํ ์ ์์ง๋ง, inference ์์๋ ๋ ๋ง์ ๊ณ์ฐ์ด ์๊ตฌ๋ ์ ์์ต๋๋ค.
Appendix B.2 Initialization of AAA
- ํ๋ผ๋ฏธํฐ AAA๋ HiPPO-LegS ์ฐ์ฐ์๋ฅผ ์ฌ์ฉํ์ฌ ์ด๊ธฐํ๋ฉ๋๋ค.
HiPPO-LegS
๋ ์ฐ์ ์๊ฐ ๋ฉ๋ชจ๋ฆฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์ค๊ณ๋ ์ฐ์ฐ์๋ก, ์ํ ๊ณต๊ฐ ์์คํ ์์ ๊ธด ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ํจ์จ์ ์ผ๋ก ์ฒ๋ฆฌํ๋ ๋ฐ ๋์์ ์ค๋๋ค.AAA๋ ํน์ ๊ท์น์ ๋ฐ๋ผ ๋๊ฐ ํ๋ ฌ์ ๊ตฌ์ฑํ๋๋ฐ, AAA์ ์ด๊ธฐ๊ฐ์ ์๋์ ๊ฐ์ด ์ฃผ์ด์ง๋๋ค:
Ank={(2n+1)1/2/(2k+1)1/2ifย nโฅkn+1ifย n=k0ifย n<kA_{nk} = \begin{cases} (2n + 1)^{1/2}/(2k + 1)^{1/2} & \text{if } n \geq k \ n + 1 & \text{if } n = k \ 0 & \text{if } n < k \end{cases}Ankโ=โฉโชโชโจโชโชโงโ(2n+1)1/2/(2k+1)1/2n+10โifย nโฅkifย n=kifย n<kโ
- ์ด ์ด๊ธฐํ ๋ฐฉ์์ LSSL์ ์ํ ์ ์ด๊ฐ HiPPO ์ฐ์ฐ์ ๋ง์ถ์ด ์ต์ ํ๋๋๋ก ํ๋ฉฐ, ๊ธด ์ํ์ค ๋ฉ๋ชจ๋ฆฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๋ฐ ๋์์ ์ค๋๋ค.
Appendix B.3 Initialization of ฮt\Delta tฮt
- LSSL์์ ฮt\Delta tฮt๋ ๊ฐ ๋ ์ด์ด์์ ์ํ ๊ณต๊ฐ ์์คํ ์ ์๊ฐ ๊ฐ๊ฒฉ(Time-step)์ ์กฐ์ ํ๋ ์ค์ํ ํ๋ผ๋ฏธํฐ์ ๋๋ค. ฮt\Delta tฮt๋ ๋ก๊ทธ ๊ท ๋ฑ ๋ถํฌ(log-uniform distribution)๋ฅผ ์ฌ์ฉํ์ฌ ์ด๊ธฐํ๋๋ฉฐ, ์ด๋ ์๊ฐ ์ฒ๋ ฮt\Delta tฮt๋ฅผ ๋ค์ํ๊ฒ ์ค์ ํจ์ผ๋ก์จ ์ฌ๋ฌ ์๊ฐ ์ฒ๋๋ฅผ ํ์ตํ ์ ์๋๋ก ํฉ๋๋ค.
- ์ต์ ์๊ฐ ๊ฐ๊ฒฉ ฮtmin\Delta t_{min}ฮtminโ์ ์ต๋ ์๊ฐ ๊ฐ๊ฒฉ ฮtmax\Delta t_{max}ฮtmaxโ๋ฅผ ์ค์ ํ์ฌ, ๋ฐ์ดํฐ์ ์ํ์ค ๊ธธ์ด์ ๋ง๊ฒ ์๊ฐ ๊ฐ๊ฒฉ์ ์ด๊ธฐํํฉ๋๋ค.
- ์ด ํ๋ผ๋ฏธํฐ๋ ์ํ์ค ๋ฐ์ดํฐ์ ๊ธธ์ด์ ๋ฐ์ดํฐ์ ๋ง๋ค ๋ค๋ฅด๊ฒ ์ค์ ๋ ์ ์์ผ๋ฉฐ, ๋ค์ํ ์๊ฐ ์ฒ๋์ ๋ํด ๋ชจ๋ธ์ด ์ ์ํ ์ ์๊ฒ ํฉ๋๋ค.
Combining LSSLs with Continuous-time Memorization
๊ธฐ๋ณธ LSSL
์ ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ์์ด ๋ ๊ฐ์ง ๋ฌธ์ ๊ฐ ์์์ต๋๋ค: (1) ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ์ (2) ์ฐ์ฐ ๋ณต์ก๋ ๋ฌธ์
- 4.1 Incorporating Long Dependencies into LSSLs (๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ):
- ๋ฌธ์ : ์ํ ์ ์ด ํ๋ ฌ AAA๋ฅผ ๋ฌด์์๋ก ์ค์ ํ๊ฑฐ๋ ์ ์ ํ๊ฒ ์ค๊ณํ์ง ์์ผ๋ฉด, ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ๋ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๊ฐ ๋ฐ์ํฉ๋๋ค. ์ด๋ ๋คํธ์ํฌ๊ฐ ๊ธด ์๊ฐ ๋์ ์ค์ํ ์ ๋ณด๋ฅผ ์ ์งํ์ง ๋ชปํ๋ ๋ฌธ์ ๋ก, ํนํ LSSL์ด RNN๊ณผ ๊ฐ์ ์ํ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๊ณ ์๊ธฐ ๋๋ฌธ์ ๋ฐ์ํ ์ ์์ต๋๋ค.
- LSSL์ ๊ธด ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ์ ์๋ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ถ๊ณ ์์ง๋ง, ๋ฌด์์(random initialized) ์ํ ํ๋ ฌ AAA๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ ํจ๊ณผ๊ฐ ํฌ์ง ์์์ ๊ฒฝํ์ ์ผ๋ก ํ์ธํ์์ต๋๋ค. (์คํ์ ์ผ๋ก ํ์ธํจ)
- ํด๊ฒฐ์ฑ : ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด HiPPO ํ๋ ์์ํฌ๋ฅผ ์ ์ฉํ์ฌ, ์ ์ ํ ์ํ ์ ์ด ํ๋ ฌ AAA๋ฅผ ์ค๊ณํฉ๋๋ค. HiPPO๋ ์๊ฐ์ ๋ฐ๋ฅธ ์ค์ํ ์ ๋ณด๋ฅผ ์ ์ ์งํ ์ ์๋๋ก ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ์ต์ ํํ์ฌ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๋ฅผ ์ํํฉ๋๋ค.
- ๋ฌธ์ : ์ํ ์ ์ด ํ๋ ฌ AAA๋ฅผ ๋ฌด์์๋ก ์ค์ ํ๊ฑฐ๋ ์ ์ ํ๊ฒ ์ค๊ณํ์ง ์์ผ๋ฉด, ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ๋ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๊ฐ ๋ฐ์ํฉ๋๋ค. ์ด๋ ๋คํธ์ํฌ๊ฐ ๊ธด ์๊ฐ ๋์ ์ค์ํ ์ ๋ณด๋ฅผ ์ ์งํ์ง ๋ชปํ๋ ๋ฌธ์ ๋ก, ํนํ LSSL์ด RNN๊ณผ ๊ฐ์ ์ํ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๊ณ ์๊ธฐ ๋๋ฌธ์ ๋ฐ์ํ ์ ์์ต๋๋ค.
- 4.2 Theoretically Efficient Algorithms for the LSSL (์ฐ์ฐ ๋ณต์ก๋ ๋ฌธ์ ):
- ๋ฌธ์ : LSSL์ ์ํ ์ ์ด ํ๋ ฌ AAA์ ๋ฒกํฐ์ ๊ณฑ์ (Matrix-Vector Multiplication, MVM)์ด๋ Krylov ๊ณต๊ฐ์์์ ํฉ์ฑ๊ณฑ ์ฐ์ฐ์ด ํฌํจ๋๋๋ฐ, ์ด ์ฐ์ฐ๋ค์ด ๋งค์ฐ ๋ณต์กํ๊ณ ์๊ฐ์ด ๋ง์ด ๊ฑธ๋ฆด ์ ์์ต๋๋ค. ํนํ, ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ๋ ์ฐ์ฐ ๋ณต์ก๋๊ฐ ์ปค์ง๋ ๋ฌธ์ ๊ฐ ์์ต๋๋ค.
- ํด๊ฒฐ์ฑ : ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด Quasiseparable ํ๋ ฌ์ ์ฌ์ฉํ์ฌ, ์ํ ์ ์ด ํ๋ ฌ์ ํน์ฑ์ ํ์ฉํ ํจ์จ์ ์ธ ๊ณ์ฐ ๋ฐฉ๋ฒ์ ์ ์ํฉ๋๋ค. Quasiseparable ํ๋ ฌ์ ์ ํ ์๊ฐ ๋ณต์ก๋๋ก ๊ณ์ฐํ ์ ์์ผ๋ฉฐ, Krylov ๊ณต๊ฐ์์์ ์ฐ์ฐ์ ๋ ๋น ๋ฅด๊ณ ํจ์จ์ ์ผ๋ก ์ํํ ์ ์๊ฒ ํด์ค๋๋ค.
Empirical Evaluation
- 5.1 Image and Time Series Benchmarks: ์๊ณ์ด ์ด๋ฏธ์ง์ ๊ฐ์ ๋ฐ์ดํฐ์ ์์ LSSL์ ์ฑ๋ฅ์ ํ๊ฐํ ์คํ ๊ฒฐ๊ณผ๋ฅผ ์ค๋ช ํฉ๋๋ค. ์ฌ๊ธฐ์๋ sMNIST, pMNIST, sCIFAR์ ๊ฐ์ ์ ๋ช ํ ๋ฒค์น๋งํฌ์์์ ์ฑ๋ฅ ๋น๊ต๊ฐ ํฌํจ๋ฉ๋๋ค.
- 5.2 Speech and Image Classification for Very Long Time Series: ๋งค์ฐ ๊ธด ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ์์ฑ ๋ฐ ์ด๋ฏธ์ง ๋ถ๋ฅ ๋ฌธ์ ์์ LSSL์ด ๊ธฐ์กด ๋ชจ๋ธ๋ณด๋ค ๋ฐ์ด๋ ์ฑ๋ฅ์ ๋ณด์๋ค๋ ์ ์ ์ค๋ช ํฉ๋๋ค.
- 5.3 Advantages of Recurrent, Convolutional, and Continuous-time Models: ์ฌ๊ท์ , ์ปจ๋ณผ๋ฃจ์ , ์ฐ์-์๊ฐ ๋ชจ๋ธ์ ์ฅ์ ์ ๋ชจ๋ ๊ฐ์ถ LSSL์ ์ฅ์ ์ ๊ฐ์กฐํฉ๋๋ค.
- 5.4 LSSL Ablations: Learning the Memory Dynamics and Timescale: LSSL์ด ์ํ์ค์ ์๊ฐ ์ค์ผ์ผ์ ์๋์ผ๋ก ํ์ตํ ์ ์๋ ๋ฅ๋ ฅ์ ๋ถ์ํ๊ณ , ๋ฉ๋ชจ๋ฆฌ ๋๋ ฅํ์ ํ์ตํ๋ ๋ฐฉ๋ฒ์ ๋ํ ์คํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ค๋๋ค.
-
S4: Efficiently Modeling Long Sequences with Structured State Spaces (ICLR, 2022)
Introduction
-
์ด ์น์ ์์๋ ์์ฐจ ๋ฐ์ดํฐ(sequence data) ๋ชจ๋ธ๋ง์ ์ฃผ์ ๊ณผ์ ์ธ ์ฅ๊ธฐ ์ข ์์ฑ(long-range dependencies) ๋ฌธ์ ๋ฅผ ๋ค๋ฃจ๋ฉฐ ๊ธฐ์กด์ ๋ชจ๋ธ(RNN, CNN, Transformer ๋ฑ)์ด ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ์์ด ๊ฒช๋ ๋ฌธ์ ์ ์ ์ ์ํฉ๋๋ค.
- RNNs (Recurrent Neural Networks): RNN ๊ณ์ด ๋ชจ๋ธ์ ๋ณธ๋ ์์ฐจ ๋ฐ์ดํฐ ์ฒ๋ฆฌ๋ฅผ ์ํด ๊ฐ๋ฐ๋์์ผ๋, vanishing gradient(๊ธฐ์ธ๊ธฐ ์์ค) ๋ฌธ์ ๋ก ์ธํด ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ํ๊ณ๊ฐ ์์ต๋๋ค.
- CNNs (Convolutional Neural Networks): CNN์ ์ํ์ค ๊ธธ์ด๋ฅผ ํ์ฅํ๊ธฐ ์ํด dilated convolutions(ํ์ฅ๋ ์ปจ๋ณผ๋ฃจ์ ) ๋ฑ์ ๋์ ํ์ผ๋ ์ฌ์ ํ ๊ธด ์ํ์ค ์ฒ๋ฆฌ์์ ์ฑ๋ฅ์ด ์ ํ๋ฉ๋๋ค.
- Transformers: Transformers ๋ชจ๋ธ์ ๋๊ท๋ชจ ์ํ์ค ์ฒ๋ฆฌ์ ๋๋ฆฌ ์ฌ์ฉ๋์ง๋ง, quadratic scaling(์ํ์ค ๊ธธ์ด์ ๋ฐ๋ฅธ ์ฐ์ฐ ๋ณต์ก๋๊ฐ ์ ๊ณฑ์ ๋น๋ก) ๋ฌธ์ ๋ก ์ธํด ๋งค์ฐ ๊ธด ์ํ์ค์์๋ ๋นํจ์จ์ ์ ๋๋ค.
- ๋์์ ์ ๊ทผ๋ฒ์ผ๋ก ์ต๊ทผ ์ฐ๊ตฌ์์๋ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ(SSM)์ ๊ธฐ๋ฐ์ผ๋ก ํ ์ ๊ทผ๋ฒ์ด ์ ์๋์์ต๋๋ค. SSM์ ์ ์ด ์ด๋ก ๋ฑ ๋ค์ํ ๋ถ์ผ์์ ์ค๋์ ๋ถํฐ ์ฌ์ฉ๋์ด ์จ ๋ชจ๋ธ๋ก ์๊ฐ์ ๋ฐ๋ผ ๋ณํํ๋ ์ํ๋ฅผ ํํํ๊ณ , ์ด๋ฅผ ํตํด ์ฅ๊ธฐ์ ์ธ ์๊ณ์ด ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๊ธฐ์กด SSM์ ๋ฅ๋ฌ๋์ ์ ์ฉํ๋ ๋ฐ๋ ๊ณ์ฐ ๋น์ฉ๊ณผ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ๋งค์ฐ ํฌ๋ค๋ ํ๊ณ์ ๋ด์ฐฉํ์ต๋๋ค.
-
๋ณธ ๋ ผ๋ฌธ์์๋ ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด S4(Structured State Spaces) ๋ชจ๋ธ์ ์ ์ํฉ๋๋ค. ์ด ๋ชจ๋ธ์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ์ํ์ ๊ฐ์ ์ ์ ์งํ๋ฉด์๋, ์ด๋ฅผ ๋ ํจ์จ์ ์ผ๋ก ๊ณ์ฐํ ์ ์๋ ๋ฐฉ๋ฒ์ ์ ๊ณตํฉ๋๋ค.
- S4๋ ์ํ ํ๋ ฌ A๋ฅผ ์ ๋ญํฌ(low-rank)์ ์ ๊ท ํ๋ ฌ(normal matrix)๋ก ๋ถํดํ์ฌ ๊ณ์ฐ์ ์์ ์ฑ๊ณผ ํจ์จ์ฑ์ ๋์ ๋๋ค.
- ํนํ S4๋ Cauchy kernel์ ์ฌ์ฉํ์ฌ ํจ์จ์ ์ธ ๊ณ์ฐ์ ๊ฐ๋ฅํ๊ฒ ํ๋ฉฐ, ์ด๋ก ์ธํด ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ํ์ํ ์ฐ์ฐ๋๊ณผ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ํฌ๊ฒ ์ค์ผ ์ ์์ต๋๋ค.
-
Figure 1 ์ค๋ช
-
(์ผ์ชฝ) ์ํ ๊ณต๊ฐ ๋ชจ๋ธ: ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ์ ๋ ฅ ์ ํธ u(t)u(t)u(t)๊ฐ ์ฃผ์ด์ก์ ๋, ์ด๋ฅผ ์๋ ์ํ x(t)x(t)x(t)๋ก ๋ณํํ ๋ค, ์ต์ข ์ ์ผ๋ก ์ถ๋ ฅ y(t)y(t)y(t)๋ฅผ ์์ฑํ๋ ์์คํ ์ ๋๋ค.
- ์ํ ๋ณํ์ ์ํ ํ๋ ฌ AAA, ์ ๋ ฅ ํ๋ ฌ BBB, ์ถ๋ ฅ ํ๋ ฌ CCC, ๊ทธ๋ฆฌ๊ณ ์คํต ์ฐ๊ฒฐ์ ๋ด๋นํ๋ ํ๋ ฌ DDD์ ์ํด ์ ์๋ฉ๋๋ค.
- ์ด ๋ชจ๋ธ์ ์ ์ด ์ด๋ก ๊ณผ ๊ณ์ฐ ์ ๊ฒฝ๊ณผํ์์ ๊ด๋ฒ์ํ๊ฒ ์ฌ์ฉ๋๋ฉฐ, ํนํ ์ฐ์ ์๊ฐ ์์คํ ์ ๋ชจ๋ธ๋งํ๋ ๋ฐ ์ ํฉํฉ๋๋ค.
-
(์ค์) ์ฐ์ ์๊ฐ ๋ฉ๋ชจ๋ฆฌ ์ด๋ก : ์ต๊ทผ ์ฐ๊ตฌ์์๋ ํน์ ํ๋ ฌ AAA๋ฅผ ์ฌ์ฉํ๋ฉด SSM์ด ์ฅ๊ธฐ ์ข ์์ฑ(Long-Range Dependencies, LRDs)์ ์ํ์ ์ผ๋ก๋ ์คํ์ ์ผ๋ก ํจ๊ณผ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์์์ ์ ์ฆํ์ต๋๋ค. (
์ด์ ์ฐ๊ตฌ
)- ์ด๋ฌํ ํ๋ ฌ์ HiPPO๋ผ๋ ์ด๋ก ์์ ์ ๋๋ ํน๋ณํ ํ๋ ฌ๋ก, ์ ๋ ฅ์ ๊ธด ์ด๋ ฅ์ ๊ธฐ์ตํ๋ ๋ฐ ์ต์ ํ๋์ด ์์ต๋๋ค.
-
(์ค๋ฅธ์ชฝ) ์ฌ๊ท ๋ฐ ์ปจ๋ณผ๋ฃจ์ ํํ: SSM์ ๋ ๊ฐ์ง ๋ฐฉ์์ผ๋ก ๊ณ์ฐํ ์ ์์ต๋๋ค.
์ฌ๊ท์ ๋ฐฉ์
๊ณผ์ปจ๋ณผ๋ฃจ์ ๋ฐฉ์
.- ์ฌ๊ท์ ๋ฐฉ์์ RNN์ฒ๋ผ ์์ฐจ์ ์ผ๋ก ๊ณ์ฐ๋๋ฉฐ, ์ปจ๋ณผ๋ฃจ์ ๋ฐฉ์์ ๋ณ๋ ฌํ๊ฐ ๊ฐ๋ฅํด ๋ ๋น ๋ฅธ ์ฐ์ฐ์ด ๊ฐ๋ฅํฉ๋๋ค.
- S4๋ ์ด๋ฌํ ์๋ก ๋ค๋ฅธ ํํ ๊ฐ์ ๋ณํ์ ํจ์จ์ ์ผ๋ก ์ํํ ์ ์๋๋ก ์ค๊ณ๋์์ผ๋ฉฐ, ๋ค์ํ ์์ ์ ์ ํฉํ ๋ฐฉ์์ผ๋ก ํจ์จ์ ์ธ ํ์ต๊ณผ ์ถ๋ก ์ ์ง์ํฉ๋๋ค.
-
Method: Structured State Spaces (S4)
-
3.1 ๋๊ธฐ: ๋๊ฐํ (Motivation: Diagonalization)
๋ฌธ์ ์ ์
: ์ํ ๊ณต๊ฐ ๋ชจ๋ธ(SSM)์ ์ค์ํ ๋ฌธ์ ๋, ์ํ ๊ณต๊ฐ์ ํฌ๊ธฐ๊ฐ ์ปค์ง์ ๋ฐ๋ผ ์ฐ์ฐ ๋ณต์ก๋๊ฐ ์ฆ๊ฐํ๋ค๋ ๊ฒ์ ๋๋ค. ๊ตฌ์ฒด์ ์ผ๋ก, HiPPO ํ๋ ฌ AAA๋ฅผ ์ฌ๋ฌ ๋ฒ ๊ณฑํ๋ ์ฐ์ฐ์ด ๋ณต์ก๋๋ฅผ ์ฆ๊ฐ์ํค๋ ์ฃผ ์์ธ์ ๋๋ค. (โต์ํ ์ ๋ฐ์ดํธ๋ฅผ ์ํด์๋ A๋ฅผ ์ฌ๋ฌ๋ฒ ๊ณฑํด์ผํจ)-
์ํ ๊ณต๊ฐ ๋ชจ๋ธ์์ AAA๋ ์ํ ๊ฐฑ์ ์ ๋ด๋นํ๋ ํต์ฌ ํ๋ ฌ์ด๋ฉฐ, ์ด๋ฅผ ์ฌ์ฉํ๋ ์ฐ์ฐ์ ๋ฐ๋ณต์ ์ผ๋ก ์ผ์ด๋ฉ๋๋ค. AAA๋ฅผ ์ง์ ๊ณ์ฐํ๋ฉด O(N2L)O(N^2L)O(N2L)์ ๋ฌํ๋ ์ฐ์ฐ๋๊ณผ O(NL)O(NL)O(NL)์ ๋ฉ๋ชจ๋ฆฌ ๊ณต๊ฐ์ด ํ์ํฉ๋๋ค. ์ด๋ ํนํ ๋๊ท๋ชจ ์ํ์ค ๋ชจ๋ธ๋ง์์ ๋ณ๋ชฉ์ด ๋ฉ๋๋ค.
์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด, ์ผค๋ (conjugation)๋ผ๋ ์ํ์ ๊ธฐ๋ฒ์ ๋์ ํ์ฌ ์ฐ์ฐ์ ๋จ์ํํ ์ ์์ต๋๋ค.
-
-
Lemma 3.1: ์ผค๋ ๊ด๊ณ : ์ด ๋ ๋ง์์๋ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ(SSM)์ ํ๋ ฌ AAA, BBB, CCC์ ์ผค๋ ๋ณํ ์ ์ ์ฉํ๋ฉด ๋์ผํ ๋ชจ๋ธ์ ์ป์ ์ ์์์ ๋ณด์ฌ์ค๋๋ค. ์ด ๋ง์, ๋ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ด ์๋ก ๋์ผํ ์ ๋ณด๋ฅผ ํํํ๊ณ ์์ง๋ง ๋ค๋ฅธ ์ขํ๊ณ์์ ํํ๋ ์ ์๋ค๋ ์๋ฏธ์ ๋๋ค. ์ด๋ฅผ ํตํด ์์คํ ์ ๋ณต์กํ ๊ณ์ฐ์ ๋ ๋จ์ํํ ์ ์์ต๋๋ค.
์ผค๋ ๋ณํ์ด๋?
- ์ ํ๋์ํ์์
์ผค๋ ๋ณํ(conjugate transformation)
์ ๋ณต์์ ํ๋ ฌ์ด๋ ๋ฒกํฐ์ ์ ์ฉ๋๋ ์ค์ํ ์ฐ์ฐ์ ๋๋ค. ์ด ๋ณํ์ ๋ณต์์ ํ๋ ฌ์ ๋ํด ๋ ๊ฐ์ง ์ฐ์ฐ์ ์์ฐจ์ ์ผ๋ก ์ํํฉ๋๋ค: โ ํ๋ ฌ์ ์ ์น(transpose)ํฉ๋๋ค. โก ๊ฐ ์์๋ฅผ ์ผค๋ ๋ณต์์๋ก ๋ณํํฉ๋๋ค
์ผค๋ ๋ณํ์ ์์
์์คํ ๋ถ์
: ์ผค๋ ๋ณํ์ ํตํด ์์คํ ์ ๋ ์ฝ๊ฒ ๋ถ์ํ ์ ์๋ ํํ๋ก ๋ณํํ ์ ์์ต๋๋ค.- ์๋ฅผ ๋ค์ด, ๋๊ฐํ๋ ์ ๊ทํ์ผ๋ก์ ๋ณํ์ด ๊ฐ๋ฅํฉ๋๋ค.
์ ์ด ์ค๊ณ
: ์ํ ํผ๋๋ฐฑ ์ ์ด๋ ๊ด์ธก๊ธฐ ์ค๊ณ ์, ์ผค๋ ๋ณํ์ ํตํด ๋ ๊ฐ๋จํ ํํ์ ์์คํ ์ผ๋ก ๋ณํํ์ฌ ์ค๊ณ๋ฅผ ์ํํ ์ ์์ต๋๋ค.๊ณ์ฐ ํจ์จ
: ํน์ ํํ๋ก์ ๋ณํ์ ํตํด ๊ณ์ฐ ํจ์จ์ ๋์ผ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ๋๊ฐ ํ๋ ฌ์ ๊ณ์ฐ์ด ๋งค์ฐ ๊ฐ๋จํฉ๋๋ค.
์ผค๋ ๊ด๊ณ์ ์๋ฏธ : ์ผค๋ ๊ด๊ณ๋ ์ฃผ๋ก ๋ณต์์๋ ํ๋ ฌ์์ ์ฌ์ฉ๋๋ ๊ฐ๋ ์ ๋๋ค.
๋ณต์์์์์ ์ผค๋
: ๋ณต์์a + bi
์ ์ผค๋ ๋a - bi
์ ๋๋ค. ์ผค๋ ๋ณต์์๋ ์ค์๋ถ๋ ๊ฐ๊ณ ํ์๋ถ์ ๋ถํธ๋ง ๋ฐ๋์ ๋๋ค.ํ๋ ฌ์์์ ์ผค๋ ์ ์น
: ํ๋ ฌ A์ ์ผค๋ ์ ์น(conjugate transpose)๋A*
๋ก ํ๊ธฐํ๋ฉฐ, ํ๋ ฌ์ ์ ์นํ ํ ๊ฐ ์์๋ฅผ ์ผค๋ ๋ณต์์๋ก ๋ฐ๊พผ ๊ฒ์ ๋๋ค
-
๋ ๊ฐ์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ์๊ฐํด ๋ด ์๋ค. ํ๋๋ ์๋์ ์ํ ๋ฒกํฐ xxx๋ฅผ ์ฌ์ฉํ๊ณ , ๋ค๋ฅธ ํ๋๋ ๋ณํ๋ ์ํ ๋ฒกํฐ x~=Vx\tilde{x} = Vxx~=Vx๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ฌ๊ธฐ์ VVV๋ ๋ณํ์ ์ํํ๋ ํ๋ ฌ์ ๋๋ค.
-
๊ฐ๊ฐ์ ์ํ ๊ณต๊ฐ ๋ฐฉ์ ์์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
-
์๋ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ:
xโฒ=Ax+Buxโ = Ax + Buxโฒ=Ax+Bu y=Cxy = Cxy=Cx
-
๋ณํ๋ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ:
x~โฒ=Vโ1AVx~+Vโ1Bu\tilde{x}โ = V^{-1}AV\tilde{x} + V^{-1}Bux~โฒ=Vโ1AVx~+Vโ1Bu y=CVx~y = CV\tilde{x}y=CVx~
-
-
-
์ด ๋ ๋ชจ๋ธ์ ๋์ผํ ์์คํ ์ ๋ํ๋ด๋ฉฐ, ์ด๋ ์ผค๋ ๋ณํ ์ด ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์์ ๋๋ฑ ๊ด๊ณ์์ ์๋ฏธํฉ๋๋ค. ์ด๋ฅผ ํตํด ์ฐ๋ฆฌ๋ AAA, BBB, CCC ํ๋ ฌ์ ๋ณํํ์ฌ ๋์ผํ ์ฐ์ฐ์ ๋ค๋ฅธ ํํ๋ก ๊ณ์ฐํ ์ ์๊ฒ ๋ฉ๋๋ค. ์ผค๋ ๊ด๊ณ๋ ์๋ ์์ผ๋ก ์ ์๋ฉ๋๋ค.
(A,B,C)โผ(Vโ1AV,Vโ1B,CV)(A, B, C) \sim (V^{-1} A V, V^{-1} B, C V)(A,B,C)โผ(Vโ1AV,Vโ1B,CV)
- ์ฆ, ํ๋ ฌ VVV๋ฅผ ์ฌ์ฉํ์ฌ ์ํ ๋ฒกํฐ xxx๋ฅผ ๋ณํํ๋ฉด, ์๋ก์ด ์ํ ๋ฒกํฐ x~=Vx\tilde{x} = Vxx~=Vx๋ก ๋ณํ๋ ์์คํ ์์ ๋ ํจ์จ์ ์ธ ์ฐ์ฐ์ด ๊ฐ๋ฅํฉ๋๋ค.
- ํ๋ ฌ AAA๋ฅผ Vโ1AVV^{-1} A VVโ1AV๋ก ๋ณํํ์ฌ ๋๊ฐํํ๋ฉด, AAA๊ฐ ๋๊ฐ ํ๋ ฌ์ผ ๋ ๊ณ์ฐ์ด ๋จ์ํด์ง๋๋ค.
- ์ ํ๋์ํ์์
-
Lemma 3.2: HiPPO ํ๋ ฌ์ ๋๊ฐํ: ์ด ๋ ๋ง๋ HiPPO ํ๋ ฌ AAA๊ฐ ๋๊ฐํ๋ ์ ์์์ ๋ณด์ฌ์ค๋๋ค. ์ฌ๊ธฐ์ ๋๊ฐํ๋ ๋ณต์กํ ํ๋ ฌ ์ฐ์ฐ์ ๋ ๋จ์ํ๊ฒ ๋ง๋ค์ด์ฃผ๋ ์ค์ํ ์ํ์ ๊ธฐ๋ฒ์ ๋๋ค.
- HiPPO ํ๋ ฌ์ ์๊ณ์ด ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๊ธฐ ์ํ ํน์ ์ ํ์ ํ๋ ฌ์ธ๋ฐ, ์ด ํ๋ ฌ์ ๋๊ฐํ๋ ๊ณ์ฐ ํจ์จ์ฑ์ ๋์ด๋ ๋ฐ ์ค์ํ ์ญํ ์ ํฉ๋๋ค.
-
HiPPO ํ๋ ฌ AAA๋ ๋๊ฐํ๋ ์ ์์ผ๋ฉฐ, ๋๊ฐํ์ ์ฌ์ฉ๋๋ ๋ณํ ํ๋ ฌ VVV์ ํ๋ ฌ VVV์ ๊ฐ ํญ๋ชฉ V3k,iV_{3k,i}V3k,iโ๋ ์๋์ ๊ฐ์ด ์ ์๋ฉ๋๋ค.
Vij=((i+jiโj))V_{ij} = \left( \binom{i+j}{i-j} \right)Vijโ=((iโji+jโ)) V3k,i=((ki))2iโkV_{3k,i} = \left(\binom{k}{i}\right) 2^{i-k}V3k,iโ=((ikโ))2iโk
- ์ด ์์ ํตํด, VVV์ ํญ๋ชฉ์ 2iN/32^{iN/3}2iN/3 ์ ๋์ ํฌ๊ธฐ๋ฅผ ๊ฐ์ง๋๋ค. ์ด ์์์ ํตํด HiPPO ํ๋ ฌ์ ๋๊ฐํํ์ฌ ์ฐ์ฐ์ ๊ฐ์ํํ ์ ์์ต๋๋ค.
-
3.2 S4 ํ๋ผ๋ฏธํฐํ: Normal Plus Low-Rank Parameterization (NLPR)
- ๊ธฐ๋ณธ์ ์ธ HiPPO ํ๋ ฌ AAA๋ ๋๊ฐ ํ๋ ฌ์ด ์๋๋ฉฐ, ์ผ๋ฐ์ ์ผ๋ก ๊ณ์ฐ์ด ๋ณต์กํฉ๋๋ค. ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด, ๋
ผ๋ฌธ์์๋ ์ ๊ท ํ๋ ฌ(normal matrix)๊ณผ ์ ๋ญํฌ ํ๋ ฌ(low-rank matrix)์ ํฉ์ผ๋ก ๋ถํดํ๋ ๊ธฐ๋ฒ์ ์ ์ํฉ๋๋ค.
์ ๊ท ํ๋ ฌ (Normal Matrix) : ์ ๊ท ํ๋ ฌ์ ํน๋ณํ ์ฑ์ง์ ๊ฐ์ง ์ ์ฌ๊ฐ ํ๋ ฌ์ ๋๋ค.
์ ์
: ํ๋ ฌ์ ๋ค์ง๊ณ ๋ณต์์ ๋ถ๋ถ์ ๋ถํธ๋ฅผ ๋ฐ๊พผ ๊ฒ(์ผค๋ ์ ์น)๊ณผ ์๋ ํ๋ ฌ์ ๊ณฑํ์ ๋, ์์๋ฅผ ๋ฐ๊ฟ๋ ๊ฐ์ ๊ฒฐ๊ณผ๊ฐ ๋์ค๋ ํ๋ ฌ์ ๋๋ค.
์ ๋ญํฌ ํ๋ ฌ (Low-rank Matrix) : ์ ๋ญํฌ ํ๋ ฌ์ ๋ณต์กํ ์ ๋ณด๋ฅผ ๊ฐ๋จํ๊ฒ ํํํ ์ ์๋ ํ๋ ฌ์ ๋๋ค.
์ ์
: ํ๋ ฌ์ ๋ญํฌ(๋ ๋ฆฝ์ ์ธ ํ ๋๋ ์ด์ ์)๊ฐ ์์ ํ๋ ฌ์ ๋งํฉ๋๋ค.
- ์ ๊ท ํ๋ ฌ์ ๋๊ฐํ๊ฐ ๊ฐ๋ฅํ์ง๋ง, HiPPO ํ๋ ฌ ์์ฒด๋ ์ด ์์ฑ์ ๋ง์กฑํ์ง ์์ผ๋ฏ๋ก ์ด๋ฅผ ํ์ฉํ ์ ์์ต๋๋ค.
-
๋์ , HiPPO ํ๋ ฌ์ ์ ๊ท ํ๋ ฌ๊ณผ ์ ๋ญํฌ ํ๋ ฌ์ ํฉ์ผ๋ก ๊ทผ์ฌํ ์ ์์ต๋๋ค. ์ฆ, AAA๋ ์๋์ ๊ฐ์ด ๋ถํด๋ฉ๋๋ค.
A=VฮVโ1โPQTA = V \Lambda V^{-1} - PQ^TA=VฮVโ1โPQT
- ฮ\Lambdaฮ: ๋๊ฐ ํ๋ ฌ
- PPP, QQQ: ์ ๋ญํฌ ํ๋ ฌ
- ์ ๋ญํฌ ํ๋ ฌ์ ํญ๋ชฉ ์๊ฐ ์ ๊ธฐ ๋๋ฌธ์ ๊ณ์ฐ์ด ํจ์จ์ ์ผ๋ก ์ด๋ฃจ์ด์ง ์ ์์ผ๋ฉฐ, ์ด๋ฌํ ๋ถํด๋ NPLR (Normal Plus Low-Rank) ๊ธฐ๋ฒ์ผ๋ก ์๋ ค์ ธ ์์ต๋๋ค. ์ด ๋ฐฉ๋ฒ์ ์ฌ์ฉํ๋ฉด, AAA๋ฅผ ์ฌ๋ฌ ๋ฒ ๊ณฑํ๋ ์ฐ์ฐ์ ๋ณต์ก๋๋ฅผ ๋ํญ ์ค์ผ ์ ์๋ค๊ณ ํฉ๋๋ค.
-
(Theorem 1) ๋ชจ๋ HiPPO ํ๋ ฌ์ NPLR ํํ : ๋ชจ๋ HiPPO ํ๋ ฌ์ด NPLR ํํ์ ๊ฐ์ง ์ ์์์ ์ฆ๋ช ํฉ๋๋ค. ์ด๋ฅผ ํตํด, S4 ๋ชจ๋ธ์์ ์ฌ์ฉ๋๋ ํ๋ ฌ AAA๋ ์๋์ ๊ฐ์ด ํํ๋ฉ๋๋ค.
A=VฮVโ1โPQT=V(ฮโ(Vโ1P)(Vโ1Q)T)Vโ1A = V\Lambda V^{-1} - PQ^T = V\left(\Lambda - (V^{-1}P)(V^{-1}Q)^T\right)V^{-1}A=VฮVโ1โPQT=V(ฮโ(Vโ1P)(Vโ1Q)T)Vโ1
- ฮ\Lambdaฮ๋ ๋๊ฐ ํ๋ ฌ
- PPP์ QQQ๋ ์ ๋ญํฌ ํ๋ ฌ
- ๊ธฐ๋ณธ์ ์ธ HiPPO ํ๋ ฌ AAA๋ ๋๊ฐ ํ๋ ฌ์ด ์๋๋ฉฐ, ์ผ๋ฐ์ ์ผ๋ก ๊ณ์ฐ์ด ๋ณต์กํฉ๋๋ค. ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด, ๋
ผ๋ฌธ์์๋ ์ ๊ท ํ๋ ฌ(normal matrix)๊ณผ ์ ๋ญํฌ ํ๋ ฌ(low-rank matrix)์ ํฉ์ผ๋ก ๋ถํดํ๋ ๊ธฐ๋ฒ์ ์ ์ํฉ๋๋ค.
-
3.3 S4 Algorithms and Computational Complexity : ์ด ์น์ ์์๋ S4 ๋ชจ๋ธ์์ ์ ์ํ๋ ์ฃผ์ ์๊ณ ๋ฆฌ์ฆ๊ณผ ๊ทธ ๋ณต์ก๋์ ๋ํด ์ค๋ช ํฉ๋๋ค. ์ ๋ฆฌ 2 ์ ์ ๋ฆฌ 3 ์ ๊ฐ๊ฐ ์ฌ๊ท ์ฐ์ฐ๊ณผ ์ปจ๋ณผ๋ฃจ์ ์ฐ์ฐ์ ๋ณต์ก๋๋ฅผ ๋ค๋ฃน๋๋ค.
Theorem 2
: S4 Recurrence์์๋ ์ฌ๊ท ์ฐ์ฐ์ ๋ณต์ก๋๋ฅผ O(N)O(N)O(N)์ผ๋ก ์ค์ผ ์ ์์์ ์ค๋ช ํฉ๋๋ค. ์ฌ๊ท ์ฐ์ฐ์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์์ ์ค์ํ ์ฐ์ฐ์ด๋ฉฐ, ์ด๋ฅผ ํจ์จ์ ์ผ๋ก ์ํํ๋ ๋ฐฉ๋ฒ์ ์ ์ํฉ๋๋ค.-
Theorem 3
: S4 Convolution์์๋ SSM์ ์ปจ๋ณผ๋ฃจ์ ํํฐ KKK๋ฅผ ๊ณ์ฐํ๋ ์ฐ์ฐ์ O(N+L)O(N + L)O(N+L)๋ก ์ค์ผ ์ ์์์ ์ค๋ช ํฉ๋๋ค. ์ด ํํฐ๋ ์ํ์ค ๋ชจ๋ธ์์ ์ฃผ๋ก ์ฌ์ฉ๋๋ ํต์ฌ ์ฐ์ฐ ์ค ํ๋์ ๋๋ค.- ์ปจ๋ณผ๋ฃจ์
ํํฐ์ ๊ณ์ฐ์ 4๊ฐ์ ์ผ์ฐ์ ๊ณฑ์
(Cauchy multiplies)์ผ๋ก ์ด๋ฃจ์ด์ง๋ฉฐ, O(N+L)O(N + L)O(N+L) ์ฐ์ฐ๋ง ํ์ํฉ๋๋ค. ์ด๋ก ์ธํด S4 ๋ชจ๋ธ์ ๋๊ท๋ชจ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ๋งค์ฐ ํจ์จ์ ์
๋๋ค.
์ผ์ฐ์ ํ๋ ฌ(Cauchy Matrix): ๋ ๋ฒกํฐ์ ์์ ์ฐจ์ด์ ์ญ์๋ก ์ด๋ฃจ์ด์ง ํน์ํ ํ๋ ฌ.
Cij=1xiโyjC_{ij} = \frac{1}{x_i - y_j}Cijโ=xiโโyjโ1โ
์ผ์ฐ์ ๊ณฑ์ (Cauchy Multiplication): ํจ์๋ ์์ด์ ๊ณฑ์ ์ ํจ์จ์ ์ผ๋ก ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ผ๋ก, ๋คํญ์์ ๊ณฑ์ด๋ ์ปจ๋ณผ๋ฃจ์ ์ฐ์ฐ์์ ์ฌ์ฉ๋จ.
- ์ปจ๋ณผ๋ฃจ์
ํํฐ์ ๊ณ์ฐ์ 4๊ฐ์ ์ผ์ฐ์ ๊ณฑ์
(Cauchy multiplies)์ผ๋ก ์ด๋ฃจ์ด์ง๋ฉฐ, O(N+L)O(N + L)O(N+L) ์ฐ์ฐ๋ง ํ์ํฉ๋๋ค. ์ด๋ก ์ธํด S4 ๋ชจ๋ธ์ ๋๊ท๋ชจ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ๋งค์ฐ ํจ์จ์ ์
๋๋ค.
Algorithm 1: S4 Convolution Kernel
์๊ณ ๋ฆฌ์ฆ 1
์ S4 ์ปจ๋ณผ๋ฃจ์ ์ปค๋(S4 Convolution Kernel)์ ๊ณ์ฐํ๋ ์ ์ฐจ๋ฅผ ์ค๋ช ํ๊ณ ์์ต๋๋ค. ์ด ์๊ณ ๋ฆฌ์ฆ์์ํ ๊ณต๊ฐ ๋ชจ๋ธ(SSM)
์ ๊ธฐ๋ฐ์ผ๋ก ์ํ์ค ๋ฐ์ดํฐ์์ ์ปจ๋ณผ๋ฃจ์ ํํฐ๋ฅผ ํจ์จ์ ์ผ๋ก ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ ์ ์ํฉ๋๋ค. ์๋ ๊ทธ๋ฆผ์ ๊ธฐ์ค์ผ๋ก ์ค๋ช ํฉ๋๋ค.
์ ๋ ฅ
- ฮ,P,Q,B,C\Lambda, P, Q, B, Cฮ,P,Q,B,C: ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์์ ์ํ ์ ๋ฐ์ดํธ, ์ ๋ ฅ ๋ฐ ์ถ๋ ฅ์ ๊ด๋ จ๋ S4 ๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ.
- ฮ\Deltaฮ: ์๊ฐ ๊ฐ๊ฒฉ ๋๋ ๋จ๊ณ ํฌ๊ธฐ(step size).
์ถ๋ ฅ
- KKK: S4 ๋ชจ๋ธ์ ์ปจ๋ณผ๋ฃจ์ ์ปค๋ (SSM ์ต์ข ์ปจ๋ณผ๋ฃจ์ ํํฐ)
๋จ๊ณ๋ณ ์ค๋ช
-
โ SSM ์์ฑ ํจ์ C~\tilde{C}C~ ๊ณ์ฐ
-
์ฌ๊ธฐ์, SSM ์์ฑ ํจ์(Generating Function) C~\tilde{C}C~๋ ์๋์ ๊ฐ์ด ์ ์๋ฉ๋๋ค.
C~โ(IโAL)โ1C\tilde{C} \leftarrow \left( I - A^L \right)^{-1} CC~โ(IโAL)โ1C
-
ALA^LAL๋ ํ๋ ฌ AAA๋ฅผ ์๊ฐ ๋จ๊ณ LLL์ ๋ํด ์ ๊ณฑํ ํ๋ ฌ์ ์๋ฏธํฉ๋๋ค.
- ์ด ์ฐ์ฐ์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์์ ์ํ ๊ฐฑ์ ์ ๋ํ๋ด๋ฉฐ, CCC์ ๊ฒฐํฉํ์ฌ ์ํ ๊ณต๊ฐ์ ์ถ๋ ฅ์ ๊ณ์ฐํ ์ ์์ต๋๋ค. (์ฐธ๊ณ ๋ก, C๋ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ(SSM)์ ์ถ๋ ฅ ํ๋ ฌ)
- IโALI - A^LIโAL์ ๋จ์ ํ๋ ฌ III์์ ํ๋ ฌ ALA^LAL์ ๋บ ๊ฒ์ ๋๋ค. ์ด๋ AAA๊ฐ ์์คํ ์ ๋ฏธ์น๋ ์ํฅ์ ๋ฐ์ํ์ฌ ๋จ์ ํ๋ ฌ์์ ์ผ์ ๋ถ๋ถ์ ์กฐ์ ํ๋ ์ญํ ์ ํฉ๋๋ค.
- (IโAL)โ1(I - A^L)^{-1}(IโAL)โ1์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์์ ์ฌ๋ฌ ์๊ฐ ์คํ ์ ๊ฑธ์น ์ํ ๋ณํ๋ฅผ ๊ณ ๋ คํฉ๋๋ค. ์ด๋ฅผ ํตํด ์์คํ ์ ํ์ฌ ์ํ์ ๋ํ ์ ๋ณด๋ฅผ ์ถ์ถํ๊ณ , ์๊ฐ์ด ์ง๋จ์ ๋ฐ๋ผ ์ํ๊ฐ ์ด๋ป๊ฒ ๋ณํํ๋์ง ๊ณ์ฐํ ํ, ๊ทธ ์ํฅ์ ์ญ์ผ๋ก ๊ณ์ฐํ๋ ์ญํ ์ ํฉ๋๋ค.
- ์ด ๋จ๊ณ์์ ์ต์ข ์ ์ผ๋ก C~\tilde{C}C~๋ฅผ ๊ณ์ฐํ์ฌ SSM์ ๋ํ๋ด๋ ๋ฒกํฐ๋ฅผ ์ป์ต๋๋ค. ์ด๋ฅผ ํตํด ์์ฑ๋ ์ํ ๋ฒกํฐ๋ ์ดํ์ ์ปจ๋ณผ๋ฃจ์ ์ปค๋ ๊ณ์ฐ์ ์ฌ์ฉ๋ฉ๋๋ค.
-
-
โก SSM ์ผ์ฐ์ ๊ณฑ์ (Cauchy Multiplication)
-
KKK์ ๊ฐ ์ฑ๋ถ์ ์๋์ ๊ฐ์ด ์ผ์ฐ์ ๊ณฑ์ ์ ํตํด ๊ณ์ฐํฉ๋๋ค.
[k00(ฯ)k01(ฯ)k10(ฯ)k11(ฯ)]โC~โ [(ฮ1โฯ1+ฯโฮ)โ1โ BP]\begin{bmatrix} k_{00}(\omega) & k_{01}(\omega) \ k_{10}(\omega) & k_{11}(\omega) \end{bmatrix} \leftarrow \tilde{C} \cdot \left[ \left( \Delta \frac{1 - \omega}{1 + \omega} - \Lambda \right)^{-1} \cdot B P \right][k00โ(ฯ)k10โ(ฯ)โk01โ(ฯ)k11โ(ฯ)โ]โC~โ [(ฮ1+ฯ1โฯโโฮ)โ1โ BP]
- ์ฌ๊ธฐ์ ์ผ์ฐ์ ๊ณฑ์ (Cauchy Multiplication)์ ์ฌ์ฉํ์ฌ ๋ณต์กํ ํ๋ ฌ ์ฐ์ฐ์ ํจ์จ์ ์ผ๋ก ์ํํฉ๋๋ค. Cauchy ํ๋ ฌ์ ํน์ํ ํํ์ ํ๋ ฌ๋ก, ์ฌ๊ธฐ์์๋ PPP์ BBB ํ๋ ฌ์ ๊ณฑํ์ฌ ์ต์ข ์ปจ๋ณผ๋ฃจ์ ํํฐ KKK๋ฅผ ๊ณ์ฐํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
- ฮ\Lambdaฮ๋ ๋๊ฐ ํ๋ ฌ์ด๊ณ , BBB์ PPP๋ S4 ๋ชจ๋ธ์์ ์ํ์ ์ ๋ ฅ์ ๋ํ ์ ๋ญํฌ ํ๋ ฌ์ ๋๋ค.
-
-
โข Woodbury Identity ์ ์ฉ
- Woodbury Identity๋ ๋๊ท๋ชจ ํ๋ ฌ์ ์ญํ๋ ฌ์ ๊ณ์ฐํ ๋ ์ฌ์ฉํ๋ ํจ์จ์ ์ธ ๋ฐฉ๋ฒ์ ๋๋ค. ์ด๋ฅผ ์ ์ฉํ์ฌ ์ปจ๋ณผ๋ฃจ์ ํํฐ์ ๊ณ์ฐ์ ๋์ฑ ๊ฐ์ํํ ์ ์์ต๋๋ค.
- Woodbury Identity๋ ์ ๋ญํฌ ํ๋ ฌ์ ํฌํจํ๋ ์ญํ๋ ฌ์ ๋น ๋ฅด๊ฒ ๊ณ์ฐํ ์ ์๊ฒ ํด์ฃผ๋ฉฐ, A+PQโA + PQ^*A+PQโ ํํ์ ํ๋ ฌ์ Aโ1A^{-1}Aโ1๋ก ๋ฐ๊ฟ์ค๋๋ค. ์ด๋ก ์ธํด ํ๋ ฌ ์ฐ์ฐ์ด ๋ํญ ๋จ์ํด์ง๋๋ค.
-
โฃ K~(ฯ)\tilde{K}(\omega)K~(ฯ) Evaluate(ํ๊ฐ)
-
K(ฯ)K(\omega)K(ฯ)๋ ๋ชจ๋ ๊ทผ(roots of unity) ฯโฮฉL\omega \in \Omega_LฯโฮฉLโ์์ ํ๊ฐ๋ฉ๋๋ค.
K~(ฯ)โ21+ฯ[k00(ฯ)โk01(ฯ)(1+k11(ฯ))โ1k10(ฯ)]\tilde{K}(\omega) \leftarrow \frac{2}{1 + \omega} \left[ k_{00}(\omega) - k_{01}(\omega)(1 + k_{11}(\omega))^{-1}k_{10}(\omega) \right]K~(ฯ)โ1+ฯ2โ[k00โ(ฯ)โk01โ(ฯ)(1+k11โ(ฯ))โ1k10โ(ฯ)]
-
์ด ๋จ๊ณ์์ ๊ฐ ํํฐ์ ์์๊ฐ ๊ทผ์ ํตํด ํ๊ฐ๋๊ณ , ์ด๋ฅผ ํตํด ํํฐ์ ๊ฐ ์ฃผํ์ ์ฑ๋ถ์ด ๊ณ์ฐ๋ฉ๋๋ค. (๊ณ์ฐ๊ณผ์ ์๋ ์ฐธ๊ณ )
-
ฯ\omegaฯ ์ค์ : LLL๊ฐ์ ๋จ์ ๊ทผ ฯk\omega_kฯkโ๋ฅผ ๊ณ์ฐํฉ๋๋ค.
ฯk=eโ2ฯik/L,k=0,1,โฆ,Lโ1\omega_k = e^{-2\pi i k / L}, \quad k = 0, 1, \dots, L-1ฯkโ=eโ2ฯik/L,k=0,1,โฆ,Lโ1
-
K(ฯk)K(\omega_k)K(ฯkโ) ๊ณ์ฐ: ๊ฐ ฯk\omega_kฯkโ์ ๋ํด K(ฯk)K(\omega_k)K(ฯkโ)๋ฅผ ๊ณ์ฐํฉ๋๋ค. ์ด๋ ํํฐ์ ์ฃผํ์ ์๋ต์ ๋ํ๋ ๋๋ค.
K~(ฯk)โ21+ฯk[k00(ฯk)โk01(ฯk)(1+k11(ฯk))โ1k10(ฯk)]\tilde{K}(\omega_k) \leftarrow \frac{2}{1 + \omega_k} \left[ k_{00}(\omega_k) - k_{01}(\omega_k)(1 + k_{11}(\omega_k))^{-1}k_{10}(\omega_k) \right]K~(ฯkโ)โ1+ฯkโ2โ[k00โ(ฯkโ)โk01โ(ฯkโ)(1+k11โ(ฯkโ))โ1k10โ(ฯkโ)]
-
์ฃผํ์ ๋๋ฉ์ธ์์์ ์ฐ์ฐ: ๊ณ์ฐ๋ K(ฯk)K(\omega_k)K(ฯkโ)๋ฅผ ์ฌ์ฉํ์ฌ ์ ๋ ฅ ์ ํธ์ ์ฃผํ์ ์ฑ๋ถ๊ณผ ๊ณฑ์ ์ ์ํํฉ๋๋ค.
-
-
-
โค ์ญ FFT(Inverse FFT) ์ ์ฉ
- ๋ง์ง๋ง ๋จ๊ณ์์ ์ญ Fourier ๋ณํ(iFFT)์ ์ฌ์ฉํ์ฌ ํํฐ์ ์ฃผํ์ ๋๋ฉ์ธ ํํ์ ์๊ฐ ๋๋ฉ์ธ์ผ๋ก ๋ณํํฉ๋๋ค. ์ด ๊ณผ์ ์์ ์ต์ข ์ ์ธ ์ปจ๋ณผ๋ฃจ์ ์ปค๋ KKK๊ฐ ๊ณ์ฐ๋ฉ๋๋ค.KโIFFT(K~(ฯk))K \leftarrow \text{IFFT}(\tilde{K}(\omega_k))KโIFFT(K~(ฯkโ))
Experiments
- 4.1 S4 Efficiency Benchmarks : S4๋ ๊ธฐ์กด์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ ๋ฐ Transformer ๋ชจ๋ธ์ ๋นํด ๋งค์ฐ ๋น ๋ฅธ ํ์ต ์๋์ ์ ์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์๋ํฉ๋๋ค. ์คํ ๊ฒฐ๊ณผ, S4๋ ์๋์ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ ๋ชจ๋์์ ์ฐ์ํ ์ฑ๋ฅ์ ๋ณด์์ต๋๋ค.
- 4.2 Learning Long Range Dependencies : Long Range Arena (LRA)** ๋ฒค์น๋งํฌ์์ S4๋ ์ฅ๊ธฐ ์ข ์์ฑ์ ์ฒ๋ฆฌํ๋ ๋ฐ ์์ด ํ์ํ ์ฑ๋ฅ์ ๋ณด์ฌ์ค๋๋ค. ํนํ, ๊ธฐ์กด ๋ชจ๋ธ๋ค์ด ํด๊ฒฐํ์ง ๋ชปํ ์ด๋ ค์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๋ฐ ์ฑ๊ณตํ์์ต๋๋ค.
- 4.3 S4 as a General Sequence Model : S4๋ ์ด๋ฏธ์ง, ํ ์คํธ, ์ค๋์ค ๋ฑ ๋ค์ํ ๋ฐ์ดํฐ ์ ํ์์ ์ฌ์ฉ๋ ์ ์๋ ์ผ๋ฐ์ ์ธ ์ํ์ค ๋ชจ๋ธ๋ก ์ ์๋ฉ๋๋ค. ์คํ์ ํตํด S4๊ฐ ๋ค์ํ ๋ฐ์ดํฐ ์ ํ์์ ๋ฐ์ด๋ ์ฑ๋ฅ์ ๋ฐํํ๋ค๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
- 4.4 SSM Ablations: the Importance of HiPPO : HiPPO ์ด๊ธฐํ๋ฅผ ์ฌ์ฉํ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ด ์ฑ๋ฅ ํฅ์์ ๋งค์ฐ ์ค์ํ๋ค๋ ๊ฒ์ ์คํ์ ์ผ๋ก ํ์ธํ์์ต๋๋ค.
Reference
Lectures
- Efficiently Modeling Long Sequences with Structured State Spaces (๋งํฌ)
- Structured State Space Models for Deep Sequence Modeling (๋งํฌ)
Blogs
- github.com/dhkim0225/1day_1paper
- A Visual Guide to Mamba and State Space Models (๋งํฌ)
Books
- Dive into Deep Learning (๋งํฌ)
Papers
-
HiPPO: Recurrent Memory with Optimal Polynomial Projections (2020)
- ๋ ผ๋ฌธ ๋งํฌ: https://arxiv.org/pdf/2008.07669
-
LSSL: Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers (2021)
- ๋ ผ๋ฌธ ๋งํฌ: https://arxiv.org/pdf/2110.13985
-
S4: Efficiently Modeling Long Sequences with Structured State Spaces (2022)
- ๋ ผ๋ฌธ ๋งํฌ: https://arxiv.org/pdf/2111.00396