[Paper Review] Mamba: Linear-Time Sequence Modeling with Selective State Spaces
์๋ณธ ๊ฒ์๊ธ: https://velog.io/@euisuk-chung/Paper-Review-Mamba-Linear-Time-Sequence-Modeling-with-Selective-State-Spaces
์ต๊ทผ ๋ฅ๋ฌ๋ ์ํคํ
์ฒ์ ์ค์ฌ์๋ ํธ๋์คํฌ๋จธ
๊ฐ ์๋ฆฌ ์ก๊ณ ์์ต๋๋ค. ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ(LLM)๋ฟ๋ง ์๋๋ผ, ๊ทธ๋ฆผ์ ์์ฑํ๋ ๋ฐ ์ฐ์ด๋ ๋ํจ์ ๋ชจ๋ธ ๋ํ ํธ๋์คํฌ๋จธ ๊ตฌ์กฐ๋ฅผ ํ์ฉํ๊ณ ์์ต๋๋ค. ์ด์ธ์๋ ์๊ณ์ด ๋ถ์์ด๋ ์ถ์ฒ ์์คํ
๊ณผ ๊ฐ์ ๋ค์ํ ๋ถ์ผ์์ ํธ๋์คํฌ๋จธ๊ฐ ํต์ฌ์ ์ธ ์ญํ ์ ํ๊ณ ์์ต๋๋ค.
๊ทธ๋ฌ๋ ํธ๋์คํฌ๋จธ๋ฅผ ๋์ฒดํ ์ ์๋ ์๋ก์ด ์ํคํ ์ฒ๋ฅผ ๋ชจ์ํ๋ ค๋ ์ฐ๊ตฌ๋ ๊ณ์๋๊ณ ์์ผ๋ฉฐ, ๊ทธ ์ค์์๋ ํนํ ์ฃผ๋ชฉ๋ฐ๊ณ ์๋ ๊ฒ์ด State Space Model(SSM)์ ๋๋ค. ์ต๊ทผ โMamba: Linear-Time Sequence Modeling with Selective State Spacesโ๋ผ๋ ๋ ผ๋ฌธ๊ณผ ๊ทธ ๋ชจ๋ธ์ด ๊ณต๊ฐ๋๋ฉด์, SSM์ด ํธ๋์คํฌ๋จธ์ ๋์์ผ๋ก์ ๋์ฑ ๊ด์ฌ์ ๋๊ณ ์์ต๋๋ค.
- Mamba ๋ ผ๋ฌธ ๋งํฌ : https://arxiv.org/pdf/2312.00752
์ถ๊ฐ์ ์ผ๋ก ๋ฒ์จ Survey ๋ ผ๋ฌธ๋ ๋ฒ์จ ๋์๋๋ฐ์!! ํฅ๋ฏธ๋ก์ด ์ด๋ฏธ์ง๋ค๋ง ์ข reference์ฉ์ผ๋ก ๊ฐ์ ธ์ค์๋ฉด ์๋์ ๊ฐ์ต๋๋ค. Mamba ์ดํ๋ก ๋ง์ ํ์ ์ฐ๊ตฌ ๋ฐ Variation๋ค์ด ๋น ๋ฅด๊ฒ ์ฐ๊ตฌ๋๊ณ ์๋ ๊ฒ๋ค์ ํ์ธํ ์ ์์ต๋๋ค.
-
๋ง์ ๊ด์ฌ์ ๋ฐ๊ณ ์๋ SSM ๋ชจ๋ธ
์ด๋ฏธ์ง ์ถ์ฒ. State Space Model for New-Generation Network Alternative to Transformers: A Survey
-
์๋ง์ SSM Variation ๋ชจ๋ธ : new paradigm shift?!
์ด๋ฏธ์ง ์ถ์ฒ. State Space Model for New-Generation Network Alternative to Transformers: A Survey
1. Introduction (์๋ก )
์๋ก ์์๋ Mamba ๋ชจ๋ธ์ ํ์์ฑ
๊ณผ ๊ธฐ์กด Transformer ๋ชจ๋ธ์ ํ๊ณ์
์ ์ค๋ช
ํ๊ณ , Mamba๊ฐ ์ด ๋ฌธ์ ๋ฅผ ์ด๋ป๊ฒ ํด๊ฒฐํ๋์ง๋ฅผ ์๊ฐํฉ๋๋ค.
-
๊ธฐ์กด Transformer์ ๋ฌธ์ ์ : Transformer๋ Attention ๋ฉ์ปค๋์ฆ์ ๊ธฐ๋ฐ์ผ๋ก ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ๋งค์ฐ ๋ฐ์ด๋์ง๋ง, ์ํ์ค ๊ธธ์ด์ ๋ฐ๋ผ ๊ณ์ฐ ๋ณต์ก๋๊ฐ 2์ฐจ ํจ์(Quadratic)๋ก ์ฆ๊ฐํ์ฌ, ๊ธด ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ๋ค๋ฃฐ ๋ ๋งค์ฐ ๋นํจ์จ์ ์ ๋๋ค.
- ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ๋ง์ ์ฐ๊ตฌ์๋ค์ด Linear Attention๊ณผ ๊ฐ์ ๋ค์ํ ๋ฐฉ๋ฒ์ ์ ์ํ์ง๋ง, ๋๋ถ๋ถ์ ์ ๋ณด ๋ฐ๋๊ฐ ๋์ ๋ฐ์ดํฐ(์: ํ ์คํธ)์์ Transformer๋งํผ์ ์ฑ๋ฅ์ ๋ด์ง ๋ชปํ์ต๋๋ค.
-
SSM(Structured State Space Models)์ ๋ฑ์ฅ: ์ด๋ฌํ ํ๊ณ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด, Structured State Space Models(SSM)์ด ๋ฑ์ฅํ์ต๋๋ค.
- SSM์ ์ฌ๊ท์ ์ ๊ฒฝ๋ง(RNN)๊ณผ ํฉ์ฑ๊ณฑ ์ ๊ฒฝ๋ง(CNN)์ ์ด์ ์ ๊ฒฐํฉํ ๋ชจ๋ธ๋ก, ์ํ์ค์ ๊ธธ์ด์ ๋น๋กํ๋ ์ ํ์ ์ธ ๊ณ์ฐ ๋ณต์ก๋๋ฅผ ๊ฐ์ง๊ณ ์์ด ๋งค์ฐ ํจ์จ์ ์ ๋๋ค.
- ๊ทธ๋ฌ๋ SSM์ ์ ๋ณด ๋ฐ๋๊ฐ ๋์ ํ ์คํธ ๋ฐ์ดํฐ์์๋ Transformer๋งํผ์ ์ฑ๋ฅ์ ๋ด์ง ๋ชปํ์ต๋๋ค.
-
Mamba ๋ชจ๋ธ์ ๋ฑ์ฅ: Mamba๋
์ ํ ๋ฉ์ปค๋์ฆ
์ ๋์ ํ Selective State Space Model(์ ํ์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ)์ ๊ธฐ๋ฐ์ผ๋ก ํ๋ฉฐ, ๊ธด ์ํ์ค๋ฅผ ๋ค๋ฃจ๋ฉด์๋ Transformer ์์ค์ ์ฑ๋ฅ์ ์ ์งํ๋ฉด์๋ ๊ณ์ฐ ๋น์ฉ์ ์ค์ผ ์ ์๋ ๋ชจ๋ธ์ ๋๋ค.- ํนํ, ํ ์คํธ, ์ค๋์ค, ์ ์ ์ฒดํ(genomics) ๋ฑ์ ๋ค์ํ ๋ฐ์ดํฐ ์ ํ์์ ๋งค์ฐ ์ฐ์ํ ์ฑ๋ฅ์ ๋ณด์ฌ์ค๋๋ค.
2. State Space Models (์ํ ๊ณต๊ฐ ๋ชจ๋ธ)
์ด ์ฅ์์๋ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ(SSM)์ ๊ธฐ๋ณธ ๊ฐ๋ ๊ณผ ์๋ ์๋ฆฌ์ ๋ํด ์ค๋ช ํฉ๋๋ค.
-
SSM์ ์๋ ์๋ฆฌ: ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ์์คํ ์ ์ ๋ ฅ์ ๊ณ ์ฐจ์์ ์ ์ฌ ๊ณต๊ฐ(latent space)์ผ๋ก ๋ณํํ์ฌ ์ฒ๋ฆฌํ๋ ๋ฐฉ์์ผ๋ก ๋์ํฉ๋๋ค. ์ด๋ ์ํ์ ์ผ๋ก ์ฐ์ ์์คํ ์ ์ด์ฐํ(Discretization) ๊ณผ์ ์ผ๋ก ํํ๋ฉ๋๋ค.
- ์์์ผ๋ก ํํํ์๋ฉด, ์ ๋ ฅ x(t)x(t)x(t)์ ์ ์ฌ ์ํ h(t)h(t)h(t)๋ก ๋ณํํ๊ณ , ์ด๋ฅผ ํตํด ์ถ๋ ฅ y(t)y(t)y(t)์ ๋์ถํ๋ ๋ฐฉ์์ ๋๋ค.
- ์ด๋, ๊ฐ ์์ ์์ ์ํ ๊ณต๊ฐ์ ๋ณํ๋ฅผ ๋ํ๋ด๋ ์ฃผ์ ๋งค๊ฐ๋ณ์ A,B,CA, B, CA,B,C๊ฐ ์ฃผ์ด์ง๋๋ค.h(t)=Aโ h(tโ1)+Bโ x(t)h(t) = A \cdot h(t-1) + B \cdot x(t)h(t)=Aโ h(tโ1)+Bโ x(t) y(t)=Cโ h(t)y(t) = C \cdot h(t)y(t)=Cโ h(t)
-
์ฐ์ ์์คํ ์์ ์ด์ฐ ์์คํ ์ผ๋ก์ ๋ณํ: SSM์์๋ ์ฐ์์ ์ธ ์์คํ ์ ๋งค๊ฐ๋ณ์๋ฅผ ์ด์ฐํ(discretization)ํ์ฌ ๊ณ์ฐํฉ๋๋ค. ์ด๋ฅผ ํตํด ๋ชจ๋ธ์ ์ฐ์ ๋ฐ์ดํฐ๋ฅผ ํจ์จ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.
- ์ด ๋, ์ ํ ์๊ฐ ๋ถ๋ณ ์์คํ (LTI, Linear Time-Invariant System)์ ๊ฐ๋ ์ด ์ฌ์ฉ๋ฉ๋๋ค. ์ด ์์คํ ์ ์๊ฐ์ ๋ฐ๋ผ ๋ณํ์ง ์๋ ์ ํ์ ์ธ ์ฐ์ฐ์ ์ํํ๋ฏ๋ก, ๊ธด ์ํ์ค ์ฒ๋ฆฌ์ ๋งค์ฐ ํจ์จ์ ์ ๋๋ค.
๐ก ๊ฒฐ๊ตญ ์ด์ฐํํ๋ฉด RNN์ด๋ ๊ฐ์ ๊ฑฐ ์๋๊ฐ?
DSBA ์ฐ๊ตฌ์ค์ ์ฒ์ฌ์ ์์ฌ์์ PYSR์ ๋ณด๋ฉด ์ด์ ๋ํ ๋ต๋ณ์ ์ป์ ์ ์์ต๋๋ค.
โ๏ธ SSM (State Space Model)์ ์ฐ์์ฑ
- SSM์ ์ฐ์์ ์ธ ์๊ฐ ํ๋ฆ์ ๋ฐ๋ผ ์์คํ ์ ์ํ๋ฅผ ๋ชจ๋ธ๋งํ๋ ๋ฐฉ์์ ๋๋ค. ์ด๋ A์ B๋ ์ฐ์์ ์์คํ ์ ํํํ๋ ์ค์ํ ๋งคํธ๋ฆญ์ค๋ค๋ก, ์๊ฐ์ ๋ฐ๋ฅธ ์์คํ ์ ์ํ ๋ณํ๋ฅผ ๊ธฐ์ ํฉ๋๋ค.
- A: ์ํ ๋ณํ๋ฅผ ๊ฒฐ์ ํ๋ ๋งคํธ๋ฆญ์ค. ์ด์ ์ํ Xtโ1X_{t-1}Xtโ1โ์ ๊ณฑํด์ ธ์ ์์คํ ์ ์ํ๊ฐ ์ด๋ป๊ฒ ๋ณํ๋์ง ์ ์ํฉ๋๋ค.
- B: ์ ๋ ฅ์ ์ํ๋ก ๋ณํํ๋ ๋งคํธ๋ฆญ์ค. ์ ๋ ฅ UtU_tUtโ๋ฅผ ๋ฐ์ ์ํ์ ๋ฐ์ํ๋ ์ญํ ์ ํฉ๋๋ค.
- Aฬ ์ Bฬ ๋ SSM์์ ์ด์ฐํ ๋ ๋ฒ์ ์ ๋งคํธ๋ฆญ์ค๋ค๋ก, ์ฐ์์ ์ธ ์์คํ ์ ์ด์ฐ์ ์ธ ํํ๋ก ๋ณํํ์ฌ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ์ ์๊ฒ ๋ง๋ญ๋๋ค. Aฬ ์ Bฬ ๋ ์ฐ์์ ์ธ SSM ๋ชจ๋ธ์ ๋ํจ์๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ด์ฐ์ ์ํ์ค ์ฒ๋ฆฌ์ ๋ง๊ฒ ๋ณํ๋ ๊ฒ์ ๋๋ค.
- SSM์ ์ฅ์ ์ ์ด๋ฌํ ์ฐ์์ ์ธ ํ๋ฆ์ ๊ธฐ๋ฐ์ผ๋ก ์์คํ ์ ๋ฏธ์ธํ ๋ณํ๋ฅผ ๋ ์ ๋ชจ๋ธ๋งํ ์ ์๋ค๋ ์ ์ ๋๋ค.
- ์๊ฐ ๋ณํ๊ฐ ์ฐ์์ ์ธ ์์คํ ์์ ๋ฐ์ดํฐ๋ฅผ ์ ๋ฐ์ํ ์ ์๊ธฐ ๋๋ฌธ์ ์์คํ ์ ๋ฌผ๋ฆฌ์ ์ฑ์ง์ ๋ ์ ํํ๊ฒ ๋ฐ์ํ ์ ์์ต๋๋ค.
โ๏ธ RNN (Recurrent Neural Network)์ ์ด์ฐํ
- RNN์ ๊ธฐ๋ณธ์ ์ผ๋ก ์ด์ฐํ๋ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ์ค์ ์ ๋ก๋๋ค. RNN์ ๊ฐ ์๊ฐ ์คํ ์์ ์ด์ ์ํ์ ํ์ฌ ์ ๋ ฅ์ ๋ฐํ์ผ๋ก ๋ค์ ์ํ๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค.
- RNN์ ์ฐ์์ ์ธ ์๊ฐ ํ๋ฆ์ ๋ช ์์ ์ผ๋ก ๋ชจ๋ธ๋งํ์ง ์์ผ๋ฉฐ, ์ด์ ์ํ์ ํ์ฌ ์ํ ๊ฐ์ ๋จ์ํ ๊ด๊ณ์ ์์กดํฉ๋๋ค.
- RNN์ ํ๊ณ๋ ์๊ฐ์ ์ฐ์์ฑ์ ๋ช ํํ๊ฒ ๋ค๋ฃจ์ง ์๊ธฐ ๋๋ฌธ์, ์๊ฐ์ ๋ฐ๋ฅธ ๋ฏธ์ธํ ๋ณํ๋ฅผ ๋ฐ์ํ๋ ๋ฐ์๋ ํ๊ณ๊ฐ ์์ ์ ์๋ค๋ ์ ์ ๋๋ค.
- ์ฐ์์ ์ธ ์๊ฐ ํ๋ฆ์ ๋ฐ์ํ์ง ์๋ ๊ตฌ์กฐ์ด๋ฏ๋ก, ๋ฌผ๋ฆฌ์ ์๊ฐ ํ๋ฆ์ด ์ค์ํ ๋ฌธ์ ์ ๋ํด์๋ ์ฑ๋ฅ์ด ์ ํ์ ์ผ ์ ์์ต๋๋ค.
- SSM ๊ตฌ์กฐ: SSM์ ๊ตฌ์กฐ๋ ์ฃผ๋ก ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ ์ฌ ์ํ๋ก ๋ณํํ ํ ์ด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ถ๋ ฅ์ ๋์ถํ๋ ๋ฐฉ์์ผ๋ก, ๊ฐ ์ฑ๋์ด ๋ ๋ฆฝ์ ์ผ๋ก ์๋ํ๋ ํน์ง์ด ์์ต๋๋ค. ์ด๋ก ์ธํด ๊ณ์ฐ์ ๋ณ๋ ฌํ๊ฐ ๊ฐ๋ฅํด์ ธ ๋งค์ฐ ํจ์จ์ ์ผ๋ก ์๋ํ ์ ์์ต๋๋ค.
-
SSM ์ํคํ ์ฒ ๊ฐ์ : SSM(์ํ๊ณต๊ฐ๋ชจ๋ธ) ์ํคํ ์ฒ๋ ๋ ๋ฆฝ์ ์ธ ์ํ์ค ๋ณํ ๋ชจ๋ธ๋ก, ์๋ ํฌ ์๋ ์ ๊ฒฝ๋ง ์ํคํ ์ฒ์ ํตํฉ๋ ์ ์์ต๋๋ค.
- SSM ์ํคํ ์ฒ๋ SSNN(State Space Neural Networks)๋ผ๊ณ ๋ ํ๋ฉฐ, ์ด ๊ฒฝ์ฐ์๋ SSM ๋ ์ด์ด๊ฐ CNN(ํฉ์ฑ๊ณฑ์ ๊ฒฝ๋ง) ๋ ์ด์ด์ ์ ์ฌํ ์ญํ ์ ํฉ๋๋ค.
- Introduction์๋ ์๋ ์ ์๋ ค์ง ๋ช ๊ฐ์ง SSM ์ํคํ
์ฒ๋ฅผ ๊ฐ๋จํ๊ฒ ์๊ฐํฉ๋๋ค:
- Linear Attention (Katharopoulos et al. 2020): ์๊ฐ ์ฃผ์์ ๊ทผ์ฌ๋ก, ์ฌ๊ท์ฑ์ ํฌํจํ๋ฏ๋ก ์ผ์ข ์ ์ ํ SSM์ผ๋ก ๋ณผ ์ ์์ต๋๋ค.
- H3 (Dao, Fu, Saab et al. 2023): ์ด ๋ชจ๋ธ์ S4๋ฅผ ์ฌ์ฉํ๋ ์ฌ๊ท๋ฅผ ์ผ๋ฐํํ๋ฉฐ, SSM์ด ๋ ๊ฐ์ ๊ฒ์ดํธ๊ฐ ์๋ ์ฐ๊ฒฐ ์ฌ์ด์ ์์นํ๋ ํํ์ ๋๋ค. H3๋ ํ์ค ์ง์ญ ํฉ์ฑ์ ์ถ๊ฐํ์ฌ ์ด๋ฅผ shift-SSM์ผ๋ก ๊ฐ์ฃผํฉ๋๋ค.
- Hyena (Poli et al. 2023): H3์ ๋์ผํ ์ํคํ ์ฒ๋ฅผ ์ฌ์ฉํ์ง๋ง, S4 ๋ ์ด์ด๋ฅผ MLP ๋งค๊ฐ๋ณ์ํ๋ ์ ์ญ ํฉ์ฑ์ผ๋ก ๋์ฒดํฉ๋๋ค.
- RetNet (Y. Sun et al. 2023): ์ด ์ํคํ ์ฒ๋ ์ถ๊ฐ์ ์ธ ๊ฒ์ดํธ๋ฅผ ๋ํ๊ณ , ๋ ๋จ์ํ SSM์ ์ฌ์ฉํ์ฌ ๋ค์ค ํค๋ ์ฃผ์(MHA) ๋ณํ์ ํตํด ๋์์ ์ธ ๋ณ๋ ฌ ๊ณ์ฐ ๊ฒฝ๋ก๋ฅผ ์ ๊ณตํฉ๋๋ค.
- RWKV (B. Peng et al. 2023): ์ด ๋ชจ๋ธ์ ์ธ์ด ๋ชจ๋ธ๋ง์ ์ํด ์ค๊ณ๋ RNN์ผ๋ก, ๋ค๋ฅธ ์ ํ ์ฃผ์ ๊ทผ์ฌ์ ์ผ์ข ์ธ Attention-free Transformer(S. Zhai et al. 2021)๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํฉ๋๋ค. ์ฃผ์ โWKVโ ๋ฉ์ปค๋์ฆ์ LTI ์ฌ๊ท๋ฅผ ํฌํจํ๋ฉฐ, ๋ ๊ฐ์ SSM์ ๋น์จ๋ก ๋ณผ ์ ์์ต๋๋ค.
(์ฐธ๊ณ ) LSSL ๋ฐ deepSSM ์ฐจ์ ๊ณ์ฐ ๋ฐฉ์
- ์ฌ๋ผ์ด๋์ ๋์จ ์์๊ณผ ์ ๊ฐ ์ด ์์์ด ์์ดํฉ๋๋ค. ์ ๋ ์์ ์ด ์์์ ๋ง๋๋ก hidden dim์ h๋ก, time input์ x๋ก ์ ์ํ์ต๋๋ค.
1. LTI(Linear Time-Invariant) ์์คํ ์ ์ ์
- LTI ์์คํ
์ ์๊ฐ์ ๋ฐ๋ผ ์์คํ
์ ํน์ฑ์ด ๋ณํ์ง ์์ผ๋ฉฐ, ์ ํ์ ์ด๊ณ ์๊ฐ ๋ถ๋ณ์ ์
๋๋ค. ์ฆ, ๊ฐ์ ์
๋ ฅ์ด ์ฃผ์ด์ง๋ฉด ์ธ์ ๋ ์ง ๋์ผํ ๋ฐฉ์์ผ๋ก ์ฒ๋ฆฌ๋๊ณ , ์๊ฐ์ ํ๋ฆ์ ๋ฐ๋ผ ์์คํ
์ ํ๋์ด ๋ฌ๋ผ์ง์ง ์๋ ํน์ฑ์ ๊ฐ์ง๋๋ค.
- ์ด๋ ๋ ๊ฐ์ง ์ฃผ์ ํน์ฑ์ ์ํด ์ ์๋ฉ๋๋ค:
- ์ ํ์ฑ: ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ ๊ด๊ณ๊ฐ ์ ํ์ ๋๋ค. ์ฆ, ์ ๋ ฅ์ ํฉ์ด ์ถ๋ ฅ์ ํฉ์ผ๋ก ์ ํ์ ์ผ๋ก ๋ณํ๋ฉ๋๋ค.
- ์๊ฐ ๋ถ๋ณ์ฑ: ์์คํ ์ ์ํ ๋ณํ๋ ์๊ฐ์ ์์กดํ์ง ์์ต๋๋ค. ์ฆ, ์๊ฐ์ ๋ฐ๋ผ ์์คํ ์ ์ฑ๋ฅ์ด๋ ๋์์ด ๋ฌ๋ผ์ง์ง ์์ต๋๋ค.
- ์ด๋ ๋ ๊ฐ์ง ์ฃผ์ ํน์ฑ์ ์ํด ์ ์๋ฉ๋๋ค:
โ ๊ทธ๋ ๋ค๋ฉด ์์ฐ์ค๋ฝ๊ฒ ๋๋ ์๋ฌธ์ ์?
- ๐ค : ํ .. ์๋ฅผ dmodeld_{\text{model}}dmodelโ ์ฐจ์์ผ๋ก ์ด๋ป๊ฒ ํ์ฅ์ ์ํฌ๊น? ํ๋ ์๊ฐ์ ํ๊ฒ ๋ฉ๋๋ค.
- ๋จ์ํ๊ฒ ์๊ฐํด๋ณด๋ฉด, ์๋ ๊ทธ๋ฆผ์ SSM ๋ชจ๋ธ์ฒ๋ผ ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ด dmodeld_{\text{model}}dmodelโ ์ฐจ์์ผ๋ก ํ์ฅํ๋ ์์ผ๋ก ์๊ฐํด๋ณผ ์๋ ์๊ธด ํฉ๋๋ค๋ง!โ ๏ธโ ๏ธ
ํ์ฅ๋ ๋ฐฉ์ ์:
h(t)=Aโ h(tโ1)+Bโ x(t)h(t) = A \cdot h(t-1) + B \cdot x(t)h(t)=Aโ h(tโ1)+Bโ x(t)
y(t)=Cโ h(t)y(t) = C \cdot h(t)y(t)=Cโ h(t)
๊ฐ ์์์ ๋๋ฉ์ ๋ณํ:
h(t)โRnh(t) \in \mathbb{R}^nh(t)โRn: ํ๋ ์คํ ์ดํธ๋ ์ฌ์ ํ nnn์ฐจ์์ ๊ฐ์ง๋๋ค.
x(t)โRdmodelx(t) \in \mathbb{R}^{d_{\text{model}}}x(t)โRdmodelโ: ์ ๋ ฅ์ด ์ด์ dmodeld_{\text{model}}dmodelโ ์ฐจ์์ ๋ฒกํฐ๋ก ํ์ฅ๋ฉ๋๋ค.
y(t)โRdmodely(t) \in \mathbb{R}^{d_{\text{model}}}y(t)โRdmodelโ: ์ถ๋ ฅ ์ญ์ dmodeld_{\text{model}}dmodelโ ์ฐจ์์ ๋ฒกํฐ๋ก ํ์ฅ๋ฉ๋๋ค.
AโRnรnA \in \mathbb{R}^{n \times n}AโRnรn: ํ๋ ์คํ ์ดํธ์ ์ ๋ฐ์ดํธ๋ฅผ ๋ด๋นํ๋ฉฐ ์ฐจ์์ ๋ณํ์ง ์์ต๋๋ค.
BโRnรdmodelB \in \mathbb{R}^{n \times d_{\text{model}}}BโRnรdmodelโ: ์ ๋ ฅ์ ํ๋ ์คํ ์ดํธ๋ก ๋ณํํ๋ ์ญํ ์ ํ๋ฉฐ, dmodeld_{\text{model}}dmodelโ ์ฐจ์์ ์ฒ๋ฆฌํฉ๋๋ค.
CโRdmodelรnC \in \mathbb{R}^{d_{\text{model}} \times n}CโRdmodelโรn: ํ๋ ์คํ ์ดํธ๋ฅผ ์ถ๋ ฅ์ผ๋ก ๋ณํํ๋ ๋งคํธ๋ฆญ์ค.
๐ก ํ์ง๋ง!! ์์ ๊ฐ์ ๋ฐฉ๋ฒ์ ์ฑ๋ฆฝํ์ง ์์ต๋๋ค.
- ๊ฒฐ๋ก ๋ถํฐ ๋ง์๋๋ฆฌ์๋ฉด, State Space Model(SSM)์ ๋ณธ์ง์ ์ผ๋ก LTI(Linear Time-Invariant) ์์คํ ์ด๋ฉฐ, ์ด ํน์ฑ ๋๋ฌธ์ ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ด dmodeld_{\text{model}}dmodelโ ์ฐจ์์ผ๋ก ํ์ฅ๋์์ ๋, ์ฐจ์๋ณ๋ก ๋ ๋ฆฝ์ ์ธ SSM ์ฒ๋ฆฌ๊ฐ ํ์ํ๊ฒ ๋ฉ๋๋ค.
โ ์ ๊ทธ๋ผ ์ฑ๋ฆฝํ์ง ์๋๊ฐ
=> SSM์ ๋ณธ์ง์ ์ผ๋ก LTI ์์คํ
-
SSM ์์ฒด๊ฐ LTI ์์คํ ์ด๊ธฐ ๋๋ฌธ์, ๊ธฐ๋ณธ์ ์ผ๋ก ์ ํ์ฑ๊ณผ ์๊ฐ ๋ถ๋ณ์ฑ์ด ๋ณด์ฅ๋์ด์ผ ํฉ๋๋ค. ์ด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก, SSM์ ๋ค์์ ๋ ๊ฐ์ง ์ค์ํ ํน์ฑ์ ๊ฐ์ง๋๋ค:
- ์ ํ์ฑ: ์์คํ ์ ์ ๋ ฅ๊ณผ ์ถ๋ ฅ ์ฌ์ด์ ๊ด๊ณ๋ ์ ํ์ ๋๋ค. ์ฆ, ์ ๋ ฅ์ด ๋ณํ๋ฉด ์ถ๋ ฅ๋ ์ ํ์ ์ผ๋ก ๋ณํ๋ฉฐ, ์ด๋ ์์คํ ์ ๋์์ ๊ฒฐ์ ํ๋ ์ ํ ๋งคํธ๋ฆญ์ค์ ์ํด ์ ์ด๋ฉ๋๋ค.
- ์๊ฐ ๋ถ๋ณ์ฑ: ์์คํ ์ ์๊ฐ์ ๋ฐ๋ผ ๋ณํํ์ง ์๊ณ ํญ์ ๋์ผํ ๋ฐฉ์์ผ๋ก ๋์ํฉ๋๋ค. ์ ๋ ฅ์ด ์ธ์ ๋ค์ด์ค๋ , ์์คํ ์ ์ํ์ ์ถ๋ ฅ์ ๋์ผํ ๋ฐฉ์์ผ๋ก ๊ณ์ฐ๋ฉ๋๋ค.
-
๋ฐ๋ผ์ SSM์ ๋ชจ๋ ์ฐ์ฐ์ ์๊ฐ์ ๋ฐ๋ผ ๋ณํ์ง ์์์ผ ํ๋ฉฐ, ์ ๋ ฅ ์ฐจ์๊ณผ ์๊ด์์ด ๋์ผํ ๋ฐฉ์์ผ๋ก ๋์ํด์ผ ํฉ๋๋ค.
์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ด dmodeld_{\text{model}}dmodelโ ์ฐจ์์ผ๋ก ํ์ฅ๋ ๋์ ๋ฌธ์
-
๋ง์ฝ ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ด dmodeld_{\text{model}}dmodelโ ์ฐจ์์ผ๋ก ํ์ฅ๋๋ค๋ฉด, ๋ชจ๋ ์ฐจ์์ ๋์์ ์ฒ๋ฆฌํ๊ธฐ ์ํด ํ๋์ ๊ณตํต๋ SSM ๋งคํธ๋ฆญ์ค(AAA, BBB, CCC)๋ฅผ ์ฌ์ฉํ๋ ๋ฐฉ์์ผ๋ก๋ ๋ฌธ์ ๊ฐ ์๊ธธ ์ ์์ต๋๋ค. ์๋ํ๋ฉด:
- SSM์ ๋ณธ์ง์ ์ผ๋ก LTI ์์คํ ์ด๊ธฐ ๋๋ฌธ์, ๊ฐ ์ฐจ์์ ์๋ก ๋ ๋ฆฝ์ ์ผ๋ก ์ฒ๋ฆฌ๋์ด์ผ๋ง ์ ํ์ฑ๊ณผ ์๊ฐ ๋ถ๋ณ์ฑ์ด ๋ณด์ฅ๋ฉ๋๋ค.
- ์ฌ๋ฌ ์ฐจ์์ ํ๋์ SSM ์์คํ ์์ ์ฒ๋ฆฌํ๋ ๋ฐฉ์์, ์ฐจ์ ๊ฐ ์ํธ์์ฉ์ ์ผ๊ธฐํ ์ ์์ต๋๋ค. ์ด๋ ๊ฐ ์ฐจ์์ด ๊ฐ๋ณ์ ์ผ๋ก ๋ ๋ฆฝ์ ์ผ๋ก ์ฒ๋ฆฌ๋์ง ์๊ธฐ ๋๋ฌธ์, ์๊ฐ์ ๋ฐ๋ฅธ ์ ๋ ฅ ์ฒ๋ฆฌ ๋ฐฉ์์ด ๋ฌ๋ผ์ง ๊ฐ๋ฅ์ฑ์ด ์๊น๋๋ค.
-
๋ฐ๋ผ์ dmodeld_{\text{model}}dmodelโ ์ฐจ์์ ์ ๋ ฅ์ ํ๋์ SSM์ผ๋ก ์ฒ๋ฆฌํ๋ ๋ฐฉ์์ LTI ์์คํ ์ ์๊ตฌ ์ฌํญ์ ์๋ฐฐํ ์ ์์ต๋๋ค. ์ฐจ์๋ณ๋ก ๋ ๋ฆฝ์ ์ธ ์ฒ๋ฆฌ๊ฐ ์ด๋ฃจ์ด์ง์ง ์์ผ๋ฉด ์ฐจ์ ๊ฐ์ ์ํธ์์ฉ์ด ๋ฐ์ํ๊ณ , ์๊ฐ์ ๋ฐ๋ผ ๊ฒฐ๊ณผ๊ฐ ๋ฌ๋ผ์ง ์ ์๊ธฐ ๋๋ฌธ์ ๋๋ค.
๐ก ๊ทธ๋ ๋ค๋ฉด ์ด๋ป๊ฒ ๋๋๊ฐ?
=> ๊ฐ ์ฐจ์์ SSM์ ๋ ๋ฆฝ์ ์ผ๋ก ์ฒ๋ฆฌ
- ์ค๋ช
:
- dmodeld_{\text{model}}dmodelโ ์ฐจ์์ ์ ๋ ฅ ๋ฒกํฐ๊ฐ ์ฃผ์ด์ง ๋, ๊ฐ ์ฐจ์์ ๊ฐ๋ณ์ ์ธ SSM์ ๊ฑฐ์นฉ๋๋ค.
- ์ฆ, dmodeld_{\text{model}}dmodelโ๊ฐ์ SSM์ด ๋ ๋ฆฝ์ ์ผ๋ก ์กด์ฌํ๋ฉฐ, ๊ฐ๊ฐ nรnn \times nnรn ํฌ๊ธฐ์ AAA ๋งคํธ๋ฆญ์ค๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ ์คํ ์ดํธ๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค.
- ์ด ๊ฒฝ์ฐ ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ ์ฐจ์์ ๋ง์ถ๊ธฐ ์ํด BBB์ CCC์ ์ฐจ์๋ dmodeld_{\text{model}}dmodelโ์ ๋ฐ๋ผ ํ์ฅ๋ฉ๋๋ค.
- ์ต์ข
๋๋ฉ์
:
- AโRnรnรdmodelA \in \mathbb{R}^{n \times n \times d_{\text{model}}}AโRnรnรdmodelโ
- BโRnรdmodelB \in \mathbb{R}^{n \times d_{\text{model}}}BโRnรdmodelโ
- CโRdmodelรnC \in \mathbb{R}^{d_{\text{model}} \times n}CโRdmodelโรn
3. Selective State Space Models (์ ํ์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ)
์ด ์ฅ์์๋ ์ ํ์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ(SSSM, Selective State Space Models)์ ๋ํด ์ค๋ช ํ๋ฉฐ, ์ด๋ฅผ ํตํด ๊ธฐ์กด SSM์ ์ฑ๋ฅ์ ์ด๋ป๊ฒ ๊ฐ์ ํ ์ ์๋์ง ๋ค๋ฃน๋๋ค.
๐ Figure 1 ์ค๋ช
- xtx_txtโ (์ ๋ ฅ ๋ฐ์ดํฐ)
- ์ ๋ ฅ ๋ฐ์ดํฐ xtx_txtโ (์ด๋ก)๋ ์ํ์ค์ ํ์ฌ ์์ ์์ ๋ค์ด์ค๋ ๋ฐ์ดํฐ์ ๋๋ค. ์ด ๋ฐ์ดํฐ๋ ์ฌ๋ฌ ์ฑ๋(D)๋ก ๋๋์ด์ ธ ์๊ณ , ๊ฐ๊ฐ์ ์ฑ๋์ด ๋ ๋ฆฝ์ ์ผ๋ก ์ฒ๋ฆฌ๋ฉ๋๋ค.
- ์๋ฅผ ๋ค์ด, ๊ทธ๋ฆผ์์๋ D=5D = 5D=5๋ก, 5๊ฐ์ ์ ๋ ฅ ์ฑ๋์ ์๋ฏธํฉ๋๋ค.
- htโ1h_{t-1}htโ1โ (์ด์ ์์ ์ ์ํ)
- htโ1h_{t-1}htโ1โ๋ ์ด์ ์์ ์์ ๊ณ์ฐ๋ ์ ์ฌ ์ํ(latent space)๋ฅผ ์๋ฏธํฉ๋๋ค. ์ด ์ ์ฌ ์ํ๋ ์๊ฐ์ ๋ฐ๋ผ ์ด์ด์ ธ ์์ผ๋ฉฐ, ์ด์ ์์ ์ ์ ๋ณด๊ฐ ํ์ฌ ์์ ์ ์ํฅ์ ๋ฏธ์นฉ๋๋ค.
- ์๋ฅผ ๋ค์ด ๊ทธ๋ฆผ์์๋ N=4N = 4N=4์ ๊ณ ์ฐจ์ ๊ณต๊ฐ์์ ์ ์๋์ด ์์ต๋๋ค.
- Bt,Ct,A,ฮtB_t, C_t, A, \Delta_tBtโ,Ctโ,A,ฮtโ (SSM์ ์ฃผ์ ๋งค๊ฐ๋ณ์)
- ์ด
๋ค ๊ฐ์ง ๋งค๊ฐ๋ณ์
๋ SSM์์ ์ค์ํ ์ญํ ์ ํฉ๋๋ค. ๊ฐ ๋งค๊ฐ๋ณ์๋ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๊ณ , ์ ์ฌ ์ํ hth_thtโ๋ฅผ ์ ๋ฐ์ดํธํ๋ฉฐ, ์ต์ข ์ถ๋ ฅ yty_tytโ๋ฅผ ๊ณ์ฐํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
- BtB_tBtโ: ์ ๋ ฅ ๋ฐ์ดํฐ xtx_txtโ์ ์ํธ์์ฉํ์ฌ ์๋ก์ด ์ํ๋ฅผ ๋ง๋ญ๋๋ค.
- (์ฐธ๊ณ ์์) โถ h(t)=Aโ h(tโ1)+Bโ x(t)h(t) = A \cdot h(t-1) + B \cdot x(t)h(t)=Aโ h(tโ1)+Bโ x(t)
- ์ด๋ BtB_tBtโ๋ ์ ๋ ฅ ์์กด์ ์ด๋ฉฐ, ์ ๋ ฅ์ ๋ฐ๋ผ ๋์ ์ผ๋ก ๋ณํํฉ๋๋ค.
- AAA: ์ ์ฌ ์ํ hth_thtโ๋ฅผ ์ ๋ฐ์ดํธํ๋ ๋ฐ ํ์ํ ์ค์ํ ๋งค๊ฐ๋ณ์์ ๋๋ค. ์ด์ ์ํ์ ํ์ฌ ์ ๋ ฅ์ ๊ธฐ๋ฐํ์ฌ ์๋ก์ด ์ํ๋ฅผ ๊ณ์ฐํ ๋ ์ฌ์ฉ๋ฉ๋๋ค.
- CtC_tCtโ: ๊ณ์ฐ๋ ์ ์ฌ ์ํ hth_thtโ๋ฅผ ์ถ๋ ฅ yty_tytโ๋ก ๋ณํํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
- (์ฐธ๊ณ ์์) โถ y(t)=Cโ h(t)y(t) = C \cdot h(t)y(t)=Cโ h(t)
- ์ด๋ฅผ ํตํด ์ต์ข ์ ์ผ๋ก ์ํ์ค์ ๊ฐ ์์ ์์ ๋ชจ๋ธ์ ์ถ๋ ฅ์ ์ป์ต๋๋ค.
ฮt\Delta_tฮtโ: ์ด ๋งค๊ฐ๋ณ์๋ ์๊ฐ ์ฐจ์์ ์กฐ์ ํ๋ ์ญํ ์ ํฉ๋๋ค.
- ์๊ฐ์ ๋ฐ๋ผ ์ํ ๊ณต๊ฐ์์์ ๋ณํ๋ฅผ ์กฐ์ ํ์ฌ ๋ชจ๋ธ์ด ์ํ์ค๋ฅผ ๋ฐ๋ผ ์ค์ํ ์ ๋ณด๋ฅผ ๊ธฐ์ตํ๊ฑฐ๋ ์๋๋ก ๋์์ค๋๋ค.
- ์ ํ ๋ฉ์ปค๋์ฆ (Selection Mechanism)
- ์ ํ ๋ฉ์ปค๋์ฆ์ ์ด ๋ชจ๋ธ์ ํต์ฌ์ ์ธ ์์๋ก, SSSM์์ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ ํ์ ์ผ๋ก ์ฒ๋ฆฌํ๊ฒ ๋ง๋ญ๋๋ค.
- ์ด ๋ฉ์ปค๋์ฆ์ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ํํฐ๋งํ์ฌ ์ค์ํ ์ ๋ณด๋ง์ ์ ํํ๊ณ , ๋ถํ์ํ ์ ๋ณด๋ ๋ฌด์ํ ์ ์๋๋ก ๋์ต๋๋ค. โจ ์ฆ, ์ ๋ณด ์์ถ ๋ฐ ์ ํ์ ๊ธฐ์ต์ ์ํํ๋ ๋ฐฉ์์ ๋๋ค. โจ
- ๊ทธ๋ฆผ์์๋ ์ด ๋ฉ์ปค๋์ฆ์ด ์ ๋ ฅ xtx_txtโ์ ์ํธ์์ฉํ์ฌ BtB_tBtโ์ ฮt\Delta_tฮtโ๋ฅผ ์กฐ์ ํ๋ ๋ฐฉ์์ผ๋ก ํํ๋ฉ๋๋ค. (ํ๋)
- GPU ๋ฉ๋ชจ๋ฆฌ ๊ณ์ธต (GPU Memory Hierarchy)
- SSSM์ ์ค์ํ ํน์ง ์ค ํ๋๋ ํ๋์จ์ด ์นํ์ ์ธ ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค.
- ์ด ๋ชจ๋ธ์ GPU์ ๊ณ ์ ๋ฉ๋ชจ๋ฆฌ(SRAM)์ ๋์ฉ๋ ๋ฉ๋ชจ๋ฆฌ(HBM) ๊ณ์ธต์ ํจ๊ณผ์ ์ผ๋ก ํ์ฉํ์ฌ, ์ ์ฌ ์ํ์ ๊ณ์ฐ์ ์ต์ ํํฉ๋๋ค.
- ์ด ๊ตฌ์กฐ๋ฅผ ํตํด SSSM์ ํฐ ์ํ์ค ๋ฐ์ดํฐ๋ ํจ์จ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์์ผ๋ฉฐ, ๊ณ์ฐ ์์์ ํจ์จ์ ์ผ๋ก ํ์ฉํ์ฌ ๋ ๋น ๋ฅด๊ฒ ๊ฒฐ๊ณผ๋ฅผ ๋์ถํ ์ ์์ต๋๋ค.
3.1 Motivation: Selection as a Means of Compression
์ด ์น์ ์ ์ ํ ๋ฉ์ปค๋์ฆ(selection mechanism)์ด ์ ์ค์ํ์ง๋ฅผ ์ค๋ช ํฉ๋๋ค. ์ฃผ์ ๋ด์ฉ์ ๋ฐ์ดํฐ ์์ถ ๊ณผ ๊ด๋ จ๋ ๋ฌธ์ ์ด๋ฉฐ, ์ ํ์ ๋ฉ์ปค๋์ฆ์ด ์ด๋ฅผ ์ด๋ป๊ฒ ํด๊ฒฐํ๋์ง๋ฅผ ๋ค๋ฃน๋๋ค.
- ์ ํ์ ํ์์ฑ: ์ํ์ค ๋ชจ๋ธ๋ง์ ์ฃผ์ ๊ณผ์ ์ค ํ๋๋ ์ปจํ
์คํธ(๋ฌธ๋งฅ) ์ ๋ณด๋ฅผ ํจ๊ณผ์ ์ผ๋ก ์์ถํ๋ ๊ฒ์
๋๋ค.
Transformer
๋ ์ ๋ณด์ ์์ถ์ ํ์ง ์๊ณ , ๋ชจ๋ ์ ๋ณด๋ฅผ ์ ์ฅํ์ฌ ์ฒ๋ฆฌํ๋ ๋ฐฉ์์ผ๋ก ์๋ํ์ง๋ง, ์ด๋ก ์ธํด ๋นํจ์จ์ ์ธ ๊ณ์ฐ์ด ๋ฐ์ํฉ๋๋ค.- ๋ฐ๋ฉด,
RNN
๊ณผ ๊ฐ์ ์ฌ๊ท ๋ชจ๋ธ์ ์ ๋ณด๋ฅผ ์์ถํ์ฌ ์ฒ๋ฆฌํ์ง๋ง, ์์ถ๋ ์ ๋ณด๊ฐ ์์ค๋๋ฉด ์ฑ๋ฅ์ด ๋จ์ด์ง ์ ์์ต๋๋ค.
- ์ ํ์ ์์ถ: SSSM(S4)์ ์ปจํ
์คํธ ์ ๋ณด๋ฅผ ์ ํ์ ์ผ๋ก ์์ถํ์ฌ ์ ์ฅํ๊ฑฐ๋ ์์ ์ ์์ต๋๋ค.
- ์ฆ, ์ํ์ค ๋ด์์ ์ค์ํ ์ ๋ณด๋ ๊ธฐ์ตํ๊ณ , ๋ถํ์ํ ์ ๋ณด๋ ๋ฌด์ํ๋ ๋ฐฉ์์ ๋๋ค. ์ด๋ฅผ ํตํด ์ ๋ณด ์์ถ์ ์ต์ ํํ๊ณ , ๋ชจ๋ธ์ด ์ํ์ค ์ ๋ฐ์ ๊ฑธ์ณ ์ค์ํ ์ ๋ณด๋ฅผ ๋์น์ง ์๋๋ก ํฉ๋๋ค.
๋ณธ๋ฌธ์์๋ Figure2
๋ฅผ ํตํด โ์ ๋ณด ์ ํ ๋ฐ ๋ณต์ฌ ์์
โ ๋๋ โ์ ๋ณด ํํฐ๋ง ์์
โ Task๋ค์ ์๊ฐํ๋ฉฐ, ๊ธฐ๋ณธ์ ์ธ LTI ์์คํ
์ ์ ํ์ ๋ฉ์ปค๋์ฆ๊ณผ ๋ฌธ๋งฅ ์ฒ๋ฆฌ ๋ฅ๋ ฅ์ด ํ์ํ Selective Copying Task์ Induction Heads Task์ ์ ํฉํ์ง ์์์ ์ง์ ํฉ๋๋ค.
-
1. Copying Task (Figure2 ์ผ์ชฝ ์ด๋ฏธ์ง)
-
๋ฌธ์
: Copying Task๋ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ๊ธฐ์ตํ๊ณ , ํน์ ์์น์์ ๋ณต์ฌํ๋ ์์ ์ ๋๋ค.- ์ด๋, ์ ๋ ฅ๊ณผ ์ถ๋ ฅ ๊ฐ์ ๊ฐ๊ฒฉ์ด ์ผ์ ํ๊ฒ ์ ์ง๋ฉ๋๋ค. ๋ชจ๋ธ์ ์ํ์ค์ ์ผ์ ํ ๊ฐ๊ฒฉ์ ์๋ ๋ฐ์ดํฐ๋ฅผ ๊ธฐ์ตํ๊ณ , ๊ทธ๋๋ก ๋ณต์ฌํ๋ฉด ๋ฉ๋๋ค.
-
ํด๊ฒฐ๋ฐฉ๋ฒ
: ์ด ์์ ์ ๋งค์ฐ ๊ฐ๋จํ ํจํด์ด๊ธฐ ๋๋ฌธ์ ์๊ฐ ๋ถ๋ณ ๋ชจ๋ธ(Time-Invariant Model)๋ก ์ฝ๊ฒ ํด๊ฒฐํ ์ ์์ต๋๋ค.- ์๊ฐ ๋ถ๋ณ ๋ชจ๋ธ์ ๋ชจ๋ ์์ ์์ ๋์ผํ ๋ฐฉ์์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ๋ชจ๋ธ๋ก, ์ ํ ์ฌ๊ท ๋ชจ๋ธ(Linear Recurrence Model)์ด๋ ๊ธ๋ก๋ฒ ํฉ์ฑ๊ณฑ ๋ชจ๋ธ(Global Convolution Model) ๊ฐ์ ๋ฐฉ์์ด ์ฌ์ฉ๋ ์ ์์ต๋๋ค.
- ์ด๋ฌํ ๋ชจ๋ธ์ ์ ๋ ฅ ๊ฐ์ ์ผ์ ํ ๊ฐ๊ฒฉ์ ์ธ์ํ๊ณ , ๊ทธ์ ๋ฐ๋ผ ๋ฐ์ดํฐ๋ฅผ ๋ณต์ฌํ๋ ๋ฐ ์ ํฉํฉ๋๋ค.
-
๊ฒฐ๊ณผ
: Copying Task๋ ๊ฐ๊ฒฉ์ด ๊ณ ์ ๋์ด ์์ด, LTI(Linear Time-Invariant) ๋ชจ๋ธ๋ก ์ฝ๊ฒ ์ฒ๋ฆฌํ ์ ์๋ ๋จ์ํ ์์ ์ ๋๋ค.- ๋ชจ๋ธ์ ์๊ฐ ํ๋ฆ์ ๋ฐ๋ผ ๋์ผํ ๋ฐฉ์์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ๋ณต์ฌํ์ฌ ์ด ์์ ์ ์๋ฒฝํ ํด๊ฒฐํ ์ ์์ต๋๋ค.
-
-
2. Selective Copying Task (Figure2 ์ค๋ฅธ์ชฝ ์ ์ด๋ฏธ์ง)
-
๋ฌธ์
: Selective Copying Task๋ Copying Task์ ๋ฌ๋ฆฌ, ์ ๋ ฅ๊ณผ ์ถ๋ ฅ ๊ฐ์ ๊ฐ๊ฒฉ์ด ์ผ์ ํ์ง ์๊ณ ๋๋คํ๊ฒ ๋ณ๋๋ฉ๋๋ค.- ๋ชจ๋ธ์ ์ํ์ค ๋ด์์ ์ค์ํ ์ ๋ณด๋ฅผ ์ ํ์ ์ผ๋ก ๊ธฐ์ตํ๊ณ , ๋๋จธ์ง ๋ถํ์ํ ์ ๋ณด๋ ๋ฌด์ํด์ผ ํฉ๋๋ค.
-
ํด๊ฒฐ๋ฐฉ๋ฒ
: ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ค๋ฉด ์๊ฐ ๊ฐ๋ณ ๋ชจ๋ธ(Time-Varying Model)๊ณผ ์ ํ์ ๋ฉ์ปค๋์ฆ(Selection Mechanism)์ด ํ์ํฉ๋๋ค.์ ํ์ ๋ฉ์ปค๋์ฆ
์ ๋ชจ๋ธ์ด ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ๋ถ์ํ์ฌ ์ค์ํ ์ ๋ณด๋ฅผ ์ ํ์ ์ผ๋ก ๊ธฐ์ตํ๊ณ , ๋ถํ์ํ ๋ฐ์ดํฐ๋ ๋ฌด์ํ ์ ์๊ฒ ๋ง๋ญ๋๋ค.- Selective State Space Model(SSSM)๊ณผ ๊ฐ์ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ๋ฉด, ์ ๋ ฅ ์ํ์ค ๋ด์์ ์ด๋ค ์ ๋ณด๊ฐ ์ค์ํ์ง ์ ํํ๊ณ , ์ค์ํ ์ ๋ณด๋ฅผ ์ ํ์ ์ผ๋ก ๊ธฐ์ตํ ์ ์์ต๋๋ค.
-
๊ฒฐ๊ณผ
: Selective Copying Task๋ ์๊ฐ ๊ฐ๋ณ์ ์ฒ๋ฆฌ์ ์ ํ์ ๋ฉ์ปค๋์ฆ์ ํตํด ํด๊ฒฐ๋ฉ๋๋ค.- ์ด๋ก ์ธํด ๋ชจ๋ธ์ ๊ฐ ์์ ์์ ์ค์ํ ๋ฐ์ดํฐ๋ฅผ ์ ํํ๊ณ , ๋ถํ์ํ ์ ๋ณด๋ ๋ฌด์ํ์ฌ ๋ถ๊ท์นํ ์ ๋ ฅ ๊ฐ๊ฒฉ์์๋ ์ ํํ๊ฒ ๋ณต์ฌ ์์ ์ ์ํํ ์ ์์ต๋๋ค.
-
-
3. Induction Heads Task (Figure2 ์ค๋ฅธ์ชฝ ํ๋จ ์ด๋ฏธ์ง)
-
๋ฌธ์
: Induction Heads Task๋ ์ฐ๊ด ๊ธฐ์ต(Associative Recall) ๋ฌธ์ ๋ก, ๋ชจ๋ธ์ด ์ด์ ์ ํ์ต๋ ์ ๋ณด๋ฅผ ๋ฐํ์ผ๋ก ๋ฌธ๋งฅ(Context)์ ์ดํดํ๊ณ , ๋ฌธ๋งฅ์ ๋ง๋ ์ถ๋ ฅ์ ์ ์ถํด์ผ ํฉ๋๋ค.- ์ด ์์ ์์๋ ์ฃผ์ด์ง ์ํ์ค์์ ํน์ ํจํด์ด ์ฃผ์ด์ง ํ, ๋น์ทํ ํจํด์ด ๋ค์ ๋์ฌ ๋ ์ฌ๋ฐ๋ฅธ ์ถ๋ ฅ์ ์์ธกํ ์ ์์ด์ผ ํฉ๋๋ค. ์ด๋ ๋ฌธ๋งฅ ๊ธฐ๋ฐ์ ํ์ต๊ณผ ํ์์ ์๊ตฌํฉ๋๋ค.
-
ํด๊ฒฐ๋ฐฉ๋ฒ
: ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด์๋ ๋ชจ๋ธ์ด ๋ฌธ๋งฅ์ ํ์ตํ๊ณ ์ฐ๊ด ์ง์ด ๊ธฐ์ตํ ์ ์์ด์ผ ํฉ๋๋ค.- ๋จ์ํ Copying ์์ ๊ณผ ๋ฌ๋ฆฌ, ๋ชจ๋ธ์ ์ด์ ์์ ์ ๋ฌธ๋งฅ์ ์ฐ๊ด ์ง์ด ๊ธฐ์ตํ๊ณ , ํ์ํ ์์ ์์ ์ด๋ฅผ ํ์ํ์ฌ ์ ์ ํ ์ถ๋ ฅ์ ์ ์ถํด์ผ ํฉ๋๋ค.
- ์ ํ์ ๋ฉ์ปค๋์ฆ๊ณผ ํจ๊ป ์ฐ๊ด ๊ธฐ์ต ๋ฉ์ปค๋์ฆ์ ์ฌ์ฉํ๋ฉด, ๋ชจ๋ธ์ด ๋ฌธ๋งฅ์ ๊ธฐ๋ฐ์ผ๋ก ํ์ํ ์ ๋ณด๋ฅผ ๊ธฐ์ตํ๊ณ , ์๋ก์ด ์ ๋ ฅ์ด ๋ค์ด์ฌ ๋ ๊ทธ ์ ๋ณด๋ฅผ ๋ค์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
- Selective State Space Model (SSSM)์ ์ด๋ฌํ ๋ฌธ๋งฅ ๊ธฐ๋ฐ ํ์ต์ ์ ํฉํ ๊ตฌ์กฐ๋ฅผ ์ ๊ณตํฉ๋๋ค.
-
๊ฒฐ๊ณผ
: Induction Heads Task๋ ๋ฌธ๋งฅ ๊ธฐ๋ฐ์ ์ฐ๊ด ๊ธฐ์ต์ด ์ค์ํ ์์ ์ ๋๋ค. ์ ํ์ ๋ฉ์ปค๋์ฆ๊ณผ ์ฐ๊ด ๊ธฐ์ต ๋ฉ์ปค๋์ฆ์ ํตํด ๋ชจ๋ธ์ ๋ฌธ๋งฅ ์ ๋ณด๋ฅผ ํ์ตํ๊ณ , ์๋ก์ด ์ ๋ ฅ๊ณผ ๊ด๋ จ๋ ํจํด์ ์ฐ๊ด ์ง์ด ํ์ํ ์ ์๊ฒ ๋ฉ๋๋ค. ์ด๋ ๋ฌธ๋งฅ์ ๋ง๋ ์์ธก์ ๊ฐ๋ฅํ๊ฒ ํด์ค๋๋ค.
-
์ ๋ฆฌ
Copying Task
: ์ผ์ ํ ๊ฐ๊ฒฉ์ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ๋จ์ํ ๋ณต์ฌํ๋ ์์ ์ผ๋ก, ์๊ฐ ๋ถ๋ณ ๋ชจ๋ธ(LTI)์ ํตํด ์ฝ๊ฒ ํด๊ฒฐ๋ฉ๋๋ค. ์ด ์์ ์ ๊ณ ์ ๋ ๊ตฌ์กฐ์ LTI ๋ชจ๋ธ์ด ์๊ฐ ์ธ์์ ํ์๋ก ํ์ง๋ง, ์ ๋ณด ์ ํ์ด ๋จ์ํ๊ธฐ ๋๋ฌธ์ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.Selective Copying Task
: ์ ๋ ฅ๊ณผ ์ถ๋ ฅ ๊ฐ์ ๊ฐ๊ฒฉ์ด ๋๋คํ๊ฒ ๋ณ๋ํ๋ ์์ ์ผ๋ก, ์๊ฐ ๊ฐ๋ณ ๋ชจ๋ธ๊ณผ ์ ํ์ ๋ฉ์ปค๋์ฆ์ ํตํด ์ค์ํ ์ ๋ณด๋ง ์ ํ์ ์ผ๋ก ๊ธฐ์ตํจ์ผ๋ก์จ ํด๊ฒฐ๋ฉ๋๋ค. ์ด ์์ ์ ๋ด์ฉ ์ธ์์ด ํ์์ ์ด๋ฏ๋ก, ๊ธฐ์กด LTI ๋ชจ๋ธ์์๋ ํจ๊ณผ์ ์ผ๋ก ์ฒ๋ฆฌํ๊ธฐ ์ด๋ ต์ต๋๋ค.Induction Heads Task
: ์ฐ๊ด ๊ธฐ์ต ๋ฌธ์ ๋ก, ๋ฌธ๋งฅ ์ ๋ณด๋ฅผ ํ์ตํ๊ณ ์ฐ๊ด ์ง์ด ํ์ํ๋ ๋ฅ๋ ฅ์ด ํ์ํฉ๋๋ค. ์ ํ์ ๋ฉ์ปค๋์ฆ๊ณผ ๋ฌธ๋งฅ ๊ธฐ๋ฐ ๊ธฐ์ต์ ํตํด ํด๊ฒฐํ ์ ์์ต๋๋ค. ์ด ์์ ์ ๋ณต์กํ ์๊ด๊ด๊ณ๋ฅผ ์ดํดํ๊ณ ๊ธฐ์ตํ๋ ๊ฒ์ด ์ค์ํ์ฌ, ๋ณด๋ค ๋ฐ์ ๋ ๋ชจ๋ธ ๊ตฌ์กฐ๊ฐ ์๊ตฌ๋ฉ๋๋ค.
3.2 Improving SSMs with Selection (์ ํ์ ํตํ SSM ์ฑ๋ฅ ํฅ์)
์ด ์น์
์์๋ ์ ํ ๋ฉ์ปค๋์ฆ์ ๊ธฐ๋ณธ SSM ๊ตฌ์กฐ์ ํตํฉํ์ฌ, ๋ชจ๋ธ์ ์ฑ๋ฅ์ ํฅ์์ํค๋ ๋ฐฉ๋ฒ์ ์ค๋ช
ํฉ๋๋ค. SSM (S4)
๋ ๊ณ ์ ๋ ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ๋จํ ๊ตฌ์กฐ๋ก ์๋ํ๋ ๋ฐ๋ฉด, SSM + Selection (S6)
๋ ์
๋ ฅ ๋ฐ์ดํฐ๋ฅผ ๋์ ์ผ๋ก ์ฒ๋ฆฌํ์ฌ ์ ํ์ ์ผ๋ก ์ ๋ณด๋ฅผ ๊ฐ์กฐํ๊ฑฐ๋ ๋ฌด์ํ ์ ์๋ ๋ณด๋ค ๋ณต์กํ ๊ตฌ์กฐ์
๋๋ค.
-
์ ํ ๋ฉ์ปค๋์ฆ: SSSM(S4+Selection, S6)์ SSM์ ์ฃผ์ ๋งค๊ฐ๋ณ์(ฮ,B,C\Delta, B, Cฮ,B,C)๋ฅผ ์ ๋ ฅ์ ๋ฐ๋ผ ์ ํ์ ์ผ๋ก ๋ณ๋์ํด์ผ๋ก์จ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค.
- ์ด๋ฅผ ํตํด ๋ชจ๋ธ์ ์ํ์ค์ ์ค์ํ ๋ถ๋ถ์ ์ ํ์ ์ผ๋ก ๊ธฐ์ตํ๊ณ , ๋ถํ์ํ ๋ถ๋ถ์ ๋ฌด์ํ ์ ์์ต๋๋ค.
-
์๊ฐ ๋ถ๋ณ์ฑ์ ํฌ๊ธฐํ๊ณ ํจ์จ์ฑ ๊ทน๋ํ: ์ด ์ ํ ๋ฉ์ปค๋์ฆ์ ์๊ฐ ๋ถ๋ณ์ฑ์ ์ ์งํ์ง ์๊ธฐ ๋๋ฌธ์, ์๊ฐ์ ๋ฐ๋ผ ๋์ ์ผ๋ก ๋ณํํ๋ ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ์ ์์ง๋ง, ์ด๋ฅผ ํจ์จ์ ์ผ๋ก ๊ตฌํํ๊ธฐ ์ํด์๋ ์ถ๊ฐ์ ์ธ ์๊ณ ๋ฆฌ์ฆ์ด ํ์ํฉ๋๋ค.
๊ธฐ๋ณธ SSM (S4)
๊ณผ SSM + Selection (S6)
์ ์๊ณ ๋ฆฌ์ฆ์ ํ๋ฒ ๋น๊ตํ๋ฉด์ ์ฐจ์ด์ ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
๊ธฐ๋ณธ SSM ๊ตฌ์กฐ (Algorithm 1: SSM (S4))
-
์ ๋ ฅ:
- xxx: ์ ๋ ฅ ๋ฐ์ดํฐ๋ก, ํํ๋ (B, L, D)์ ๋๋ค.
- ์ฌ๊ธฐ์ B๋ ๋ฐฐ์น ํฌ๊ธฐ(batch size), L์ ์ํ์ค ๊ธธ์ด, D๋ ์ฑ๋ ์๋ฅผ ๋ํ๋ ๋๋ค.
-
์ถ๋ ฅ:
- yyy: ์ถ๋ ฅ ๋ฐ์ดํฐ๋ก, ์ ๋ ฅ๊ณผ ๊ฐ์ ํํ๋ฅผ ๊ฐ์ต๋๋ค. ์ฆ, (B, L, D)์ ๋๋ค.
-
๋งค๊ฐ๋ณ์:
- AAA, BBB, CCC: ์ด ์ธ ๊ฐ์ ๋งค๊ฐ๋ณ์๋ SSM์ ํต์ฌ ํ๋ผ๋ฏธํฐ๋ก, ๊ฐ๊ฐ ์ ์ฌ ์ํ์ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ๋ณํํ๋ ์ญํ ์ ํฉ๋๋ค.
- ฮ\Deltaฮ: ์๊ฐ ์ค์ผ์ผ์ ์กฐ์ ํ๋ ๋งค๊ฐ๋ณ์๋ก, ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์์ ์๊ฐ์ ๋ณํ์ ๊ด๋ จ๋ ์ญํ ์ ํฉ๋๋ค.
-
์๋ ๋ฐฉ์:
- AAA, BBB, CCC ํ๋ผ๋ฏธํฐ๊ฐ ์ค์ ๋ฉ๋๋ค.
- ฮ\Deltaฮ ๊ฐ์ด ํ๋ผ๋ฏธํฐ๋ก ์ค์ ๋ฉ๋๋ค.
- ์ฃผ์ด์ง ฮ\Deltaฮ์ AAA, BBB, CCC ๊ฐ๋ค์ ๊ธฐ๋ฐ์ผ๋ก ์ด์ฐํ(discretization)๊ฐ ์ด๋ฃจ์ด์ง๋๋ค.
- ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ด ์๊ฐ ๋ถ๋ณ(Time-invariant)์ธ ์ฌ๊ท(recursion) ๋๋ ํฉ์ฑ๊ณฑ(convolution)์ ํตํด ๊ณ์ฐ๋ฉ๋๋ค.
๐ก ํต์ฌ ํน์ง: ์ด ์๊ณ ๋ฆฌ์ฆ์ ์๊ฐ ๋ถ๋ณ์ ๊ตฌ์กฐ๋ก, ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ๊ณ ์ ๋ ๋ฐฉ์์ผ๋ก ์ฒ๋ฆฌํฉ๋๋ค. ์ด๋ ๋์ผํ ๋งค๊ฐ๋ณ์๋ฅผ ๋ชจ๋ ์์ ์ ์ ์ฉํ๋ค๋ ์๋ฏธ์ ๋๋ค.
โฑ ๊ณ์ฐ ๋ฐฉ์:
์ ํ ์ฌ๊ท
๋๋ํฉ์ฑ๊ณฑ ์ฐ์ฐ
์ ์ฌ์ฉํ์ฌ ์๊ฐ ๋ถ๋ณ์ ์ฒ๋ฆฌ๋ง ๊ฐ๋ฅํฉ๋๋ค.
๐ (์ฌํ) S4 (SSM) ์๊ณ ๋ฆฌ์ฆ์ ํ ์ ์ฐ์ฐ ๋ฐ ์ฐจ์ ๋ณํ
- S4๋ ๊ณ ์ ๋ ํ๋ผ๋ฏธํฐ์ ์ด์ฐํ ๋ฐฉ์์ ์ฌ์ฉํ์ฌ ๋ชจ๋ ์ํ์ค์ ๋์ผํ ์ฐ์ฐ์ ์ ์ฉํ๊ณ , ๊ทธ ๊ฒฐ๊ณผ๋ก ์ผ์ ํ recurrence ๋๋ convolution์ ์ํํฉ๋๋ค.
-
์ ๋ ฅ ํ ์ (x):
(B, L, D)
์ ํํ๋ฅผ ๊ฐ์ง๋ฉฐ, B๋ ๋ฐฐ์น ํฌ๊ธฐ, L์ ์ํ์ค ๊ธธ์ด, D๋ ๊ฐ ํ ํฐ์ ์ฐจ์์ ๋ํ๋ ๋๋ค.- ์ฆ,
B
๊ฐ์ ์ํ์ ๋ํดL
๊ฐ์ ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ๋ฉฐ, ๊ฐ ์ํ์ค๋D
์ฐจ์์ ๋ฒกํฐ๋ก ํํ๋ฉ๋๋ค.
- ์ฆ,
-
S4์์๋
(D, N)
์ ๊ณ ์ ๋ ํ๋ผ๋ฏธํฐ๊ฐ ๋ชจ๋ ์ํ์ค์ ๋์ผํ๊ฒ ์ ์ฉ๋ฉ๋๋ค. -
ํ๋ผ๋ฏธํฐ A, B, C:
A
,B
,C
๋ ๋ชจ๋(D, N)
ํํ๋ก ์กด์ฌํ๋ฉฐ, ์ฌ๊ธฐ์ D๋ ์ ๋ ฅ ์ฐจ์, N์ ์จ๊ฒจ์ง ์ฐจ์(hidden state)์ ํฌ๊ธฐ์ ๋๋ค.A
: Structured NรNN \times NNรN ๋งคํธ๋ฆญ์ค๋ก ์ฐ์ฐ์ ๋ด๋นํฉ๋๋ค.B
:(D, N)
ํฌ๊ธฐ์ ํ๋ผ๋ฏธํฐ๋ก์, ์ ๋ ฅ ํ ์์ ๊ณฑํด์ ธ ์๋ก์ด ์ํ(state)๋ฅผ ์์ฑํ๋ ์ญํ ์ ํฉ๋๋ค.C
: ์ญ์(D, N)
ํฌ๊ธฐ๋ฅผ ๊ฐ์ง๋ฉฐ ์ถ๋ ฅ ํ ์ ์์ฑ์ ์ํฅ์ ๋ฏธ์นฉ๋๋ค.
- ์ด์ฐํ (discretization): ์ฐ์์ ์์คํ
์ ์ด์ฐํํ์ฌ ์ฐ์ฐ์ ์ํํ๋๋ฐ, ์ด๋ ์ฌ์ฉํ๋
ฮ
๋(D)
ํฌ๊ธฐ์ ํ๋ผ๋ฏธํฐ์ ๋๋ค. ์ด ํ๋ผ๋ฏธํฐ๋ ์๊ฐ ๊ฐ๊ฒฉ์ ์ด์ฐํํ์ฌ A, B ๋งคํธ๋ฆญ์ค์ ๊ฐ์ ๋ณํํฉ๋๋ค. - S4์์๋ ๊ณ ์ ๋ ฮ๊ฐ ์ฌ์ฉ๋ฉ๋๋ค. ์ด์ฐํ๋ ฮ๋ ๊ฐ ์ํ์ค์ ๋ํด ๊ฐ๊ฐ์ ๋งคํธ๋ฆญ์ค A, B์ ๊ณฑํด์ ธ hth_thtโ๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค.
- ์ต์ข
์ถ๋ ฅ y:
(B, L, D)
ํฌ๊ธฐ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฐํํ๋ฉฐ, ์ด๋ time-invariant ๋ฐฉ์์ผ๋ก recurrence๋ convolution ์ฐ์ฐ์ด ์ ์ฉ๋ฉ๋๋ค.
โ (๊ฒฐ๋ก ) S4๋ ๊ณ ์ ๋ ํ๋ผ๋ฏธํฐ์ ์ด์ฐํ ๋ฐฉ์์ ์ฌ์ฉํ์ฌ ๋ชจ๋ ์ํ์ค์ ๋์ผํ ์ฐ์ฐ์ ์ ์ฉํ๊ณ , ๊ทธ ๊ฒฐ๊ณผ๋ก ์ผ์ ํ recurrence ๋๋ convolution์ ์ํํฉ๋๋ค.
์ ํ์ SSM ๊ตฌ์กฐ (Algorithm 2: SSM + Selection (S6))
-
์ ๋ ฅ ๋ฐ ์ถ๋ ฅ:
- S4์ ๋์ผํ๊ฒ, ์ ๋ ฅ ๋ฐ์ดํฐ๋ xxx๋ก (B, L, D)์ ํํ๋ฅผ ๊ฐ์ง๋ฉฐ, ์ถ๋ ฅ ๋ฐ์ดํฐ๋ ๊ฐ์ (B, L, D) ๊ตฌ์กฐ์ ๋๋ค.
-
์ฃผ์ ์ฐจ์ด์ :
- ์ ํ ๋ฉ์ปค๋์ฆ ์ ์ฉ: ์
๋ ฅ ๋ฐ์ดํฐ์ ๋ฐ๋ผ ํ๋ผ๋ฏธํฐ๊ฐ ๋ณํํฉ๋๋ค.
- ์ฆ, S6์์๋ ์ ๋ ฅ ์์กด์ ์ธ ์ ํ(selectivity)์ด ์ถ๊ฐ๋์ด ์์ ์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๋ ๋ฐฉ์์ผ๋ก ์ฒ๋ฆฌ๋ฉ๋๋ค.
- ์ ํ์ SSM์์๋ BtB_tBtโ, CtC_tCtโ, ฮt\Delta_tฮtโ ํ๋ผ๋ฏธํฐ๊ฐ ์ ๋ ฅ์ ๋ฐ๋ผ ๋์ ์ผ๋ก ๋ณํํฉ๋๋ค
- ์ ํ ๋ฉ์ปค๋์ฆ ์ ์ฉ: ์
๋ ฅ ๋ฐ์ดํฐ์ ๋ฐ๋ผ ํ๋ผ๋ฏธํฐ๊ฐ ๋ณํํฉ๋๋ค.
-
๋งค๊ฐ๋ณ์ ๋ณํ:
- BBB, CCC: S4์์๋ ๊ณ ์ ๋ ๋งค๊ฐ๋ณ์์์ผ๋, S6์์๋ sB(x)s_B(x)sBโ(x), sC(x)s_C(x)sCโ(x)์ ๊ฐ์ ํจ์๋ก ์ ๋ ฅ xxx์ ๋ฐ๋ผ ๋ณํํฉ๋๋ค.
- ฮ\Deltaฮ: S4์์๋ ๊ณ ์ ๋ ๊ฐ์ด์์ผ๋, S6์์๋ sฮ(x)s_{\Delta}(x)sฮโ(x)๋ฅผ ํตํด ์ ๋ ฅ์ ๋ฐ๋ผ ๋ณํํฉ๋๋ค. ์ด๋ ๋ชจ๋ธ์ด ์์ ์ ๋ฐ๋ผ ๊ฐ๋ณ์ ์ธ ์๊ฐ ์ค์ผ์ผ์ ์ ์ฉํ ์ ์๊ฒ ๋ง๋ญ๋๋ค.
-
์๋ ๋ฐฉ์:
- ๋งค๊ฐ๋ณ์ AAA, BBB, CCC์ ์๊ฐ ์ค์ผ์ผ ฮ\Deltaฮ๋ ์ ๋ ฅ ๋ฐ์ดํฐ xxx์ ๋ฐ๋ผ ๋ณํํฉ๋๋ค.
- ๊ฐ ์์ ์์ ์ฌ๊ท์ ๊ณ์ฐ(recurrence)๋ง ์ํ๋๋ฉฐ, ์ด๋ ์๊ฐ์ ๋ฐ๋ผ ๋ณํํ๋ time-varying ๋ชจ๋ธ์ ๋๋ค.
๐ก ํต์ฌ ํน์ง: S6๋ ์๊ฐ ๊ฐ๋ณ(time-varying) ๊ตฌ์กฐ๋ก, ์ ๋ ฅ์ ๋ฐ๋ผ ๋งค๋ฒ ๋ค๋ฅธ ๋ฐฉ์์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค. ์ด ์ ํ ๋ฉ์ปค๋์ฆ์ ์ค์ํ ์ ๋ณด๋ ๊ธฐ์ตํ๊ณ , ๋ถํ์ํ ์ ๋ณด๋ ๋ฌด์ํ ์ ์๋๋ก ๋ง๋ญ๋๋ค.
โฑ ๊ณ์ฐ ๋ฐฉ์: ์๊ฐ ๊ฐ๋ณ์ ์ด๊ธฐ ๋๋ฌธ์ ์ฌ๊ท์ ์ฐ์ฐ๋ง์ ์ํํ๋ฉฐ, ์ํ์ค ์ ๋ฐ์์ ์ค์ํ ์ ๋ณด๋ฅผ ์ ํ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.
๐ (์ฌํ) S6 (SSM + Selection) ์๊ณ ๋ฆฌ์ฆ์ ํ ์ ์ฐ์ฐ ๋ฐ ์ฐจ์ ๋ณํ
- S6๋ ์ ๋ ฅ ์ข ์์ ์ธ ํ๋ผ๋ฏธํฐ์ ์๊ฐ ๋ณ์ด์ ๋ฐ๋ฅธ ์ฐ์ฐ์ ์ฌ์ฉํ์ฌ ์ํ์ค๋ง๋ค ๋ค๋ฅธ ์ฐ์ฐ์ด ์ด๋ฃจ์ด์ง๋ฉฐ, ์ด๋ ๋์ ๋ชจ๋ธ๋ง์ ๋ ์ ํฉํฉ๋๋ค.
- ์
๋ ฅ ํ
์ (x):
(B, L, D)
๋ก ๋์ผํ์ง๋ง, ์ ๋ ฅ ๊ฐ์ ๋ฐ๋ผ ๋ค์ํ ๋งคํธ๋ฆญ์ค ์ฐ์ฐ์ด ์ด๋ฃจ์ด์ง๋๋ค. -
S6์์๋ ์ ๋ ฅ
x
์ ๋ฐ๋ผ(B, L, N)
์ฐจ์์ ํ๋ผ๋ฏธํฐ๋ค์ด ๊ฐ ์ํ์ค๋ณ๋ก ๋ค๋ฅด๊ฒ ์์ฑ๋ฉ๋๋ค. ์ด๋ฅผ ํตํด ๊ฐ ์ํ์ค๋ง๋ค ๋ค๋ฅธ ๋งคํ์ด ์ผ์ด๋ฉ๋๋ค. -
ํ๋ผ๋ฏธํฐ ๋ณํ (sB, sC, sฮ): S6์์๋ ํ๋ผ๋ฏธํฐ๋ค์ด ์ ๋ ฅ์ ์ข ์์ ์ผ๋ก ๋ณํ๋ฉ๋๋ค.
sB(x)
: ์ ๋ ฅx
์ ๋ฐ๋ผ(B, L, N)
์ฐจ์์ผ๋ก ๋ณํ๋ฉ๋๋ค. ์ฆ, ๊ฐ ๋ฐฐ์น B์ ๊ฐ ์ํ์ค L์ ๋ํด, ์จ๊ฒจ์ง ์ฐจ์ N์ ์์ฑํฉ๋๋ค. ์ด๋ ๊ธฐ์กด S4์์ ๋ชจ๋ ์ํ์ค๊ฐ ๋์ผํ B ๋งคํธ๋ฆญ์ค๋ฅผ ์ฌ์ฉํ๋ ๊ฒ๊ณผ ๋ฌ๋ฆฌ, ์ด์ ๋ ๊ฐ ์ํ์ค๋ง๋ค ๋ค๋ฅธ B๊ฐ ์ ์ฉ๋ฉ๋๋ค.sC(x)
: ์ญ์ ์ ๋ ฅ์ ๋ฐ๋ผ(B, L, N)
์ฐจ์์ผ๋ก ๋ณํ๋ฉ๋๋ค.sฮ(x)
: ์ ๋ ฅ์ ๋ฐ๋ผ ์๊ฐ ๊ฐ๊ฒฉ์ ๋ํ๋ด๋ฮ
์ญ์(B, L, D)
๋ก ๋ณํ๋ฉ๋๋ค. ์ฆ, ๊ฐ ๋ฐฐ์น์ ๊ฐ ์ํ์ค๋ง๋ค ๋ค๋ฅธ ฮ ๊ฐ์ด ์ฃผ์ด์ง๋๋ค.
- ์ด์ฐํ ์ฐ์ฐ (discretization): S4์ ์ ์ฌํ๊ฒ ์ด์ฐํ๋ฅผ ํตํด ์ฐ์ฐ์ด ์ด๋ฃจ์ด์ง์ง๋ง, ์ฌ๊ธฐ์๋ ์ ๋ ฅ์ ์ข ์์ ์ผ๋ก ๋ณํ๋ ฮ๊ฐ ์ฌ์ฉ๋๋ฏ๋ก ๋ ๋ณต์กํ ํํ์ ์ฐ์ฐ์ด ๋ฐ์ํฉ๋๋ค.
-
S6์์๋
sฮ(x)
๊ฐ ๊ฐ ์ํ์ค๋ณ๋ก ๋ค๋ฅด๊ฒ ๊ณ์ฐ๋ฉ๋๋ค. ์ด์ฐํ๋ ฮ๋ ๊ฐ ์ํ์ค์ ๋ํด ๊ฐ๊ฐ์ ๋งคํธ๋ฆญ์ค A, B์ ๊ณฑํด์ ธ hth_thtโ๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค.
โ๏ธ ์ด๋ ๊ฐ ์ํ์ค๋ ๊ณ ์ ํ ฮ๋ฅผ ๊ฐ์ง๊ณ ์๊ธฐ ๋๋ฌธ์ S6์์๋ ํ ํฐ๋ง๋ค ์๋ก ๋ค๋ฅธ ์ฐ์ฐ์ด ์ ์ฉ๋ฉ๋๋ค.
- ์ต์ข
์ถ๋ ฅ y: S6์ ์ต์ข
์ถ๋ ฅ ์ญ์
(B, L, D)
์ฐจ์์ ๊ฐ์ง์ง๋ง, S4์ ๋ฌ๋ฆฌ ์ ๋ ฅ์ ๋ฐ๋ผ ๋์ ์ผ๋ก ๋ณํํ recurrence ๋ฐฉ์์ ์ฌ์ฉํฉ๋๋ค. ํนํ, ๊ฐ ์ํ์ค๋ง๋ค ๋ค๋ฅด๊ฒ ์ด์ฐํ๋ ํ๋ผ๋ฏธํฐ๊ฐ ์ ์ฉ๋๊ธฐ ๋๋ฌธ์ ๊ฐ ํ ํฐ์ ๋ง๋ ์ฐ์ฐ์ด ์ํ๋ฉ๋๋ค.
โ (๊ฒฐ๋ก ) S6๋ ์ ๋ ฅ ์ข ์์ ์ธ ํ๋ผ๋ฏธํฐ์ ์๊ฐ ๋ณ์ด์ ๋ฐ๋ฅธ ์ฐ์ฐ์ ์ฌ์ฉํ์ฌ ์ํ์ค๋ง๋ค ๋ค๋ฅธ ์ฐ์ฐ์ด ์ด๋ฃจ์ด์ง๋ฉฐ, ์ด๋ ๋์ ๋ชจ๋ธ๋ง์ ๋ ์ ํฉํฉ๋๋ค.
๐ฌ Time In-Variant? ๋ฌด์จ ๋ป์ด์ง?
์๊ฐ ๋ถ๋ณ์ ์ฒ๋ฆฌ
๋ผ๋ ๊ฐ๋ ์ โ๋ชจ๋ธ์ ๋งค๊ฐ๋ณ์(๊ฐ์ค์น)๊ฐ ์๊ฐ์ ๋ฐ๋ผ ๋ณํ์ง ์๋๋ค๋ ๊ฒโ์ ์๋ฏธํฉ๋๋ค.
- ์ฆ, ๋ชจ๋ธ์ด ์ ๋ ฅ ์ํ์ค์ ๊ฐ ์์ (t)์ ๋ํด ๋์ผํ ๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํ๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.
- ์ด๋ ๋ค์๊ณผ ๊ฐ์ ์๋ฏธ๋ฅผ ๊ฐ์ง๋๋ค:
์ ์ ์ธ ๊ฐ์ค์น
: ๊ธฐ์กด SSM ๋ชจ๋ธ์์ ์ฌ์ฉ๋๋ ๋งค๊ฐ๋ณ์ A, B, C ๋ฑ์ ์์ ๋ง๋ค ๊ณ ์ ๋์ด ์์ต๋๋ค. ๋ฐ๋ผ์, ๊ฐ์ ์ ๋ ฅ์ ๋ํด์๋ ํญ์ ๊ฐ์ ์ถ๋ ฅ์ ์์ฑํฉ๋๋ค.
- ์๋ฅผ ๋ค์ด, ๊ณผ๊ฑฐ์ ์ ๋ ฅ์ด ๋ฏธ๋์ ์ถ๋ ฅ์ ์ํฅ์ ๋ฏธ์น ๋, ์ถ๋ ฅ์ ์ํฅ์ ์ฃผ๋ ๊ฐ์ค์น๊ฐ ๋ณํ์ง ์๊ธฐ ๋๋ฌธ์ ํน์ ์ ๋ ฅ์ ๋ํด ์ ํฉํ๊ฒ ์กฐ์ ๋์ง ์์ต๋๋ค.
์ ๋ ฅ ์์กด์ฑ ๋ถ์กฑ
: ์ ๋ ฅ ๋ฐ์ดํฐ์ ํน์ฑ์ ๋ฐ๋ผ ๋ชจ๋ธ์ด ๋์ ์ผ๋ก ๋ฐ์ํ์ง ๋ชปํฉ๋๋ค.
- ์๋ฅผ ๋ค์ด, ์ด๋ค ํน์ ์ ๋ ฅ์ด ๋งค์ฐ ์ค์ํ ๋ ๊ทธ ์ ๋ ฅ์ ๋ํ ๋ฐ์์ ๊ฐํํ๊ฑฐ๋, ๋ฐ๋๋ก ๋ ์ค์ํ ์ ๋ณด๋ ๋ฌด์ํ๋ ๋ฐฉ๋ฒ์ด ์์ต๋๋ค. ๊ทธ๋์, ์ ๋ ฅ ์ํ์ค์ ๋งฅ๋ฝ์ด๋ ์ค์ํ ์ ๋ณด์ ๋ฐ๋ผ ๋ชจ๋ธ์ด ํ์ต๋ ํ๋์ ๋ณํ์ํฌ ์ ์๋ ๋ฅ๋ ฅ์ด ์ ํ๋ฉ๋๋ค.
๐ S4์ S6์ ์ฐจ์ด์ ๋ถ์
- ์ ๋ ฅ ์์กด์ฑ
- S4: ๊ณ ์ ๋ ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ์ฉํ์ฌ, ๋ชจ๋ ์์ ์์ ๋์ผํ ๊ณ์ฐ์ ์ํํฉ๋๋ค. ์ฆ, ๋ชจ๋ ์์ ์์ ๋์ผํ ๋ฐฉ์์ผ๋ก ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค.
- S6: ์ ํ ๋ฉ์ปค๋์ฆ์ ํตํด ์ ๋ ฅ ๋ฐ์ดํฐ xxx์ ๋ฐ๋ผ ๋งค๊ฐ๋ณ์๋ค์ด ๋์ ์ผ๋ก ๋ณํํฉ๋๋ค. ์ด๋ ๋ฐ์ดํฐ์ ํน์ฑ์ ๋ฐ๋ผ ๊ฐ ์์ ์์ ํ์ํ ์ ๋ณด๋ฅผ ์ ํ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์๊ฒ ๋ง๋ญ๋๋ค.
- ์๊ฐ ๋ถ๋ณ์ฑ(Time-invariant) vs ์๊ฐ ๊ฐ๋ณ์ฑ(Time-varying)
- S4: ์๊ฐ ๋ถ๋ณ์ ์ธ ๊ตฌ์กฐ๋ก, ๋์ผํ ํ๋ผ๋ฏธํฐ๊ฐ ๋ชจ๋ ์์ ์ ์ ์ฉ๋ฉ๋๋ค. ์ด๋ ์ฃผ๋ก ํฉ์ฑ๊ณฑ(convolution)์ด๋ ์ฌ๊ท(recursion) ํํ๋ก ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค.
- S6: ์๊ฐ ๊ฐ๋ณ์ ์ธ ๊ตฌ์กฐ๋ก, ์ ๋ ฅ์ ๋ฐ๋ผ ๋งค๊ฐ๋ณ์๋ค์ด ๋ณํํ๊ณ , ์ฌ๊ท์ ๋ฐฉ์์ผ๋ก ๊ณ์ฐ์ด ์ด๋ฃจ์ด์ง๋๋ค. ์ด๋ฅผ ํตํด ์ํ์ค์ ๊ฐ ์์ ์์ ์ค์ํ ์ ๋ณด๋ ๊ธฐ์ตํ๊ณ , ๋ถํ์ํ ์ ๋ณด๋ ๋ฌด์ํ๋ ์ ํ์ ์ฒ๋ฆฌ๊ฐ ๊ฐ๋ฅํฉ๋๋ค.
- ํจ์จ์ฑ
- S4: ์๊ฐ ๋ถ๋ณ์ฑ์ ์ ์งํ๋ SSM์ ๊ณ์ฐ์ ๋ณ๋ ฌํ๊ฐ ๊ฐ๋ฅํ์ฌ, ๋น๊ต์ ํจ์จ์ ์ธ ๊ณ์ฐ์ ์ํํ ์ ์์ต๋๋ค.
- S6: ์ ํ ๋ฉ์ปค๋์ฆ์ ์ถ๊ฐํจ์ผ๋ก์จ ๋ ๋ง์ ๊ณ์ฐ์ด ํ์ํ ์ ์์ง๋ง, GPU์ ๋ฉ๋ชจ๋ฆฌ ๊ณ์ธต์ ํ์ฉํ ํ๋์จ์ด ์ต์ ํ๊ฐ ๊ฐ๋ฅํด์ ธ ํจ์จ์ฑ์ ์ ์งํฉ๋๋ค. ํนํ, ์ ๋ ฅ ๋ฐ์ดํฐ์ ํน์ฑ์ ๋ฐ๋ผ ๋์ ์ธ ๊ณ์ฐ์ ์ํํ์ฌ ๋ ๋์ ์ฑ๋ฅ์ ๋ผ ์ ์์ต๋๋ค.
3.3 Efficient Implementation of Selective SSMs (ํจ์จ์ ์ธ ์ ํ์ SSM ๊ตฌํ)
์ด ์น์ ์์๋ Selective State Space Model(SSSM)์ ํ๋์จ์ด์์ ํจ์จ์ ์ผ๋ก ๊ตฌํํ๋ ๋ฐฉ๋ฒ์ ๋ค๋ฃน๋๋ค. ํนํ, GPU ๋ฑ์ ํ์ฉํ์ฌ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๊ณผ ๊ณ์ฐ์ ์ต์ ํํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์ค๋ช ํ๊ณ ์์ต๋๋ค.
3.3.1 Motivation of Prior Models (์ด์ ๋ชจ๋ธ๋ค์ ๋๊ธฐ)
์ด ํญ๋ชฉ์์๋ ์ ํ์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ(SSSM)์ด ๋์ค๊ธฐ ์ , ๊ธฐ์กด ์ํ ๊ณต๊ฐ ๋ชจ๋ธ(SSM)์ด ์ด๋ป๊ฒ ๋์ํ๋์ง, ๊ทธ๋ฆฌ๊ณ ์ ๊ฐ์ ์ด ํ์ํ๋์ง๋ฅผ ์ค๋ช ํฉ๋๋ค.
-
1. ๊ธฐ์กด ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ๋์ ์๋ฆฌ
- SSM(Structured State Space Model)์ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๊ธฐ ์ํด ์ ์ฌ ์ํ(latent state)๋ฅผ ์ฌ์ฉํ๋ ๋ชจ๋ธ์ ๋๋ค. ์ด ๋ชจ๋ธ์ ์ํ์ค ๋ด์ ์ ๋ณด๋ฅผ ์ฌ๊ท์ ์ผ๋ก ์ฒ๋ฆฌํ์ฌ, ์์ ๊ฐ์ ์์กด์ฑ์ ์ ์งํ๋ฉด์ ๊ธด ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ํจ์จ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.
- SSM์ ์๊ฐ ๋ถ๋ณ์ (time-invariant)์ผ๋ก ์ค๊ณ๋์ด, ๊ฐ ์์ ์์ ๋์ผํ ๋ฐฉ์์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค. ์ด๋ ์ ํ ์ฌ๊ท์ ๊ตฌ์กฐ(linear recurrence)๋ ํฉ์ฑ๊ณฑ ์ฐ์ฐ(convolution)์ ์ฌ์ฉํ์ฌ ๊ณ์ฐ๋ฉ๋๋ค.
-
2. ๊ธฐ์กด ๋ชจ๋ธ์ ํ๊ณ
๊ณ ์ ๋ ํ๋ผ๋ฏธํฐ
: ๊ธฐ์กด์ SSM์ ๋ชจ๋ ์์ ์์ ๋์ผํ ํ๋ผ๋ฏธํฐ๋ก ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค. ์ด๋ ์ ๋ ฅ ๋ฐ์ดํฐ์ ํน์ฑ์ด๋ ์ค์๋์ ๋ฐ๋ผ ๊ฐ๋ณ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์๋ค๋ ํ๊ณ๋ฅผ ์ง๋๊ณ ์์ต๋๋ค. ์ฆ, ๋ชจ๋ ๋ฐ์ดํฐ๋ฅผ ๋์ผํ๊ฒ ์ทจ๊ธํ๊ธฐ ๋๋ฌธ์ ์ค์ํ ์ ๋ณด๋ง ์ ํ์ ์ผ๋ก ์ฒ๋ฆฌํ๋ ๊ธฐ๋ฅ์ด ๋ถ์กฑํฉ๋๋ค.๋ณต์กํ ๊ณ์ฐ
: ์ํ์ค๊ฐ ๊ธธ์ด์ง์๋ก ๊ณ์ฐ ๋ณต์ก๋๊ฐ ํฌ๊ฒ ์ฆ๊ฐํฉ๋๋ค. ํนํ ๊ณ ์ฐจ์์ ์ํ ๊ณต๊ฐ์์ ์์ ํ ๊ฒฝ์ฐ, ๊ณ์ฐ ๋น์ฉ์ด ๋งค์ฐ ๋์์ง๋ฉฐ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋๋ ํฌ๊ฒ ์ฆ๊ฐํฉ๋๋ค.ํจ์จ์ฑ ๋ฌธ์
: SSM์ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ๋ ๋ชจ๋ ์ ๋ณด๋ฅผ ๊ธฐ์ตํด์ผ ํ๊ธฐ ๋๋ฌธ์, ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ๋งค์ฐ ํฌ๊ณ ๊ณ์ฐ ์๊ฐ๋ ๊ธธ์ด์ง๋๋ค. ์ด๋ ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ๋ ํจ์จ์ฑ์ด ๋จ์ด์ง๋ ๋ฌธ์ ๋ฅผ ์ด๋ํฉ๋๋ค.
-
3. ๊ฐ์ ํ์์ฑ
- ์ ๋ ฅ์ ๋ฐ๋ผ ์ ๋์ ์ธ ์ฒ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค. ๊ธฐ์กด ๋ชจ๋ธ์ ๋ชจ๋ ์์ ์์ ๋์ผํ ๋ฐฉ์์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ์ง๋ง, ์ ๋ ฅ์ ์ค์๋์ ๋ฐ๋ผ ์ ํ์ ์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ๊ธฐ๋ฅ์ด ์์ผ๋ฉด ๋ ํจ์จ์ ์ผ๋ก ์๋ํ ์ ์์ต๋๋ค.
- ๋ํ, ๋ฉ๋ชจ๋ฆฌ์ ๊ณ์ฐ ์์์ ๋ ํจ์จ์ ์ผ๋ก ์ฌ์ฉํ๊ธฐ ์ํด, ๊ธฐ์กด ๋ชจ๋ธ๋ณด๋ค ๋ ์ ์ ์์์ผ๋ก ๋์ ์ฑ๋ฅ์ ๋ผ ์ ์๋ ์ต์ ํ๋ ๋ฐฉ์์ด ํ์ํ์ต๋๋ค.
3.3.2 Overview of Selective Scan: Hardware-Aware State Expansion (์ ํ์ ์ค์บ: ํ๋์จ์ด ์ธ์ ์ํ ํ์ฅ์ ๊ฐ์)
์ด ํญ๋ชฉ์์๋ Selective Scan์ ๊ฐ๋ ๊ณผ, ์ด๋ฅผ ํตํด SSSM์ด ํ๋์จ์ด ์์์ ์ด๋ป๊ฒ ํจ์จ์ ์ผ๋ก ๊ตฌํ๋ ์ ์๋์ง๋ฅผ ์ค๋ช ํฉ๋๋ค. ์ฌ๊ธฐ์ ์ค์ํ ๊ฐ๋ ์ ํ๋์จ์ด์ ๋ฉ๋ชจ๋ฆฌ ๊ณ์ธต์ ์ต์ ํํ์ฌ ์ ํ์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ๊ทน๋ํํ๋ ๋ฐฉ๋ฒ์ ๋๋ค.
-
1. Selective Scan์ ๊ฐ๋
Selective Scan
์ ์ํ์ค ๋ด์์ ์ค์ํ ์ ๋ณด๋ฅผ ์ ํ์ ์ผ๋ก ์ฒ๋ฆฌํ๊ณ , ๋ถํ์ํ ์ ๋ณด๋ ๋ฌด์ํ๋ ๊ณผ์ ์ ์๋ฏธํฉ๋๋ค. ์ด๋ฅผ ํตํด ๋ชจ๋ธ์ ์ค์ํ ์ ๋ณด๋ง์ ์ ํ์ ์ผ๋ก ๊ธฐ์ตํ๋ฉด์, ๋ถํ์ํ ์ฐ์ฐ์ ์ค์ผ ์ ์์ต๋๋ค.- ์๊ฐ ๊ฐ๋ณ์ (time-varying)์ด๋ผ๋ ํน์ฑ์ ๊ฐ์ง Selective Scan์ ๊ฐ ์์ ์์ ๋์ ์ผ๋ก ๋ณํํ๋ ์ํ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ฐ์ฐ์ ์ํํฉ๋๋ค.
- ์ด๋ ๊ฐ ์์ ์์ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ๋ถ์ํ์ฌ ์ค์ํ ์ ๋ณด๋ง ์ ํ์ ์ผ๋ก ์ฒ๋ฆฌํ๊ธฐ ๋๋ฌธ์, ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๊ณผ ๊ณ์ฐ ์์์ ์ ์ฝํ ์ ์์ต๋๋ค.
-
2. ํ๋์จ์ด-์ธ์ ์ํ ํ์ฅ (Hardware-Aware State Expansion)
Hardware-Aware State Expansion
์ ์ ํ์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ๊ณ์ฐ์ ํ๋์จ์ด ํจ์จ์ฑ์ ๊ณ ๋ คํ์ฌ ์ต์ ํํ๋ ๋ฐฉ๋ฒ์ ๋๋ค.-
GPU ๋ฉ๋ชจ๋ฆฌ ๊ณ์ธต ํ์ฉ: ํ๋ GPU๋ ๊ณ ์ ๋ฉ๋ชจ๋ฆฌ(SRAM)์ ๋์ฉ๋ ๋ฉ๋ชจ๋ฆฌ(HBM)๋ฅผ ๊ฐ์ง๊ณ ์์ต๋๋ค. ์ ํ์ SSM์์๋ ์ด๋ฌํ ๋ฉ๋ชจ๋ฆฌ ๊ณ์ธต์ ์ ์ ํ ํ์ฉํ์ฌ, ์์ฃผ ์ฌ์ฉ๋๋ ์ค์ํ ๋ฐ์ดํฐ๋ ๊ณ ์ ๋ฉ๋ชจ๋ฆฌ์ ์ ์ฅํ๊ณ , ๋ ์ค์ํ ๋ฐ์ดํฐ๋ ๋์ฉ๋ ๋ฉ๋ชจ๋ฆฌ์ ์ ์ฅํจ์ผ๋ก์จ ๊ณ์ฐ ์๋์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ์ ์ต์ ํํ ์ ์์ต๋๋ค.
- ๊ณ ์ SRAM์ ์ฆ๊ฐ์ ์ธ ๋ฐ์ดํฐ ์ ๊ทผ์ ์ ๊ณตํ๊ณ , ๋์ฉ๋ HBM์ ๋ณต์กํ ์ฐ์ฐ์ ํ์ํ ๋๋์ ๋ฐ์ดํฐ๋ฅผ ์ ์ฅํฉ๋๋ค.
- ์ด ์กฐํฉ์ GPU๊ฐ ๊ทธ๋ํฝ ๋ ๋๋ง, ๋จธ์ ๋ฌ๋, ๊ณผํ์ ์๋ฎฌ๋ ์ด์ ๋ฑ ๋ฐ์ดํฐ ์ง์ฝ์ ์ธ ์์ ์ ํจ๊ณผ์ ์ผ๋ก ์ํํ ์ ์๊ฒ ํฉ๋๋ค.
Kernel Fusion : Hardware-aware Algorithm
์ฒ์ ๋ฑ์ฅํ๋ ๊ฐ๋ ์ ์๋๊ณ , โFlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awarenessโ์ ๋์ค๋ idea๋ผ๊ณ ํฉ๋๋ค.
- GPU์ ์ฃผ์ ๋ณ๋ชฉ ํ์์ SRAM๊ณผ DRAM ์ฌ์ด์ Copy and PASTE์์ ๋ฐ์ํ๋ ๊ฒ์ ํ์ธํ์๊ณ , ์ ์๋ ์ด๋ฌํ memory IO ๋ก ๋ฐ์ํ๋ ๋ณ๋ชฉํ์์ ์ค์ด๊ธฐ ์ํ์ฌ kernel fusion์ ์ฌ์ฉํ์์ต๋๋ค.
- Mamba๋ ๊ณ์ฐ ์์ฒด๋ณด๋ค๋ ๋ฉ๋ชจ๋ฆฌ ์ ์ก ๊ณผ์ ์์ ๋ณ๋ชฉ์ด ๋ฐ์ํ๋ GPU์ ๊ตฌ์กฐ๋ฅผ ๊ณ ๋ คํด ์ฑ๋ฅ์ ๊ทน๋ํํ์ต๋๋ค.
- ์ ๋ ฅ ๋ฒกํฐ์ ๊ฐ์ค์น ๋งค๊ฐ๋ณ์๋ฅผ ๊ณ ์ฑ๋ฅ ๋ฉ๋ชจ๋ฆฌ๋ก ์ ์กํ ํ ๋ชจ๋ ๊ณ์ฐ์ ํ ๋ฒ์ ์ฒ๋ฆฌํ๊ณ , ๋ค์ ๋ฉ์ธ ๋ฉ๋ชจ๋ฆฌ๋ก ๋ฐ์ดํฐ๋ฅผ ์ ์กํฉ๋๋ค.
- ์ด๋ก ์ธํด ๋ฐ์ดํฐ ์ ์ก ์๊ฐ์ ๊ทธ๋๋ก ์ ์ง๋๋ฉด์๋ 16๋ฐฐ ํ์ฅ๋ ๋ฒกํฐ๋ฅผ ์ฌ์ฉํ๋ ๋ฐ ํ์ํ ์ถ๊ฐ ๊ณ์ฐ ์๊ฐ์ ๊ฑฐ์ ๋ฌด๋ฃ๋ก ์ฌ์ฉํ ์ ์๊ฒ ๋ฉ๋๋ค.
์๋ ๊ทธ๋ฆผ์ ๊ฐ๋ต์ ์ผ๋ก โ์ ๋ ฅ ๋ฒกํฐ์ ๊ฐ์ค์น ๋งค๊ฐ๋ณ์๋ฅผ ๊ณ ์ฑ๋ฅ ๋ฉ๋ชจ๋ฆฌ๋ก ์ ์กํ ํ ๋ชจ๋ ๊ณ์ฐ์ ํ ๋ฒ์ ์ฒ๋ฆฌํ๊ณ , ๋ค์ ๋ฉ์ธ ๋ฉ๋ชจ๋ฆฌ๋ก ๋ฐ์ดํฐ๋ฅผ ์ ์กโํ๋ ๊ณผ์ ์ ๋์ํํ ๊ทธ๋ฆผ์ ๋๋ค. (Source: https://youtu.be/N6Piou4oYx8)
- Scan Operation: Selective Scan ๊ณ์ฐ์ ๋ณ๋ ฌ๋ก ์ฒ๋ฆฌํ ์ ์๋๋ก ์ค๊ณํ์ฌ, ์ํ์ค์ ์ฌ๋ฌ ์์ ์ ๋์์ ์ฒ๋ฆฌํ๋ ๋ฐฉ์์ ์ฌ์ฉํฉ๋๋ค.
- ๋ณ๋ ฌ ์ค์บ ์๊ณ ๋ฆฌ์ฆ(Parellel Scan Operation)์ ํตํด ์ ๋ ฅ ์ํ์ค๋ฅผ ๋์์ ์ฒ๋ฆฌํ๋ฉด์, ์ ํ์ ์ผ๋ก ํ์ํ ์ ๋ณด๋ฅผ ์ฒ๋ฆฌํ๊ณ ๋๋จธ์ง๋ ๊ฑด๋๋ฐ๋ ๋ฐฉ์์ ๋๋ค.
- ์ด๋ฅผ ํตํด ์ ์ฒด ์ํ์ค๋ฅผ ์์ฐจ์ ์ผ๋ก ์ฒ๋ฆฌํ๋ ๋ฐฉ์๋ณด๋ค ๋ ๋น ๋ฅด๊ณ ํจ์จ์ ์ผ๋ก ์ฐ์ฐ์ ์ํํ ์ ์์ต๋๋ค.
-
๊ฒฐํฉ๊ท์น(association rule) ๊ธฐ๋ฐ ์ ๊ทผ๋ฒ์ ์ฌ์ฉํ์ฌ โ๋จผ์ ๊ณ์ฐํ ์ ์๋ ๊ฒ์ ๊ณ์ฐํด์ฃผ์!โ ๋ผ๋ ๊ฐ๋จํ์ง๋ง ๊ฐ๋ ฅํ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ์ต๋๋ค.
- ์ด ๋ฐฉ๋ฒ์ ์ฌ์ฉํ๋ฉด O(logโก(n))O(\log(n))O(log(n))์ ์๊ฐ ๋ด์ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ์ํํ ์ ์์ด ๊ณ์ฐ ์๋๊ฐ ํฌ๊ฒ ํฅ์๋ฉ๋๋ค.
- 3. ์๊ฐ ๊ฐ๋ณ์ ์ ํ ์ฒ๋ฆฌ
- Selective SSM์ ์๊ฐ ๊ฐ๋ณ์ ์ด๊ธฐ ๋๋ฌธ์, ๊ฐ ์์ ๋ง๋ค ๋ค๋ฅธ ๋ฐฉ์์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค. ์ด๋ก ์ธํด, ๊ฐ ์์ ์์ ์ค์ํ ์ ๋ณด๋ฅผ ์ ํ์ ์ผ๋ก ์ฒ๋ฆฌํ๊ณ , ์ฌ๊ท์ ์ฐ์ฐ(recurrent operation)์ ํตํด ์ด์ ์ํ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๋ค์ ์ํ๋ฅผ ๊ณ์ฐํฉ๋๋ค.
- ๊ฐ ์์ ์์ ์ฒ๋ฆฌ๋๋ ๋ฐ์ดํฐ์ ์์ ์ค์ด๊ธฐ ์ํด, Selective Scan์ ํตํด ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ๋ถ์ํ๊ณ ํ์ํ ์ ๋ณด๋ง ์ ํํฉ๋๋ค. ์ด๋ฅผ ํตํด ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋๊ณผ ๊ณ์ฐ๋์ ์ค์ด๊ณ , ์ฐ์ฐ ์๋๋ฅผ ํฌ๊ฒ ํฅ์์ํฌ ์ ์์ต๋๋ค.
3.4 A Simplified SSM Architecture (๋จ์ํ๋ SSM ์ํคํ ์ฒ)
- Mamba ์ํคํ
์ฒ: SSM์ MLP ๋ธ๋ก๊ณผ ๊ฒฐํฉํ์ฌ ๊ฐ๋จํ ํํ์ ์ํคํ
์ฒ๋ฅผ ๋ง๋ค์์ต๋๋ค. ์ด ์ํคํ
์ฒ๋ Transformer์ฒ๋ผ ๋ณต์กํ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง ์์ผ๋ฉฐ, ๋จ์ํ์ง๋ง ๊ฐ๋ ฅํ ์ฑ๋ฅ์ ๋ณด์ฌ์ค๋๋ค.
- ์ฌ๋ฌ ๊ฐ์ Mamba ๋ธ๋ก์ ๋ฐ๋ณต์ ์ผ๋ก ์์ ๋ชจ๋ธ์ ๊น์ด๋ฅผ ํ์ฅํ ์ ์์ต๋๋ค.
- Mamba ๋ธ๋ก์ ํ๋ ์ ๊ฒฝ๋ง์ ๋ค์ธต ํผ์ ํธ๋ก (MLP) ๋ธ๋ก๊ณผ ๋๋ถ๋ถ์ SSM ์ํคํ ์ฒ์ ๊ธฐ์ด๊ฐ ๋๋ H3 ๋ธ๋ก์ ์กฐํฉํ ๊ฒ์ ๋๋ค.
์ ๊ทธ๋ฆผ์ H3 ๋ธ๋ก, Gated MLP ๋ธ๋ก, ๊ทธ๋ฆฌ๊ณ Mamba ๋ธ๋ก์ ๊ตฌ์กฐ๋ฅผ ๋น๊ตํ ๊ฒ์ ๋๋ค. ๊ฐ ๋ธ๋ก์ ํ๋ ์ ๊ฒฝ๋ง์์ ์ฌ์ฉ๋๋ ๊ตฌ์กฐ์ ์ฐจ์ด์ ์ ์๊ฐ์ ์ผ๋ก ๋ณด์ฌ์ฃผ๋ฉฐ, ์ด๋ฅผ ํตํด Mamba ์ํคํ ์ฒ๊ฐ ์ด๋ป๊ฒ ์ค๊ณ๋์๋์ง๋ฅผ ์ค๋ช ํฉ๋๋ค. ๊ทธ๋ฆผ์ ๋ํ ์ฃผ์ ํด์์ ์๋์ ๊ฐ์ต๋๋ค.
1. H3 Block
H3
๋ ๊ณ ์ฐจ์ ๋ฐ์ดํฐ๋ฅผ ํจ๊ณผ์ ์ผ๋ก ์ฒ๋ฆฌํ๊ธฐ ์ํด ์ค๊ณ๋ ๋ธ๋ก ๊ตฌ์กฐ๋ก, RNN๊ณผ CNN์ ์ฅ์ ์ ๊ฒฐํฉํ์ฌ ์์ฐจ์ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ๊ฐ๋ ฅํ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ๊ณ ์์ต๋๋ค.- ์ด ๋ธ๋ก์ ๊ณผ๊ฑฐ์ ์ ๋ณด์ ํ์ฌ ์ ๋ ฅ์ ๊ธฐ๋ฐ์ผ๋ก ๋ฏธ๋์ ์ถ๋ ฅ์ ์์ธกํ๋ ๋ฐฉ์์ผ๋ก ์๋ํฉ๋๋ค.
- H3์ ์ฃผ์ ํน์ง์ ๊ธด ์ํ์ค๋ฅผ ํจ์จ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์๋ ๋ฅ๋ ฅ๊ณผ, ์ํ์ค์ ๋ชจ๋ ์์์ ๋ํ ์ ๋ณด๋ฅผ ๊ณ ๋ คํ ์ ์๋ ๊ตฌ์กฐ์ ํน์ฑ์ ๋๋ค.
- ๊ตฌ์ฑ ์์:
- SSM: ์ํ ๊ณต๊ฐ ๋ชจ๋ธ(State Space Model)์ ์ํ์ค ๋ณํ์ ๋ด๋นํฉ๋๋ค. ์ด๋ ์ฃผ๋ก ์ฌ๊ท์ ์ธ ํน์ฑ์ ํ์ฉํ์ฌ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค.
- Conv: ํฉ์ฑ๊ณฑ ์ธต์ด ์ถ๊ฐ๋์ด ๋ก์ปฌ ์ ๋ณด๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค. ํฉ์ฑ๊ณฑ์ ์ผ๋ฐ์ ์ผ๋ก ๊ณต๊ฐ์ ์ฐ๊ด์ฑ์ ๋ค๋ฃจ๋ ๋ฐ ํจ๊ณผ์ ์ ๋๋ค.
- ๊ณฑ์ ๊ฒ์ดํธ(Multiplicative Gate): SSM๊ณผ Conv ์ฌ์ด์ ๊ณฑ์ ๊ฒ์ดํธ๊ฐ ์์ด ๋ฐ์ดํฐ์ ํ๋ฆ์ ์กฐ์ ํฉ๋๋ค.
- ๋์ ์๋ฆฌ:
- H3 ๋ธ๋ก์ SSM๊ณผ ํฉ์ฑ๊ณฑ ์ธต์ ๊ต์ฐจ ๋ฐฐ์นํ์ฌ ๊ฐ๊ฐ์ ์ํ์ค ๋ฐ์ดํฐ์ ๋ก์ปฌ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ฉฐ, ๊ณฑ์ ๊ฒ์ดํธ๋ฅผ ํตํด ์ค์ํ ์ ๋ณด๋ฅผ ํต๊ณผ์ํค๊ฑฐ๋ ์ต์ ํ ์ ์์ต๋๋ค. ์ด ๊ตฌ์กฐ๋ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ๋ ์ ์ฐ์ฑ์ ์ ๊ณตํ์ง๋ง, ๊ณ์ฐ์ ์ผ๋ก ๋ณต์กํ ์ ์์ต๋๋ค.
2. Gated MLP
-
Gated MLP๋ ๋ค์ธต ํผ์ ํธ๋ก (MLP)๊ณผ ๊ณฑ์ ๊ฒ์ดํธ(Multiplicative Gate)๋ฅผ ๊ฒฐํฉํ์ฌ ์ ๋ ฅ ๋ฐ์ดํฐ์ ๋ํ ๋น์ ํ ๋ณํ์ ์ํํ๋ ๊ตฌ์กฐ์ ๋๋ค.
- MLP๋ ์ผ๋ฐ์ ์ผ๋ก ๊ณ ์ ๋ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ํ์ํ ๋ฅ๋ ฅ์ ๊ฐ์ง๊ณ ์์ผ๋ฉฐ, ๊ณฑ์ ๊ฒ์ดํธ๋ ๋ฐ์ดํฐ์ ํน์ฑ์ ๋ฐ๋ผ ์ค์ํ ์ ๋ณด๋ฅผ ๊ฐ์กฐํ๊ณ , ๋ถํ์ํ ์ ๋ณด๋ ์ต์ ํ ์ ์๋๋ก ๋์์ค๋๋ค.
-
๊ตฌ์ฑ ์์:
- MLP: ๋ค์ธต ํผ์ ํธ๋ก (Multi-Layer Perceptron) ๋ธ๋ก์ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ๋น์ ํ์ ์ผ๋ก ๋ณํํฉ๋๋ค.
- ๊ณฑ์ ๊ฒ์ดํธ(Multiplicative Gate): H3์ ๋ง์ฐฌ๊ฐ์ง๋ก ๊ณฑ์ ๊ฒ์ดํธ๊ฐ ์ถ๊ฐ๋์ด, ๋ฐ์ดํฐ๋ฅผ ์ ํ์ ์ผ๋ก ์ฒ๋ฆฌํ๋ ์ญํ ์ ํฉ๋๋ค.
-
๋์ ์๋ฆฌ:
- Gated MLP๋ MLP์ ๊ณฑ์ ๊ฒ์ดํธ๋ฅผ ๊ฒฐํฉํ์ฌ, ์ ๋ ฅ ๋ฐ์ดํฐ์ ๋ํ ์ ์ฐํ ๋ณํ์ ์ํํฉ๋๋ค. ๋ฐ์ดํฐ์ ํน์ฑ์ ๋ฐ๋ผ ์ค์ํ ์ ๋ณด๋ ๊ณฑ์ ์ ํตํด ๊ฐ์กฐ๋๊ณ , ๋ถํ์ํ ์ ๋ณด๋ ์ต์ ๋ ์ ์์ต๋๋ค.
- ๊ทธ๋ฌ๋ Gated MLP๋ ๋น์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ๋ ํจ๊ณผ์ ์ด์ง๋ง, ์ํ์ค ์ฒ๋ฆฌ์ ํ์ํ ๋ฉ์ปค๋์ฆ(์: SSM)์ด ํฌํจ๋์ด ์์ง ์๋ค๋ ์ ์์ ์ํ์ค ๊ธฐ๋ฐ ์์ ์ ์ต์ ํ๋ ๊ตฌ์กฐ๋ ์๋๋๋ค.
3. Mamba Block
-
Mamba ๋ธ๋ก์ ํ๋ ์ ๊ฒฝ๋ง์ ๋ค์ธต ํผ์ ํธ๋ก (MLP) ๋ธ๋ก๊ณผ SSM(Structured State Space Model) ์ํคํ ์ฒ์์ ์ค์ํ ์ญํ ์ ํ๋ H3 ๋ธ๋ก์ ๊ฒฐํฉํ ์ค๊ณ์ ๋๋ค.
- ์ด ๊ตฌ์กฐ๋ ๊ธฐ์กด์ MLP์ SSM ๋ธ๋ก์ ๊ต์ฐจํ๊ฑฐ๋ ํผํฉํ๋ ๋์ , ๋์ผํ Mamba ๋ธ๋ก์ ๋์ง์ ์ผ๋ก ๋ฐ๋ณตํ๋ ๋ฐฉ์์ผ๋ก ์ค๊ณ๋์์ต๋๋ค.
-
๊ตฌ์ฑ ์์:
- SSM: H3 ๋ธ๋ก๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก ์ํ ๊ณต๊ฐ ๋ชจ๋ธ(SSM)์ด ์กด์ฌํ์ฌ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค.
- Conv: ํฉ์ฑ๊ณฑ ์ธต์ด ์ถ๊ฐ๋์ด, ์ํ์ค ๋ด์ ๊ตญ์์ ์ธ ์ ๋ณด ์ฒ๋ฆฌ์ ๊ธฐ์ฌํฉ๋๋ค.
- ํ์ฑํ ํจ์(SiLU/Swish): H3์๋ ๋ค๋ฅด๊ฒ, ๊ณฑ์ ๊ฒ์ดํธ ๋์ ํ์ฑํ ํจ์๊ฐ ์ฌ์ฉ๋ฉ๋๋ค. ์ด ํจ์๋ ๋น์ ํ์ฑ์ ์ถ๊ฐํ์ฌ ๋ฐ์ดํฐ์ ํํ๋ ฅ์ ๋์ ๋๋ค.
-
๋์ ์๋ฆฌ:
- Mamba ๋ธ๋ก์ SSM๊ณผ Conv๋ฅผ ๊ฒฐํฉํ์ฌ ์ํ์ค ๋ฐ ๋ก์ปฌ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ ์ฒ๋ฆฌํฉ๋๋ค.
- ๊ณฑ์ ๊ฒ์ดํธ ๋์ ๋น์ ํ ํ์ฑํ ํจ์(SiLU ๋๋ Swish)๋ฅผ ์ฌ์ฉํ์ฌ ๊ณ์ฐ ๋ณต์ก์ฑ์ ์ค์ด๊ณ ํจ์จ์ฑ์ ๋์์ต๋๋ค. ํ์ฑํ ํจ์๋ ๋ฐ์ดํฐ์ ํ๋ฆ์ ์กฐ์ ํ๋ ๋ฐ ๋ ๊ฐ๋จํ ๋ฐฉ์์ ์ฌ์ฉํ๋ฉฐ, ๊ฒ์ดํธ์ ํ์์ฑ์ ์ ๊ฑฐํ์ฌ ๋ ๋จ์ํ ๊ตฌ์กฐ๋ฅผ ๋ง๋ค์์ต๋๋ค.
- Mamba ๋ธ๋ก์ ๋ ๊ฐ์ง ์ฃผ์ ์ฐ์ฐ(SMM, Conv)์ ํ ๋ธ๋ก ๋ด์์ ๋ฐ๋ณตํ๋ ๊ฐ๋จํ๋ฉด์๋ ๊ฐ๋ ฅํ ๊ตฌ์กฐ๋ฅผ ์ฑํํ์ต๋๋ค.
-
Mamba์ ๋ค๋ฅธ ๋ธ๋ก์ ์ฐจ์ด์
- ๋ณต์ก์ฑ ๊ฐ์: H3 ๋ธ๋ก์์๋ ๊ณฑ์ ๊ฒ์ดํธ๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ๋ฅผ ์ ํ์ ์ผ๋ก ์ฒ๋ฆฌํ์ง๋ง, Mamba๋ ์ด๋ฅผ ์ ๊ฑฐํ๊ณ ํ์ฑํ ํจ์๋ก ๋์ฒดํจ์ผ๋ก์จ ๊ณ์ฐ ๋ณต์ก์ฑ์ ์ค์์ต๋๋ค.
- ๋จ์ํ๋ ์ํคํ ์ฒ: Mamba๋ SSM๊ณผ Conv๋ฅผ ๋์ผํ ๋ธ๋ก ๋ด์์ ๋ฐ๋ณต์ ์ผ๋ก ์ฌ์ฉํ์ฌ ๋งค์ฐ ๊ท ์ผํ๊ณ ๋จ์ํ ์ํคํ ์ฒ๋ฅผ ์ ์งํ๋ฉด์๋ ์ฑ๋ฅ์ ์ต์ ํํ์ต๋๋ค. ์ด๋ ์ ์ฒด์ ์ผ๋ก ๋ ๋น ๋ฅด๊ณ ํจ์จ์ ์ธ ๊ณ์ฐ์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค.
- ๋น์ ํ ํ์ฑํ: Mamba์์๋ ๊ณฑ์ ๊ฒ์ดํธ ๋์ SiLU/Swish์ ๊ฐ์ ๋น์ ํ ํ์ฑํ ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ์ ํ๋ฆ์ ์กฐ์ ํฉ๋๋ค. ์ด๋ ๋ชจ๋ธ์ ํํ๋ ฅ์ ์ ์งํ๋ฉด์๋ ๊ณ์ฐ ๋ณต์ก์ฑ์ ์ค์ด๋ ์ค์ํ ๋ณํ์ ๋๋ค.
โจ Mamba ๋ธ๋ก์ H3 ๋ฐ Gated MLP์ ๋น๊ตํ์ฌ ๋ ๋จ์ํ๊ณ ํจ์จ์ ์ธ ์ํคํ ์ฒ๋ฅผ ์ ๊ณตํฉ๋๋ค. SSM๊ณผ Conv๋ฅผ ์ ์ ํ ๊ฒฐํฉํ์ฌ ์ํ์ค ๋ฐ์ดํฐ์ ๋ก์ปฌ ์ ๋ณด๋ฅผ ๋์์ ์ฒ๋ฆฌํ๋ฉฐ, ํ์ฑํ ํจ์๋ก ๋น์ ํ์ฑ์ ๋ถ์ฌํ์ฌ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ๋์์ต๋๋ค.
3.5 Properties of Selection Mechanisms (์ ํ ๋ฉ์ปค๋์ฆ์ ํน์ฑ)
- ์ ํ ๋ฉ์ปค๋์ฆ์ ํจ๊ณผ: ์ ํ ๋ฉ์ปค๋์ฆ์ ๋ถํ์ํ ๋ฐ์ดํฐ๋ฅผ ๋ฌด์ํ๊ณ , ์ค์ํ ๋ฐ์ดํฐ๋ฅผ ์ ํ์ ์ผ๋ก ์ฒ๋ฆฌํ๋ ๋ฅ๋ ฅ์ ๊ฐ์ง๋๋ค. ์ด๋ฅผ ํตํด ์ํ์ค์ ๊ธด ๋ฌธ๋งฅ์ ์ฒ๋ฆฌํ ๋ ํจ์จ์ฑ์ด ์ฆ๊ฐํ๋ฉฐ, ๊ธด ์ํ์ค์์๋ ์ฑ๋ฅ ์ ํ๊ฐ ๋ฐ์ํ์ง ์์ต๋๋ค.
- ๋ณ์ ๊ฐ ์ํธ์์ฉ: ์ ํ ๋ฉ์ปค๋์ฆ์ ์ํ์ค์ ๊ฐ ์์๋ค์ด ์ํธ์์ฉํ๋ ๋ฐฉ์์ ์กฐ์ ํ๋ฉฐ, ์ด๋ฌํ ์กฐ์ ๋ฅ๋ ฅ์ ํนํ ํ ์คํธ๋ DNA์ ๊ฐ์ ์ด์ฐ์ ์ธ ๋ฐ์ดํฐ์์ ํจ๊ณผ์ ์ ๋๋ค.
3.6 Additional Model Details (์ถ๊ฐ ๋ชจ๋ธ ์ธ๋ถ์ฌํญ)
- ์ค์ ๋ฐ ๋ณต์์ ์ฒ๋ฆฌ: ์ ํ์ SSM์ ๋ณต์์์ ์ค์๋ฅผ ๋ชจ๋ ์ฒ๋ฆฌํ ์ ์์ง๋ง, ํน์ ์์ ์์๋ ์ค์ ๊ธฐ๋ฐ ๋ชจ๋ธ์ด ๋ ๋์ ์ฑ๋ฅ์ ๋ณด์ผ ์ ์์ต๋๋ค.
- ์ด๊ธฐํ ๋ฐ ํ๋ผ๋ฏธํฐํ: ์ ํ์ ํ๋ผ๋ฏธํฐ์ ์ด๊ธฐํ ๋ฐฉ์์ ๋ฐ๋ผ ๋ชจ๋ธ์ ์ฑ๋ฅ์ด ๋ฌ๋ผ์ง๋ฉฐ, ๊ฐ ํ๋ผ๋ฏธํฐ์ ๋ํ ์์ธํ ์ค๋ช ์ ํตํด ๋ชจ๋ธ์ ์์ ์ฑ์ ์ ์งํฉ๋๋ค.
4. Empirical Evaluation
Mamba ๋ชจ๋ธ์ ๋ค์ํ ๋ฐ์ดํฐ ์ ํ
๊ณผ ์ํ์ค ๊ธธ์ด
์์ ํ
์คํธํ ๊ฒฐ๊ณผ๋ฅผ ์๊ฐํฉ๋๋ค.
4.1 Synthetic Tasks (ํฉ์ฑ ์์ )
- Selective Copying: ์ ํ ๋ฉ์ปค๋์ฆ์ ์ฌ์ฉํ Mamba๋ ์ํ์ค์ ์ค์ํ ๋ถ๋ถ์ ๊ธฐ์ตํ๊ณ ๋๋จธ์ง๋ฅผ ๋ฌด์ํ๋ ์์ ์์ ๋ฐ์ด๋ ์ฑ๋ฅ์ ๋ณด์ ๋๋ค. ์ํ์ค ๊ธธ์ด๊ฐ ๋งค์ฐ ๊ธธ์ด๋ Mamba๋ ์ฑ๋ฅ ์ ํ ์์ด ์ ํ๋๋ฅผ ์ ์งํฉ๋๋ค.
- Induction Heads: LLM์ ๋งฅ๋ฝ ํ์ต ๋ฅ๋ ฅ์ ํ๊ฐํ๋ Induction Heads ์์ ์์๋ Mamba๋ ์ค์ํ ํ ํฐ์ ๊ธฐ์ตํ๋ฉฐ ์ฑ๋ฅ์ ์ ์งํฉ๋๋ค. ํ๋ จ ์ 256 ๊ธธ์ด์ ์ํ์ค๋ก ํ์ตํ ๋ชจ๋ธ์ด 1๋ฐฑ๋ง ๊ธธ์ด์ ์ํ์ค์์๋ ์ ํํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ ๋๋ค.
4.2 Language Modeling (์ธ์ด ๋ชจ๋ธ๋ง)
- Mamba์ ์ธ์ด ๋ชจ๋ธ๋ง ์ฑ๋ฅ: Mamba๋ ํ ์คํธ ๋ฐ์ดํฐ์์ Transformer์ ๋น์ทํ๊ฑฐ๋ ๋ ๋์ ์ฑ๋ฅ์ ๋ณด์ ๋๋ค. ํนํ, 1B ์ด์์ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ์ง ๋ชจ๋ธ์์๋ Transformer์ ๋น์ทํ ์์ค์ ์ฑ๋ฅ์ ๋ด๋ ์ฒซ ๋ฒ์งธ ์ ํ ์ํ์ค ๋ชจ๋ธ์ ๋๋ค.
4.3 DNA Modeling (DNA ๋ชจ๋ธ๋ง)
- DNA ์ํ์ค ์ฒ๋ฆฌ: Mamba๋ ๊ธด ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ๋ฐ์ด๋๋ฉฐ, ๊ธฐ์กด์ Transformer ๋ชจ๋ธ๋ณด๋ค ๋ ๋์ ์ฑ๋ฅ์ ๋ณด์ ๋๋ค. DNA์ ๊ฒฝ์ฐ ๊ธด ๋ฌธ๋งฅ ์์กด์ฑ์ด ์ค์ํ๋ฐ, Mamba๋ ์ด๋ฌํ ๋ฐ์ดํฐ๋ฅผ ์ ์ฒ๋ฆฌํฉ๋๋ค.
4.4 Audio Modeling and Generation (์ค๋์ค ๋ชจ๋ธ๋ง ๋ฐ ์์ฑ)
- ์ค๋์ค ๋ฐ์ดํฐ ์ฒ๋ฆฌ: Mamba๋ ์ค๋์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ๋๋ ํจ์จ์ ์ด๋ฉฐ, ๊ธฐ์กด ๋ชจ๋ธ๋ณด๋ค ๋ ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.
4.5 Speed and Memory Benchmarks (์๋ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ๋ฒค์น๋งํฌ)
- ์ฒ๋ฆฌ ์๋: Mamba๋ Transformer๋ณด๋ค 5๋ฐฐ ๋น ๋ฅธ ์ถ๋ก ์๋๋ฅผ ๋ณด์ด๋ฉฐ, ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ๋ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋๋ ๋งค์ฐ ์ ์ต๋๋ค.
4.6 Model Ablations (๋ชจ๋ธ ์๋ธ๋ ์ด์ )
- ํ๋ผ๋ฏธํฐ ๋ถ์: ์ ํ์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ถ๊ฐํ ์๋ก ์ฑ๋ฅ์ด ํฅ์๋๋ฉฐ, ํนํ Delta ํ๋ผ๋ฏธํฐ๊ฐ ๋ชจ๋ธ ์ฑ๋ฅ์ ๊ฐ์ฅ ์ค์ํ ์ํฅ์ ๋ฏธ์นฉ๋๋ค.
Reference
Paper
- Mamba ๋ ผ๋ฌธ : https://arxiv.org/pdf/2312.00752
Blogs
- A Visual Guide to Mamba and State Space Models: https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-mamba-and-state
Youtube
- MAMBA from Scratch (https://youtu.be/N6Piou4oYx8)
- Mamba Paper Review, DSBA ์ฒ์ฌ์ (https://youtu.be/JjxBNBzDbNk)
- Mamba Paper Review, AirLab ์ด์ ์ด (https://youtu.be/l-dQCTv9wIg)