[Paper Review] Resurrecting Recurrent Neural Networks for Long Sequences
์๋ณธ ๊ฒ์๊ธ: https://velog.io/@euisuk-chung/Paper-Review-Resurrecting-Recurrent-Neural-Networks-for-Long-Sequences
๋ ผ๋ฌธ โResurrecting Recurrent Neural Networks for Long Sequencesโ๋ 25 Apr 2023์ publish๋์์ผ๋ฉฐ, ICML 2023 OralPoster์ ๋ฐํ๋ ๋ ผ๋ฌธ์ ๋๋ค.
ํด๋น ๋ ผ๋ฌธ์ โRecurrent Neural Networks (RNN)์ ์ฑ๋ฅ์ ๋ณต์ํ์ฌ ๊ธด ์ํ์ค์์์ ํจ์จ์ ์ธ ํ์ต๊ณผ ์ถ๋ก โ์ ๋ค๋ฃจ๊ณ ์์ต๋๋ค. ๋ณธ paper-review์์๋ ๊ฐ ์ฑํฐ๋ณ๋ก ์ฃผ์ ๋ด์ฉ์ ์ ๋ฆฌํด๋ณด์์ต๋๋ค.
- ๋
ผ๋ฌธ์ ์๋ก ์
Recurrent Neural Networks(RNNs)
๊ฐ ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ๋ ๊ฒช๋ ๋ฌธ์ ์ ๊ณผ ๊ทธ์ ๋ํ ํด๊ฒฐ์ฑ ์ ์ ์ํ๋ฉฐ, ์ต๊ทผ์ ์ฐ๊ตฌ ๋ํฅ์ ์ค๋ช ํ๊ณ ์์ต๋๋ค.
1.1. RNN์ ์ค์์ฑ๊ณผ ํ๊ณ
RNN
์ ์์ฐจ์ ๋ฐ์ดํฐ(์๊ณ์ด, ์์ฐ์ด ๋ฑ)๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ์ค๋ ์๊ฐ ๋์ ์ค์ํ ์ญํ ์ ํด์์ต๋๋ค. RNN์ ๋ฐ์ดํฐ๋ฅผ ์์ฐจ์ ์ผ๋ก ์ฒ๋ฆฌํ๊ธฐ ๋๋ฌธ์ ์ํ์ค ๊ฐ์ ์์กด์ฑ์ ํ์ตํ ์ ์๋ ๋ฅ๋ ฅ์ด ๋ฐ์ด๋ฉ๋๋ค. ๊ทธ๋ฌ๋ ์ค์ ์์ RNN์ ํ์ต์ํค๋ ๊ฒ์ ๋งค์ฐ ์ด๋ ค์ด ์ผ์
๋๋ค.
์ฃผ๋ ์ด์ ๋ ๊ธฐ์ธ๊ธฐ ์์ค(vanishing gradient)๊ณผ ๊ธฐ์ธ๊ธฐ ํญ๋ฐ(exploding gradient) ๋ฌธ์ ๋๋ฌธ์ ๋๋ค. ์ด๋ RNN์ด ๊ธด ์ํ์ค๋ฅผ ํ์ตํ ๋, ์ด๋ฐ์ ์ ๋ ฅ๋ ์ ๋ณด๊ฐ ๋คํธ์ํฌ๋ฅผ ํตํด ์ ๋ฌ๋๋ฉด์ ์ ์ ์ฝํด์ง๊ฑฐ๋, ๋๋ฌด ๊ฐํด์ ธ์ ํ์ต์ด ๋ถ๊ฐ๋ฅํด์ง๋ ํ์์ ์๋ฏธํฉ๋๋ค.
-
๊ธฐ์ธ๊ธฐ ์์ค
๋ฌธ์ ๋ ๋คํธ์ํฌ์ ์ด๊ธฐ ์ธต์ด๋ ๊น์ ์ธต์์ ๊ธฐ์ธ๊ธฐ๊ฐ ๋งค์ฐ ์์์ ธ์, ๊ฐ์ค์น ์ ๋ฐ์ดํธ๊ฐ ์ด๋ฃจ์ด์ง์ง ์๊ฑฐ๋ ๋งค์ฐ ๋๋ฆฌ๊ฒ ์ด๋ฃจ์ด์ง๋ ํ์์ ๋๋ค.- ๋น์ ํ ํ์ฑํ ํจ์: ์ ๋ ฅ๊ฐ์ด ๋น์ ํ ํจ์์ ๊ทนํ์ ๋๋ฌํ ๋ ํด๋น ํจ์์ ๊ธฐ์ธ๊ธฐ๊ฐ 0์ด ๋๊ธฐ ๋๋ฌธ์, ๊ทธ ์ดํ์ ์ธต์ผ๋ก ์ ๋ฌ๋๋ ๊ธฐ์ธ๊ธฐ๊ฐ ๋ชจ๋ 0์ ๊ฐ๊น์์ง๋ค.
- ์ธต์ด ๊น์ด์ง์๋ก: RNN์ ์๊ฐ์ ๋ฐ๋ผ ์ฐ์์ ์ผ๋ก ์ฌ๋ฌ ์ธต์ ์์ ๋๊ฐ๊ธฐ ๋๋ฌธ์, ๊ฐ ๋จ๊ณ์์ ์ ํธ๊ฐ ์ ํ๋ ๋ ๊ธฐ์ธ๊ธฐ๊ฐ ๊ณฑํด์ ธ์ ๊ฐ์ํ๊ฒ ๋๋ค.
-
๊ธฐ์ธ๊ธฐ ํญ๋ฐ
๋ฌธ์ ๋ ๊ธฐ์ธ๊ธฐ๊ฐ ๋๋ฌด ์ปค์ ธ์ ๊ฐ์ค์น ์ ๋ฐ์ดํธ๊ฐ ๊ทน์ฌํด์ง๋ ํ์์ ๋๋ค. ์ด๋ก ์ธํด ๋ชจ๋ธ์ด ํ์ต ์ค์ ๋ฐ์ฐํ๊ฒ ๋ ์ ์์ต๋๋ค.- ๊ธฐ์ธ๊ธฐ ์กฐํฉ: RNN ๋ชจ๋ธ์์ ์ฌ๋ฌ ์ธต์ ๊ฐ์ค์น๊ฐ ๊ณฑํด์ง๋ฉด์ ๊ธฐ์ธ๊ธฐ๊ฐ ์ง์์ ์ผ๋ก ์ฆ๊ฐํ ์ ์์ต๋๋ค.
- ๋ถ์ ์ ํ ์ด๊ธฐํ: ๊ฐ์ค์น๋ฅผ ์๋ชป ์ด๊ธฐํํ๋ฉด ํ๋ จ ์ค์ ๊ธฐ์ธ๊ธฐ๊ฐ ํฐ ๊ฐ์ ๊ฐ์ง๊ฒ ๋๋ฉด์, ์ ๋ฐ์ดํธ๋ ๋น์ ์์ ์ผ๋ก ์ปค์ง๊ฒ ๋ฉ๋๋ค.
์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ๋ช ๊ฐ์ง ๊ธฐ์ ๋ค์ด ๊ฐ๋ฐ๋์์ต๋๋ค. ์๋ฅผ ๋ค์ด, LSTM(Long Short-Term Memory)๊ณผ GRU(Gated Recurrent Units) ๊ฐ์ ๊ฒ์ดํธ ๋ฉ์ปค๋์ฆ์ ์ฌ์ฉํ๋ ๋ชจ๋ธ๋ค์ด ์์ต๋๋ค. ์ด๋ค์ RNN์ ํ๊ณ๋ฅผ ๊ทน๋ณตํ๋ ค๊ณ ๊ณ ์๋ ๋ฐฉ์์ด์ง๋ง, ์ฌ์ ํ ํ์ต ๊ณผ์ ์์ ์๋๊ฐ ๋๋ฆฌ๊ณ , ๋๊ท๋ชจ ๋ฐ์ดํฐ์์ ํ์ฅ์ฑ(scalability)์ด ๋ถ์กฑํ๋ค๋ ๋ฌธ์ ๊ฐ ๋จ์ ์์ต๋๋ค.
1.2. Transformer ๋ชจ๋ธ์ ๋๋
์ต๊ทผ์๋ Transformer
๋ชจ๋ธ์ด ๋ฑ์ฅํ๋ฉด์ ์์ฐจ์ ๋ฐ์ดํฐ ์ฒ๋ฆฌ์์ ํฐ ์ฑ๊ณต์ ๊ฑฐ๋์์ต๋๋ค.
Transformer๋ ์ฃผ์(attention) ๋ฉ์ปค๋์ฆ์ ์ฌ์ฉํ์ฌ ์ํ์ค์ ๊ฐ ์์ ๊ฐ์ ์ํธ์์ฉ์ ์ง์ ๋ชจ๋ธ๋งํฉ๋๋ค. ์ด๋ก ์ธํด RNN๊ณผ ๋ฌ๋ฆฌ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๊ฐ ๋ฐ์ํ์ง ์์ผ๋ฉฐ, ๋ณ๋ ฌํ๊ฐ ๊ฐ๋ฅํด ๋๊ท๋ชจ ๋ฐ์ดํฐ ํ์ต์ ์ ๋ฆฌํฉ๋๋ค. ์ด๋ฌํ ์ฅ์ ๋๋ถ์ ์์ฐ์ด ์ฒ๋ฆฌ, ์ด๋ฏธ์ง ์ฒ๋ฆฌ ๋ฑ ๋ค์ํ ๋ถ์ผ์์ ํ์ํ ์ฑ๋ฅ์ ๋ฐํํ๊ณ ์์ต๋๋ค.
๊ทธ๋ฌ๋ Transformer์ ๊ฐ์ฅ ํฐ ๋ฌธ์ ๋ ๋ฉ๋ชจ๋ฆฌ ๋ฐ ๊ณ์ฐ ๋น์ฉ์ด ์ํ์ค ๊ธธ์ด์ ๋ฐ๋ผ Quadraticํ๊ฒ ์ฆ๊ฐํ๋ค๋ ์ ์ ๋๋ค. ์ํ์ค ๊ธธ์ด๊ฐ ๊ธธ์ด์ง์๋ก ๋ฉ๋ชจ๋ฆฌ์ ์ฐ์ฐ ๋น์ฉ์ด ๊ธ๊ฒฉํ ์ฆ๊ฐํ์ฌ, ๊ธด ์ํ์ค๋ฅผ ๋ค๋ฃจ๋ ๋ฐ ํจ์จ์ ์ด์ง ์์ต๋๋ค.
์ด์ ๋ฐํด, RNN์ ์ํ์ค ๊ธธ์ด์ ๋น๋กํ๋ Linearํ Cost๋ง์ ์๊ตฌํ๋ฏ๋ก, ๊ธด ์ํ์ค์์ ์ถ๋ก ํ ๋ ์ฌ์ ํ ๋ ๋น ๋ฆ ๋๋ค.
1.3. ์ํ๊ณต๊ฐ๋ชจ๋ธ(SSM)์ ๋ฑ์ฅ
์ด๋ฌํ Transformer์ ๋ฌธ์ ์ ์ ํด๊ฒฐํ๊ธฐ ์ํด Gu et al. (2021)์ด ์ ์ํ State Space Model(SSM)
์ ํ์ฉํ S4(Structured State Space Model)๋ชจ๋ธ์ด ์ฃผ๋ชฉ๋ฐ๊ธฐ ์์ํ์ต๋๋ค.
S4 ๋ชจ๋ธ์ Long Range Arena(LRA)๋ผ๋ ๊ธด ์ํ์ค๋ฅผ ๋ค๋ฃจ๋ ๋ฒค์น๋งํฌ์์ ๋ฐ์ด๋ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ๋ฉฐ, ํนํ ๊ธด ์ํ์ค๋ฅผ ํจ์จ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์๋ ๋ฅ๋ ฅ์ ๊ฐ์ง๊ณ ์์ต๋๋ค. SSM์ ์ํ์ค ๊ฐ์ ์ํธ์์ฉ์ ์์ฐจ์ ์ผ๋ก ์ฒ๋ฆฌํ๋ฉด์๋ ๋ณ๋ ฌ ์ฒ๋ฆฌ๊ฐ ๊ฐ๋ฅํ๊ณ , ์ถ๋ก ์๋๋ ๋น ๋ฆ ๋๋ค. (๋ค์ paper-review์์ ์ ๋ฆฌํด๋ณผ๊ฒ์!)
์ด๋ RNN๊ณผ ์ ์ฌํ ๋ฐฉ์์ผ๋ก ์๋ํ๋ฉด์๋, ํ์ต ์๋์ ์ฑ๋ฅ ๋ฉด์์ Transformer๋ณด๋ค ํจ์จ์ ์ ๋๋ค.
1.4. ์ฐ๊ตฌ ๋ชฉํ์ ํต์ฌ ๊ธฐ์ฌ
์๋ก ์ ๋ง์ง๋ง ๋ถ๋ถ์์๋ ์ด ๋ ผ๋ฌธ์ ์ฐ๊ตฌ ๋ชฉํ๋ฅผ ์ ์ํฉ๋๋ค. RNN๊ณผ SSM์ ์ฑ๋ฅ ์ฐจ์ด๋ฅผ ๋ถ์ํ๊ณ , RNN์ ๊ฐ์ ํ์ฌ SSM ์์ค์ ์ฑ๋ฅ์ ๋ณต์ํ๋ ๋ฐฉ๋ฒ์ ํ๊ตฌํ๋ ๊ฒ์ด ์ด ๋ ผ๋ฌธ์ ์ฃผ๋ ๋ชฉํ์ ๋๋ค. ์ด๋ฅผ ์ํด, RNN์ ๊ตฌ์กฐ๋ฅผ ์ธ๋ฐํ๊ฒ ์กฐ์ ํ๋ ์ฌ๋ฌ ๊ฐ์ง ๋ฐฉ๋ฒ์ ์ ์ํฉ๋๋ค. ์ด ๋ฐฉ๋ฒ๋ค์ ํตํด RNN๋ SSM๊ณผ ๋์ผํ ์ฑ๋ฅ์ ๋ฐํํ ์ ์์์ ๋ณด์ฌ์ฃผ๊ณ , ์ด๋ฌํ ์์ ์ด RNN์ ํ์ต ์๋์๋ ๊ธ์ ์ ์ธ ์ํฅ์ ๋ฏธ์น๋ค๋ ์ ์ ๊ฐ์กฐํฉ๋๋ค.
๋ ผ๋ฌธ์์๋ ์๋์ ๊ฐ์ ํ๋๋ฅผ ๋์ง๋ฉฐ ํด๋น ์ง๋ฌธ์ ๊ธ์ ์ (positive)๋ผ๊ณ ์ด์ผ๊ธฐํฉ๋๋ค.
์ฐ๊ตฌ์๋ค์ deepRNN์ ์ฌ์ฉํ์ฌ ๊น์ ์ฐ์ ์๊ฐ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ(SSM)์ ์ฑ๋ฅ๊ณผ ํจ์จ์ฑ์ ์ผ์น์ํฌ ์ ์์์ ์ฃผ์ฅํฉ๋๋ค. ์ด๋ฅผ ๋ฌ์ฑํ๊ธฐ ์ํด ๋ ผ๋ฌธ์์ ์ ์ํ ๋ช ๊ฐ์ง ์ฃผ์ ์ ๊ทผ ๋ฐฉ์์ ์ค๋ช ํ๊ฒ ์ต๋๋ค.
- Linear Recurrences (์ ํ ์ฌ๊ท): ๊ธฐ์กด์ tanh ๋๋ ReLU ํ์ฑํ๋ฅผ ์ฌ์ฉํ๋ RNN ๊ณ์ธต ๋์ ๋น์ ํ์ฑ์ ์ ๊ฑฐํ๊ณ ์ ํ ์ฌ๊ท(์ ํ์ ์ผ๋ก ๋ฐ๋ณต๋๋ ๊ตฌ์กฐ)๋ฅผ ์ฌ์ฉํ์ฌ ์ฑ๋ฅ์ด ํฌ๊ฒ ํฅ์๋๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ์ต๋๋ค. ์ฌ๊ธฐ์ ์ ํ์ ์ผ๋ก ๋ฐ๋ณต๋๋ ๊ตฌ์กฐ๋, ์๋ฅผ ๋ค์ด tanh๋ ReLU์ ๊ฐ์ ํ์ฑํ ํจ์๋ฅผ ์ฌ์ฉํ์ง ์๊ณ ๋จ์ํ ํ๋ ฌ ๊ณฑ์ ๊ณผ ๋ง์ ๋ง์ ํตํด ์ํ๋ฅผ ๊ฐฑ์ ํ๋ ๋ฐฉ์์ ๋๋ค. ์ ํ ์ฌ๊ท๋ฅผ ์ฌ์ฉํ๋ฉด gradient์ ์์ค ๋๋ ํญ๋ฐ์ ์ง์ ์ ์ดํ ์ ์์ผ๋ฉฐ, ๋ณ๋ ฌํ๋ ํ๋ จ์ด ๊ฐ๋ฅํด์ง๋๋ค.
-
Complex Diagonal Recurrent Matrices (๋ณต์์ ๋๊ฐ ์ฌ๊ท ํ๋ ฌ): ๋ฐ์ง ์ ํ RNN ๊ณ์ธต์ ๋ณต์์ ๋๊ฐ ํํ๋ก ์ฌ๊ตฌ์ฑํ๋ ๊ฒ์ผ๋ก๋ ์ฑ๋ฅ์ด ํฅ์๋ฉ๋๋ค. ๋ฐ์ง ์ ํ RNN ๋ ์ด์ด๋ ๋คํธ์ํฌ์ ํํ๋ ฅ์ ์์์ํค์ง ์์ผ๋ฉด์ ๋ณต์กํ ๋๊ฐ ํํ๋ก ์ฌ๋งค๊ฐ๋ณ์ํํ ์ ์์ต๋๋ค. ์ด๋ฅผ ํตํด ์ด๊ธฐํ์์์ ํน์ฑ๋ ์ ์ง๋ฉ๋๋ค. ๋๊ฐ ํ๋ ฌ์ ๋ฐ๋ณต์ ์ธ ๊ณผ์ ์ ๋ณ๋ ฌ๋ก ํ ์ ์๊ฒ ํด์ฃผ์ด ํ๋ จ ์๋๋ฅผ ํฌ๊ฒ ํฅ์์ํต๋๋ค.
- ๋๊ฐํ๋ ฌ๋ค์ ๊ณฑ์ ์ฐ์ฐ์์ ๊ฒฐํฉ๋ฒ์น์ ๋ง์กฑํ๊ธฐ ๋๋ฌธ์, ๊ฐ ์ฐ์ฐ์ ๋ณ๋ ฌํํ ์ ์์ต๋๋ค. ์ด๋ฅผ ํตํด RNN๊ณผ ๊ฐ์ ๋ชจ๋ธ์ ํ๋ จ ์๋๋ฅผ ํฌ๊ฒ ๊ฐ์ ํ ์ ์์ต๋๋ค. (Martin & Cundy, 2017).
- (์ฐธ๊ณ )
๋๊ฐ ํ๋ ฌ ๊ณฑ์ ์ ํน์ฑ
: ๋๊ฐ ํ๋ ฌ์ ๋น๋๊ฐ์ ์์๊ฐ ๋ชจ๋ 0์ด๊ธฐ ๋๋ฌธ์, ๊ณฑ์ ์ฐ์ฐ์ด ๊ฐ ๋๊ฐ์ ์์๋ผ๋ฆฌ๋ง ์ด๋ฃจ์ด์ง๋๋ค. ์ฆ, ๊ฐ๊ฐ์ ์์๊ฐ ๋ ๋ฆฝ์ ์ด๊ธฐ ๋๋ฌธ์, ๋ณ๋ ฌ๋ก ์ฐ์ฐ์ด ๊ฐ๋ฅํด์ง๋๋ค. - (์ฐธ๊ณ )
๊ฒฐํฉ๋ฒ์น
: ๋๊ฐ ํ๋ ฌ์ ๊ณฑ์ ์ ๊ฒฐํฉ๋ฒ์น์ ๋ง์กฑํฉ๋๋ค. ์ฆ, Aร(BรC)=(AรB)รC์ ๊ฐ์ ๋ฐฉ์์ผ๋ก ์ฐ์ฐ ์์์ ์๊ด์์ด ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ์ป์ ์ ์์ต๋๋ค. ์ด๋ ๋ณ๋ ฌ ์ฒ๋ฆฌ์ ๋งค์ฐ ์ ๋ฆฌํ ํน์ฑ์ ๋๋ค.
- Stable Exponential Parameterization (์์ ์ ์ธ ์ง์ ํ๋ผ๋ฏธํฐํ): ๋๊ฐ ์ฌ๊ท ํ๋ ฌ์ ๋ํ ์ง์ ํ๋ผ๋ฏธํฐํ๋ฅผ ์ฌ์ฉํ๋ฉด ์์ ์ฑ์ ๋ณด์ฅํ ์ ์์ผ๋ฉฐ, ์ด๋ก ์ธํด ์ด๊ธฐํ ๋ถํฌ๋ฅผ ์กฐ์ ํ์ฌ ์ฅ๊ธฐ์ ์ธ ์ถ๋ก ์ฑ๋ฅ์ ํฅ์์ํฌ ์ ์์ต๋๋ค. ์ด๊ธฐํ ์ ๊ณ ์ ๊ฐ ๋ถํฌ๊ฐ ์ฅ๊ธฐ ์ถ๋ก ์ ์บก์ฒํ๋ ๋ฐ ์ค์ํ ์ญํ ์ ํ๋ค๊ณ ๊ฐ์กฐํฉ๋๋ค.
- Normalization (์ ๊ทํ): ํ๋ จ ๊ณผ์ ์์์ ์จ๊ฒจ์ง ํ์ฑํ๋ฅผ ์ ๊ทํํ๋ ๊ฒ์ด ๋งค์ฐ ์ค์ํ๋ค๊ณ ์ธ๊ธํ์ต๋๋ค. ์ด๋ฅผ ํตํด RNN์ด LRA(Long Range Arena) ๋ฒค์น๋งํฌ์ ๋ชจ๋ Task์์ SSM์ ์ฑ๋ฅ๊ณผ ์คํ๋ ์ฑ๋ฅ์ ๋ผ ์ ์์์ต๋๋ค.
๋ณธ ๋ ผ๋ฌธ์ RNN์ ์ฑ๋ฅ์ ๋์ด๋ฆฌ๊ธฐ ์ํด Linear Recurrent Unit(LRU)๋ผ๋ ์๋ก์ด ๋ธ๋ก์ ์ ์ํ๋ฉฐ, LRU๊ฐ SSM๊ณผ ์ ์ฌํ ์ฑ๋ฅ๊ณผ ํจ์จ์ฑ์ ๊ฐ์ถ ์ ์์์ ์คํ์ ์ผ๋ก ์ ์ฆํฉ๋๋ค. ์ ์๋ LRU๋ Long Range Arena(LRA) ๋ฒค์น๋งํฌ์์ SSM๊ณผ ๋๋ฑํ ์ฑ๋ฅ์ ๋ณด์ด๋ฉฐ, ๋ณ๋ ฌ ํ์ต ์๋๋ ์ผ์นํฉ๋๋ค.
์ด ์ฅ์์๋ ์ ํต์ ์ธ RNN๊ณผ ์ต๊ทผ์ S4์ ๊ฐ์ deepSSM์ ์ฃผ์ ์ฐจ์ด์ ์ ์ค๋ช ํฉ๋๋ค.
RNN (Recurrent Neural Network)
RNN
์ ์์ฐจ์ ์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ์ ๊ฒฝ๋ง ๊ตฌ์กฐ์ ๋๋ค. ์ฌ๊ธฐ์ ์ค์ํ ๊ฒ์ ์ด์ ๋จ๊ณ์ ์ํ ์ ๋ณด๊ฐ ํ์ฌ ์ํ๋ฅผ ๊ณ์ฐํ๋ ๋ฐ ์ฌ์ฉ๋๋ค๋ ์ ์ ๋๋ค.
์ด๋ ๋ค์ ์์์ผ๋ก ํํ๋ฉ๋๋ค:
xk=ฯ(Axkโ1+Buk),yk=Cxk+Dukx_k = \sigma(Ax_{k-1} + Bu_k), \quad y_k = Cx_k + D u_kxkโ=ฯ(Axkโ1โ+Bukโ),ykโ=Cxkโ+Dukโ
์ฌ๊ธฐ์:
- AAA, BBB, CCC, DDD๋ ํ์ต ๊ฐ๋ฅํ ๋งค๊ฐ๋ณ์์ ๋๋ค.
- ฯ\sigmaฯ๋ ๋น์ ํ ํ์ฑํ ํจ์๋ก ๋ณดํต tanh๋ sigmoid๊ฐ ์ฌ์ฉ๋ฉ๋๋ค. ๋ง์ฝ ฯ\sigmaฯ๊ฐ ํญ๋ฑ ํจ์(identity function)๋ผ๋ฉด RNN์ ์ ํ RNN์ผ๋ก ๊ฐ์ฃผ๋ฉ๋๋ค.
- xkx_kxkโ๋ kkk๋ฒ์งธ ์์ ์์์ ํ๋ ์ํ(hidden state)์ด๊ณ , yky_kykโ๋ ์ถ๋ ฅ์ ๋๋ค.
์ผ๋ฐ์ ์ผ๋ก tanh๋ ReLU ๊ฐ์ ๋น์ ํ ํ์ฑํ ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ์ํ ์ ์ด๋ฅผ ์ด๋ฃจ๋ฉฐ, ๋งค๋ฒ ์ด์ ์ํ์์ ์๋ก์ด ์ ๋ ฅ์ ๋ฐ์ ๋ฐ๋ณต์ ์ผ๋ก ์ํ๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค.
- ํ์ง๋ง RNN์ ๋น์ ํ์ฑ์ ๊ฐ๊ณ ์์ด ํ์ต ๊ณผ์ ์์ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๋ก ์ธํด ํ์ต์ด ์ด๋ ค์ธ ์ ์์ต๋๋ค.
SSM (State Space Model)
SSM
์ ์ฐ์ ์๊ฐ ์์คํ ์ ๊ธฐ๋ฐ์ผ๋ก ํ ๋ชจ๋ธ์ ๋๋ค. ์ฐ์์ ์ธ ์๊ฐ ์ถ์์ ์ํ ๋ณํ๋ฅผ ๋ชจ๋ธ๋งํ๊ณ , ์ด๋ฅผ ์ด์ฐํํ์ฌ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.
์ด๋ ์ํ์ ์ผ๋ก๋ ๋ค์๊ณผ ๊ฐ์ด ์ค๋ช ๋ฉ๋๋ค:
ddtxct(t)=A~xct(t)+B~uct(t)\frac{d}{dt}x_{ct}(t) = Aฬ x_{ct}(t) + Bฬ u_{ct}(t)dtdโxctโ(t)=A~xctโ(t)+B~uctโ(t) yct(t)=R(C~xct(t))+D~uct(t)y_{ct}(t) = \mathcal{R}(Cฬ x_{ct}(t)) + Dฬ u_{ct}(t)yctโ(t)=R(C~xctโ(t))+D~uctโ(t)
์ด ๋ชจ๋ธ์์ ์ ๋ ฅ ์ ํธ uctu_{ct}uctโ๋ ์ฐ์ ์๊ฐ์์ ์ํ๋ง๋ ์๊ทธ๋๋ก ๊ฐ์ฃผ๋ฉ๋๋ค. ์ด ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ๊ธด ์ํ์ค์์๋ ์์ ์ ์ธ ๊ณ์ฐ์ ๋ณด์ฅํ๋ฉฐ, ํนํ ๋ณต์กํ ์ํธ์์ฉ์ ํจ์จ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.
S4 (Structured State Space Sequence Model)
S4
๋ ์์ SSM์ ๊ธฐ๋ฐ์ผ๋ก ์ค๊ณ๋ ๋ฅ๋ฌ๋ ๋ชจ๋ธ๋ก, ๊ธด ์ํ์ค ๋ชจ๋ธ๋ง์์ ๋ฐ์ด๋ ์ฑ๋ฅ์ ๋ฐํํฉ๋๋ค. ์ด ๋ชจ๋ธ์ ๋ณต์์ ํ๋ ฌ์ ์ฌ์ฉํ์ฌ RNN๋ณด๋ค ๋ ํจ์จ์ ์ผ๋ก ํ์ตํ ์ ์์ผ๋ฉฐ, ๋ณ๋ ฌํ๊ฐ ๊ฐ๋ฅํ๋ค๋ ์ฅ์ ์ด ์์ต๋๋ค.
์ด๋ฏธ์ง ์ถ์ฒ : A Visual Guide to Mamba and State Space Models (๋งํฌ)
S4๋ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ๊ธฐ๋ฐ์ผ๋ก ํ ๋งค์ฐ ํจ์จ์ ์ธ ๋ชจ๋ธ์ ๋๋ค. S4๋ ๋ณต์กํ ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ์ ์๋ ๋ฅ๋ ฅ์ผ๋ก ์ฃผ๋ชฉ๋ฐ๊ณ ์์ผ๋ฉฐ, ๋ค์๊ณผ ๊ฐ์ ํน์ง์ ๊ฐ์ง๊ณ ์์ต๋๋ค:
- ์ฐ์ ์๊ฐ ๊ธฐ๋ฐ์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ: S4๋ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ฐ์์ ์ธ ์ ํธ๋ฅผ ์ฒ๋ฆฌํ๊ณ ์ด๋ฅผ ์ด์ฐ ์๊ฐ์ผ๋ก ๋ณํํฉ๋๋ค.
- ๋ณ๋ ฌ ์ฒ๋ฆฌ ๊ฐ๋ฅ: S4๋ ์ ํ ์์คํ ์ ๊ธฐ๋ฐ์ผ๋ก ํ์ฌ RNN๊ณผ ๋ฌ๋ฆฌ ํ์ต ๊ณผ์ ์์ ๋ณ๋ ฌ ์ฒ๋ฆฌ๊ฐ ๊ฐ๋ฅํฉ๋๋ค.
- HiPPO ์ด๋ก ์ ๊ธฐ๋ฐ์ผ๋ก ํ ์ด๊ธฐํ: S4๋ HiPPO ์ด๋ก ์ ๋ฐํ์ผ๋ก ๋ณต์กํ ์ด๊ธฐํ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ๊ทน๋ํํฉ๋๋ค. ์ด ์ด๊ธฐํ ๊ณผ์ ์ SSM์ด ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ๋ ์ค์ํ ์ญํ ์ ํฉ๋๋ค.
S4๋ RNN๊ณผ ๋น๊ตํ์ ๋ ๊ธด ์ํ์ค๋ฅผ ๋ ํจ์จ์ ์ผ๋ก ์ฒ๋ฆฌํ๋ฉฐ, RNN์ ๋ณ๋ชฉ ํ์(Sequential Processing)์ ๊ทน๋ณตํ ์ ์๋ ์ฑ๋ฅ์ ๋ณด์ฌ์ค๋๋ค.
์ด ์ฅ์์๋ ๋ ผ๋ฌธ์์ ์ ์ํ๋ ์ฑ๋ฅ์ด ๋ฐ์ด๋ ๊น์ RNN(Deep RNN)์ ์ค๊ณํ๊ธฐ ์ํ ์ฃผ์ ๋จ๊ณ๋ฅผ ์ค๋ช ํฉ๋๋ค. ์ฐ๊ตฌ์ง์ SSM(์ํ ๊ณต๊ฐ ๋ชจ๋ธ)์ ๋ฐ์ด๋ ์ฑ๋ฅ์ ์ฌํํ๊ณ ์ RNN์ ๊ตฌ์กฐ์ ๋ณํ์ ํตํด SSM๊ณผ ๋น์ทํ ์ฑ๋ฅ์ ๋ฌ์ฑํ ์ ์์์ ๋ณด์ฌ์ค๋๋ค.
๋จผ์ ์์์ ๋ณด์ฌ๋๋ฆฐ ๊ทธ๋ฆผ์ ํ๊ตฌํด๋ด ์๋ค.
(Left) Deep Linear Recurrent Unit (LRU) Architecture
-
์ผ์ชฝ ๊ทธ๋ฆผ์ LRU ์ ์ธ๋ถ ๊ตฌ์กฐ๋ฅผ ๋ณด์ฌ์ค๋๋ค.
-
Linear Encoder (์ ํ ์ธ์ฝ๋) :
-
์ ๋ ฅ ์ํ์ค๋ฅผ ์ ํ ์ธ์ฝ๋์ ํต๊ณผ์ํต๋๋ค.
- ์ด ์ธ์ฝ๋๋ ๋ชจ๋ ํ์์คํ ์ ๋ํด ๋์ผํ๊ฒ ์ ์ฉ๋๋ฉฐ, ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ ์ ํ ํํ๋ก ๋ณํํฉ๋๋ค.
-
์ฐจ์ ์ถ์ ๋๋ ๋ณํ์ ํตํด ๋ชจ๋ธ์ด ์ฒ๋ฆฌํ ์ ์๋ ํํ๋ก ๋ฐ์ดํฐ๋ฅผ ์ ๋ฌํฉ๋๋ค.
-
-
Linear Recurrent Unit (LRU) :
-
LRU ๋ธ๋ก ์ ์ฌ๋ฌ ๊ฐ์ ์ธต์ผ๋ก ์์ธ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๋ฉฐ, ๊ฐ ์ธต ์ฌ์ด์๋ MLP/GLU ๋ธ๋ก์ด ์ฝ์ ๋์ด ๋น์ ํ์ฑ์ ์ถ๊ฐํฉ๋๋ค.
- ์ฌ๊ธฐ์ MLP๋ Multi-Layer Perceptron ์ ์๋ฏธํ๊ณ , GLU๋ Gated Linear Unit ์ ์๋ฏธํฉ๋๋ค.
-
LRU๋ RNN์ ๊ธฐ๋ณธ์ ์ธ ์ํ ์ ๊ฒฝ๋ง ๊ตฌ์กฐ๋ฅผ ๋ฐ๋ฅด๊ณ ์์ง๋ง, ๋ช ๊ฐ์ง ์ค์ํ ๋ณํ์ ํตํด ์ฑ๋ฅ์ ๊ฐ์ ํ ๋ฒ์ ์ ๋๋ค.
- ํนํ, ๋น์ ํ์ฑ์ ์ ๊ฑฐํ๊ณ , ๋ณต์์ ๋๊ฐ ํ๋ ฌ์ ์ฌ์ฉํ๋ ๋ฑ, RNN์ ํจ์จ์ฑ์ ๋์ด๊ธฐ ์ํ ์ฌ๋ฌ ๊ฐ์ ์ฌํญ์ ํฌํจํ๊ณ ์์ต๋๋ค.
- ํ์ง๋ง ์ํ์ ์ธ ์ํ ์ ์ด์ ์ํ์ค ์ฒ๋ฆฌ๋ผ๋ RNN์ ํต์ฌ ํน์ฑ์ ์ ์งํ๊ณ ์๊ธฐ ๋๋ฌธ์ LRU๋ RNN์ ํ ํํ๋ผ๊ณ ํ ์ ์์ต๋๋ค.
- LRU์ ํต์ฌ์ ์ ํ ๋ฐ๋ณต ์ ์ฌ์ฉํ๋ฉฐ, ์ด๋ฅผ ํตํด RNN์ ๋น์ ํ์ฑ์ ์ ๊ฑฐํ๋ฉด์๋ ๊ฐ๋ ฅํ ์ฑ๋ฅ์ ์ ์งํ ์ ์์ต๋๋ค. ์ด ๊ตฌ์กฐ๋ ๋ณ๋ ฌํ๊ฐ ๊ฐ๋ฅํ๋ฉฐ, ํ์ต ์๋๋ฅผ ํฅ์์ํต๋๋ค.
-
LRU์ ๋ฐ๋ณต ์์
์ ์๋์ ๊ฐ์ต๋๋ค:xk=diag(ฮป)xkโ1+ฮณโBukx_k = \text{diag}(\lambda) x_{k-1} + \gamma \odot B u_kxkโ=diag(ฮป)xkโ1โ+ฮณโBukโ
- diag(ฮป)diag(ฮป)diag(ฮป) : ๋ณต์์ ๋๊ฐ ํ๋ ฌ๋ก, ์ด๋ ์ํ ์ ์ด๋ฅผ ๋ ํจ์จ์ ์ผ๋ก ๊ณ์ฐํ ์ ์๊ฒ ํด์ค๋๋ค.
- ฮณฮณฮณ : ํ์ต ๊ฐ๋ฅํ ์ ๊ทํ ํ๋ผ๋ฏธํฐ๋ก, ์ด๋ ๊ฐ ์ํ์ค ํ์์คํ ์์ ์ํ ์ ๋ณด๋ฅผ ์กฐ์ ํ๋ ์ญํ ์ ํฉ๋๋ค.
- BukB u_kBukโ : ์ ๋ ฅ ๋ฐ์ดํฐ uku_kukโ ์ ๋ํด ๊ฐ์ค์น BBB ๋ฅผ ๊ณฑํ์ฌ ๋ค์ ์ํ๋ก ์ ๋ฌํฉ๋๋ค.
-
์ธ๋ถ ํ๋ผ๋ฏธํฐ๋ฅผ ์ข ๋ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
-
Normalization (์ ๊ทํ, ฮณj\gamma_jฮณjโ) :
- LRU ๋ด๋ถ์์๋ Pre-Layer Normalization ๋๋ Batch Normalization ์ด ์ ์ฉ๋์ด ๊ธด ์ํ์ค ํ์ต ์ค ๋ฐ์ํ ์ ์๋ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๋ฅผ ์ํํฉ๋๋ค.
-
์ ๊ทํ๋ ๊ฐ ์ธต์์ ๋ฐ์ํ๋ ์๋ ์ํ์ ์ค์ผ์ผ์ ์กฐ์ ํ์ฌ ํ์ต์ ์์ ํ์ํต๋๋ค.
ฮณj=(1โโฃฮปjโฃ2)1/2\gamma_j = (1 - ฮป_j ^2)^{1/2}ฮณjโ=(1โโฃฮปjโโฃ2)1/2
-
Stable Exponential Parameterization (์์ ์ ์ธ ์ง์ ๋งค๊ฐ๋ณ์ํ, ฮปjฮป_jฮปjโ)
- RNN์ ๋ฐ๋ณต ํ๋ ฌ์ ๋งค๊ฐ๋ณ์ํํ ๋, ํ์ต ๊ณผ์ ์์ ์์ ์ฑ์ ๋ณด์ฅํ๊ธฐ ์ํด ์ํ๋ฉ๋๋ค.
-
์ด๋ ๋ชจ๋ธ์ด ํ์ตํ๋ ๋์ ๊ธฐ์ธ๊ธฐ ์์ค(vanishing gradient) ๋๋ ๊ธฐ์ธ๊ธฐ ํญ๋ฐ(exploding gradient)์ ๋ฐฉ์งํ์ฌ, ํนํ ๊ธด ์ํ์ค์์ ์ค๋ฒํ๋ก์ฐ ๋๋ ์ธ๋ํ๋ก์ฐ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๋ฐ ์ค์ํ ์ญํ ์ ํฉ๋๋ค.
ฮปj=exp(โexp(ฮฝjlog)+iexp(ฮธjlog))ฮป_j = exp(-exp(ฮฝ_j^log) + i exp(ฮธ_j^log))ฮปjโ=exp(โexp(ฮฝjlโog)+iexp(ฮธjlโog))
- ฮป๋ ๋ณต์์ ๊ณ ์ ๊ฐ์ด๋ฉฐ, ์ด๋ ํ์ต ์ค์ ๋งค๊ฐ๋ณ์ํ๋์ด ์์ ์ฑ์ ๋ณด์ฅํฉ๋๋ค.
- ๊ณ ์ ๊ฐ์ ํฌ๊ธฐ์ ์์์ ๊ฐ๊ฐ ๋งค๊ฐ๋ณ์ํํ์ฌ ํ์ต ์ฑ๋ฅ์ ๋์ ๋๋ค.
-
-
-
Skip Connection :
- ๊ฐ ์ธต ์ฌ์ด์๋ Skip Connection ์ด ํฌํจ๋์ด ์์ต๋๋ค. ์ด๋ ๋ชจ๋ธ์ด ๊น์ด์ง์๋ก ๋ฐ์ํ ์ ์๋ ์ ๋ณด ์์ค์ ๋ฐฉ์งํ๊ณ , ๋ ๊น์ ๋คํธ์ํฌ์์ ํจ๊ณผ์ ์ธ ํ์ต์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค.
-
Linear Layer (์ ํ ์ถ๋ ฅ ๋ ์ด์ด) :
- ๋ง์ง๋ง์ผ๋ก, ํ์์คํ ๊ณผ ๊ด๋ จ๋ ์ถ๋ ฅ์ ์ํด ์ ํ ๋ ์ด์ด ๊ฐ ์ฌ์ฉ๋ฉ๋๋ค. ์ด๋ ๋ชจ๋ธ์ ์ต์ข ์ถ๋ ฅ์ผ๋ก ์ด์ด์ง๋ฉฐ, ํด๋์ค ์์ธก์ด๋ ๊ธฐํ ์ํ์ค ๋ชจ๋ธ๋ง ์์ ์ ์ฌ์ฉ๋ฉ๋๋ค.
(Right) Test accuracy on LRA tasks
์ค๋ฅธ์ชฝ ๊ทธ๋ํ๋ Long Range Arena (LRA) ๋ฒค์น๋งํฌ์์ ์ํํ ํ ์คํธ ๊ฒฐ๊ณผ๋ฅผ ์๊ฐ์ ์ผ๋ก ๋ณด์ฌ์ค๋๋ค. ์ด ๊ทธ๋ํ๋ RNN ๊ตฌ์กฐ์์ tanh ํ์ฑํ ํจ์๋ฅผ ์ฌ์ฉํ๋ ๊ธฐ๋ณธ RNN์์ ์์ํ์ฌ, LRU ๊ตฌ์กฐ๋ก ๋ณ๊ฒฝํด๊ฐ๋ฉด์ ์ฑ๋ฅ์ด ์ด๋ป๊ฒ ๋ณํํ๋์ง๋ฅผ ๋ํ๋ ๋๋ค.
-
Recurrent Block Variants (RNN ๋ธ๋ก ๋ณํ) :
- X์ถ์ ๋ฐ๋ณต ๋ชจ๋์ ์ฌ๋ฌ ๋ณํ์ ๋ํ๋ ๋๋ค. ์ด๊ธฐ tanh ํ์ฑํ RNN์์ ์์ํ์ฌ, ์ ํ ํ์ฑํ , ๋๊ฐ์ ํ(๋ณต์์ ๋๊ฐ ํ๋ ฌ) , ์์ ์ ์ธ ์ง์ ๋งค๊ฐ๋ณ์ํ , ๊ทธ๋ฆฌ๊ณ ์ต์ข ์ ์ผ๋ก ฮณ ์ ๊ทํ ๋ฅผ ํฌํจํ๋ LRU๊น์ง ๋ฐ์ ํฉ๋๋ค.
- ๊ฐ ๋จ๊ณ์์ RNN์ด ์ด๋ป๊ฒ ๊ฐ์ ๋๊ณ ์๋์ง ํ์ธํ ์ ์์ต๋๋ค.
-
Efficiency Boost (ํจ์จ์ฑ ํฅ์) :
- ๊ทธ๋ํ์์ ์ฃผ๋ชฉํ ์ ์ ์ ํ ๋๊ฐ ํ๋ ฌ๊ณผ ์์ ์ ์ธ ์ง์ ๋งค๊ฐ๋ณ์ํ๋ฅผ ๋์ ํ ์์ ์์ ์ฑ๋ฅ์ด ํฌ๊ฒ ํฅ์๋์๋ค๋ ๊ฒ์ ๋๋ค. ์ด๋ ๋ณ๋ ฌํ ๋ฐ ํจ์จ์ ์ธ ๊ณ์ฐ์ ํตํด ํ์ต ์๋์ ์ฑ๋ฅ์ด ๋์์ ๊ฐ์ ๋ ๊ฒ์ ๋ํ๋ ๋๋ค.
-
Performance on LRA tasks (LRA ๊ณผ์ ์์์ ์ฑ๋ฅ) :
-
๊ทธ๋ํ์ Y์ถ์ ํ ์คํธ ์ ํ๋๋ฅผ ๋ํ๋ด๋ฉฐ, ๊ฐ ์์๊ณผ ๊ธฐํธ๋ ์๋ก ๋ค๋ฅธ ๊ณผ์ ๋ฅผ ์๋ฏธํฉ๋๋ค.
- sCIFAR (์ค๋ ์ง์ ์), ListOps (์ด๋ก์ ์ผ๊ฐํ), PathFinder (๊ฐ์ ์ฌ๊ฐํ), PathX (๋ ธ๋์ ๋ค์ด์๋ชฌ๋)
-
S4์ S5 ๋ชจ๋ธ์ ์ฑ๋ฅ์ ๋์ ์ ์ผ๋ก ๋ํ๋ด์ด, ๊ฐ ๊ณผ์ ์์ LRU๊ฐ S4/S5์ ๋๋ฑํ ์ฑ๋ฅ์ ๋ฐํํ๊ณ ์์์ ๋ณด์ฌ์ค๋๋ค.
-
๊ฐ 3. Designing Performant Deep RNNs์ ๊ฐ ์ ์์๋ ๋ ผ๋ฌธ์์ ์ ์ํ๋ Linear Recurrent Unit (LRU)์ ํต์ฌ์ ์ธ ์ค๊ณ ์์๋ค์ ์ค๋ช ํฉ๋๋ค.
3.1 Linear RNN layers are performant (์ ํ RNN ๋ ์ด์ด๋ ์ฑ๋ฅ์ด ๋ฐ์ด๋จ)
์ด ์ ์์ ์ฐ๊ตฌ์ง์ RNN์ ๋น์ ํ์ฑ์ ์ ๊ฑฐํ๋ ๊ฒ์ด ์ฑ๋ฅ ํฅ์์ ์ด๋ป๊ฒ ๊ธฐ์ฌํ๋์ง ์ค๋ช ํฉ๋๋ค. RNN์ ์ ํต์ ์ธ ๊ตฌ์กฐ๋ tanh๋ ReLU์ ๊ฐ์ ๋น์ ํ ํ์ฑํ ํจ์๊ฐ ํฌํจ๋์ด ์์ง๋ง, ๋ ผ๋ฌธ์์๋ ์ด๋ฌํ ๋น์ ํ์ฑ์ ์ ๊ฑฐํ ์ ํ RNN (Linear RNN)์ด ๋งค์ฐ ์ข์ ์ฑ๋ฅ์ ๋ฐํํ ์ ์๋ค๋ ์ ์ ๋ฐ๊ฒฌํ์ต๋๋ค.
- ๋ฐ๊ฒฌ๋ ์ฑ๋ฅ ๊ฐ์ : ์คํ ๊ฒฐ๊ณผ์ ๋ฐ๋ฅด๋ฉด RNN์ ๋น์ ํ์ฑ์ ์ ๊ฑฐํ ํ, Long Range Arena(LRA) ๋ฒค์น๋งํฌ์์ ์ฑ๋ฅ์ด ํฌ๊ฒ ๊ฐ์ ๋์์ต๋๋ค. ํนํ, ํ ์คํธ ์ฒ๋ฆฌ๋ ์ ๋ณด ๊ฒ์๊ณผ ๊ฐ์ ํน์ ๊ณผ์ ์์๋ ๋น์ ํ RNN๋ณด๋ค๋ ๋ ๋์ ์ฑ๋ฅ์ ๋ณด์์ต๋๋ค.
- ๋น์ ํ์ฑ ์ ๊ฑฐ์ ํจ๊ณผ: ๋น์ ํ์ฑ์ ์ ๊ฑฐํ๋ฉด RNN์ด ํจ์ฌ ๋ ํจ์จ์ ์ผ๋ก ์๋ํ๋ฉฐ, ํนํ ๊ธด ์ํ์ค์์์ ๊ธฐ์ธ๊ธฐ ์์ค(vanishing gradient) ๋ฌธ์ ๋ฅผ ์ํํ ์ ์์ต๋๋ค. ๋ํ, ๋น์ ํ์ฑ์ด ์๋ ์ํ์์ RNN์ ์์ ์ฌ๋ฆฌ๋ ๊ฒ์ด ๋น์ ํ์ฑ์ ์ ์งํ๋ ๊ฒฝ์ฐ๋ณด๋ค ๋ ์ฝ๊ฒ ๋ณ๋ ฌํํ ์ ์์ต๋๋ค.
์ด ์น์ ์์์ ์ฃผ์ ๊ฒฐ๋ก ์, RNN์์ ๋น์ ํ์ฑ์ ์ ๊ฑฐํ๋ ๊ฒ์ด ์ฑ๋ฅ์ ๋ถ์ ์ ์ธ ์ํฅ์ ๋ฏธ์น์ง ์์ผ๋ฉฐ, ์คํ๋ ค ์ ํ ๋ฐ๋ณต(Linear Recurrence)์ด ๋ณต์กํ ์ํ์ค-์ํ์ค ๋งต์ ๋ชจ๋ธ๋งํ ์ ์๋ค๋ ์ ์ ๋๋ค. ์ด๋ SSM์ด ์ ํ ๋ฐ๋ณต์ ํตํด ์ข์ ์ฑ๋ฅ์ ๋ด๋ ์ด์ ์๋ ๊ด๋ จ์ด ์์ต๋๋ค.
3.2 Using complex diagonal recurrent matrices is efficient (๋ณต์ ๋๊ฐ ๋ฐ๋ณต ํ๋ ฌ ์ฌ์ฉ์ ํจ์จ์ฑ)
๋ค์ ๋จ๊ณ์์๋ RNN์ ๋ฐ๋ณต ํ๋ ฌ์ ๋ณต์์ ๋๊ฐ ํ๋ ฌ(complex diagonal recurrent matrices)๋ก ์ฌ๊ตฌ์ฑํ์ฌ ํจ์จ์ฑ์ ํฌ๊ฒ ๊ฐ์ ํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช ํฉ๋๋ค.
- ๋ณต์์ ๋๊ฐ ํ๋ ฌ์ ์ฅ์ : RNN์ ์ ํ ๋ ์ด์ด๋ฅผ ๋ณต์์ ๋๊ฐ ํ๋ ฌ๋ก ๋ณํํ๋ฉด, ์ด ํ๋ ฌ์ ๋ ์ฝ๊ฒ ๋ณ๋ ฌํํ ์ ์์ต๋๋ค. ๋ณต์ ๋๊ฐ ํ๋ ฌ์ ํ๋ ฌ์ ๊ณ ์ ๊ฐ ๋ถํด(eigen decomposition)๋ฅผ ํตํด ์ฝ๊ฒ ์ฒ๋ฆฌํ ์ ์์ผ๋ฉฐ, ๊ฐ ํ์์คํ ์์์ ๊ณ์ฐ์ ๋ณ๋ ฌ๋ก ์ํํ ์ ์์ต๋๋ค. ์ด๋ ํ์ต ์๋๋ฅผ ํฌ๊ฒ ๋์ด๋ฉฐ, ํนํ ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ๋ ๊ฒฝ์ฐ ์ ๋ฆฌํฉ๋๋ค.
- ๋ณ๋ ฌํ์ ์๋ ํฅ์: ๋ณต์์ ๋๊ฐ ํ๋ ฌ์ ์ฌ์ฉํ๋ฉด, ์ ํ RNN์ ํ์ต๊ณผ ์ถ๋ก ์ ๋ณ๋ ฌํํ ์ ์์ด ๊ณ์ฐ ๋น์ฉ์ ์ค์ผ ์ ์์ต๋๋ค. ๊ธฐ์กด์ ๋น์ ํ RNN์ ํ์ต ์ ์์ฐจ์ ์ผ๋ก ๊ณ์ฐ์ด ์ด๋ฃจ์ด์ ธ์ผ ํ์ง๋ง, ๋ณต์์ ๋๊ฐ ํ๋ ฌ์ ์ฌ์ฉํ๋ฉด ๋ณ๋ ฌ ๊ณ์ฐ์ด ๊ฐ๋ฅํ๋ฏ๋ก ํ์ต ์๋์ ์ถ๋ก ์๋๊ฐ ํฌ๊ฒ ํฅ์๋ฉ๋๋ค.
๋ ผ๋ฌธ์์๋ ์ด ๊ณผ์ ์ด SSM์์๋ ์์ฃผ ์ฌ์ฉ๋๋ ๊ธฐ์ ์์ ๊ฐ์กฐํ๋ฉฐ, ๋ณต์ ๋๊ฐ ํ๋ ฌ์ด ๋ณต์กํ ์ํ์ค ๋ชจ๋ธ๋ง์์ ํจ์จ์ฑ์ ๋์ด๋ ๋ฐ ํฐ ์ญํ ์ ํ๋ค๊ณ ์ค๋ช ํฉ๋๋ค.
3.3 Stable Exponential Parameterization (์์ ์ ์ธ ์ง์ ๋งค๊ฐ๋ณ์ํ)
๋ค์์ผ๋ก ๋ ผ๋ฌธ์์๋ ์ง์ ๋งค๊ฐ๋ณ์ํ(exponential parameterization)๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ๋ณต ํ๋ ฌ์ ์์ ์ฑ(stability)์ ๋ณด์ฅํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช ํฉ๋๋ค.
- ์ง์ ๋งค๊ฐ๋ณ์ํ์ ์ด์ : ๋ฐ๋ณต ํ๋ ฌ์ ์ง์ ํจ์๋ก ๋งค๊ฐ๋ณ์ํํ๋ฉด, ํ์ต ์ค ๋ชจ๋ธ์ ์์ ์ฑ์ ์ ์งํ ์ ์์ต๋๋ค. ์ด๋ ํนํ ๊ธด ์ํ์ค์์ ๋ฐ๋ณต์ ์ผ๋ก ์ํ๋ฅผ ์ ๋ฐ์ดํธํ ๋, ๊ธฐ์ธ๊ธฐ์ ์์ค์ด๋ ํญ๋ฐ ๋ฌธ์ ๋ฅผ ๋ฐฉ์งํ๋ ๋ฐ ๋งค์ฐ ์ ์ฉํฉ๋๋ค.
- ๊ณ ์ ๊ฐ ๋ถํฌ์ ์ฑ๋ฅ ํฅ์: ์คํ ๊ฒฐ๊ณผ์ ๋ฐ๋ฅด๋ฉด, ๋ฐ๋ณต ํ๋ ฌ์ ๊ณ ์ ๊ฐ ๋ถํฌ๋ฅผ ์ ์ ํ ์ด๊ธฐํํ๋ฉด ๋ชจ๋ธ์ด ๊ธด ์ํ์ค ์์กด์ฑ์ ๋ ์ ํ์ตํ ์ ์์ผ๋ฉฐ, ์ด๋ SSM์์ ์ฑ๋ฅ์ด ๋ฐ์ด๋ ์ด์ ์ค ํ๋์ ๋๋ค. SSM์์ ๋ณต์กํ ์ด๊ธฐํ ๊ท์น์ ์ฌ์ฉํ๋ ๋์ , ๊ณ ์ ๊ฐ์ ๋ถํฌ๋ฅผ ์กฐ์ ํ์ฌ ํ์ต ์ฑ๋ฅ์ ๋์ผ ์ ์๋ค๋ ์ ์ ๋ณด์ฌ์ค๋๋ค.
3.4 Normalization (์ ๊ทํ)
๋ง์ง๋ง์ผ๋ก ๋ ผ๋ฌธ์์๋ ์ ๊ทํ(Normalization)์ ์ค์์ฑ์ ์ค๋ช ํฉ๋๋ค. ๊ธด ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ํ์ตํ ๋, RNN์ ์๋ ์ํ๋ฅผ ์ ์ ํ๊ฒ ์ ๊ทํํ๋ ๊ฒ์ด ๋งค์ฐ ์ค์ํฉ๋๋ค.
- ์ ๊ทํ์ ์ญํ : RNN์ ์๋ ์ํ๋ฅผ ์์ฐจ์ ์ผ๋ก ์ ๋ฐ์ดํธํ๋ฉฐ ํ์ต์ ์งํํ๋๋ฐ, ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ๋ ๊ณผ์ ์์ ์๋ ์ํ๊ฐ ๊ณผ๋ํ๊ฒ ์ปค์ง๊ฑฐ๋ ์์์ง ์ ์์ต๋๋ค. ์ด๋ก ์ธํด ๋ชจ๋ธ์ด ํ์ตํ๊ธฐ ์ด๋ ค์์ง ์ ์์ผ๋ฏ๋ก, forward pass์์ ์๋ ์ํ๋ฅผ ์ ๊ทํํ๋ ๊ฒ์ด ํ์์ ์ ๋๋ค.
- ์ ๊ทํ์ ์ฑ๋ฅ ํฅ์: ์ ๊ทํ๋ฅผ ์ ์ ํ ์ ์ฉํ๋ฉด ๊ธด ์ํ์ค์์ RNN์ด ๋ ์์ ์ ์ผ๋ก ํ์ตํ ์ ์์ผ๋ฉฐ, ์ด๋ ๋ ผ๋ฌธ์์ ์ ์ํ๋ LRU ๊ตฌ์กฐ๊ฐ SSM๊ณผ ์ ์ฌํ ์ฑ๋ฅ์ ๋ฐํํ ์ ์๋ ์ด์ ์ค ํ๋์ ๋๋ค. ๋ํ, ์ ๊ทํ๋ S4 ๋ชจ๋ธ์์ ์ฌ์ฉ๋๋ ๊ตฌ์กฐ์๋ ์ฐ๊ฒฐ๋๋ฉฐ, ์ด ๊ณผ์ ์ด SSM์์ ์ฑ๋ฅ์ ๊ทน๋ํํ๋ ๋ฐ ๊ธฐ์ฌํฉ๋๋ค.
๋ ผ๋ฌธ์์ ์ธ๊ธ๋๋ S4 ๋ชจ๋ธ๊ณผ ๊ทธ ๋ณํ๋ค(S4 and Variants)์ ๋ํ ์ธ์ฌ์ดํธ๋ ์ฃผ๋ก S4 ๋ชจ๋ธ์ด ๋ฐ์ด๋ ์ฑ๋ฅ์ ๋ฐํํ๋ ์์ธ๊ณผ ์ด ๋ชจ๋ธ์ ํน์ง์ ๋ํ ์ดํด๋ฅผ ๋ฐํ์ผ๋ก ํ๊ณ ์์ต๋๋ค.
1. S4์ ๊ตฌ์กฐ์ ํจ์จ์ฑ
์ธ์ฌ์ดํธ
: S4์ ๊ทธ ๋ณํ ๋ชจ๋ธ๋ค(DSS, S4D, S5 ๋ฑ)์ ์ํ์ค ๊ธธ์ด์ ๋ฐ๋ผ ์ฐ์ฐ ๋ณต์ก๋๋ฅผ ๋ฎ์ถ๋ ํจ์จ์ ์ธ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๊ณ ์์ต๋๋ค.์ด์
: S4๋ Transformer์ attention ๋ ์ด์ด๊ฐ ๊ฐ์ง๋ O(L2)O(L^2)O(L2) ๋ฉ๋ชจ๋ฆฌ์ ๊ณ์ฐ ๋ณต์ก๋๋ฅผ ํผํ ์ ์์ต๋๋ค. S4๋ ์์ฐจ์ ์ผ๋ก ํ ํฐ ๊ฐ์ ์ํธ์์ฉ์ ๋ชจ๋ธ๋งํ๋ฉฐ, ์ด๋ฅผ ํตํด ๊ธด ์ํ์ค๋ฅผ ํจ์จ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์์ต๋๋ค. ๋ํ, ์์ฐจ์ ๋ชจ๋ธ์์๋ ๋ถ๊ตฌํ๊ณ ํ๋ จ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ๊ฐ ๊ฐ๋ฅํ์ฌ ํ์ต ์๋๊ฐ ๋น ๋ฆ ๋๋ค.
2. ์ ํ ์ฌ๊ท ๊ตฌ์กฐ์ ์ฅ์
์ธ์ฌ์ดํธ
: S4๋ ์ฌ๊ท์ ์ฐ์ฐ์ ์ ํ์ ์ผ๋ก ์ฒ๋ฆฌํ๋ฏ๋ก RNN์ฒ๋ผ ๊ธด ์ํ์ค์ ๋ํ ์ ๋ณด ์ ํ๋ฅผ ๋น ๋ฅด๊ฒ ์ํํ ์ ์์ต๋๋ค.์ด์
: S4์ ์ฌ๊ท์ ๋ ์ด์ด๋ ๋น์ ํ์ฑ์ด ์๋ ์ ํ ์์คํ ์ผ๋ก ๊ตฌ์ฑ๋๋ฉฐ, ์ด๋ ํ์ต์ด ๋ณด๋ค ์์ ์ ์ด๊ณ ๋น ๋ฅด๊ฒ ์ด๋ฃจ์ด์ง๊ฒ ํฉ๋๋ค. ์ด ์ ํ์ฑ์ RNN๊ณผ ์ ์ฌํ ์ฑ์ง์ ๊ฐ์ง์ง๋ง, ๋ณ๋ ฌ ํ๋ จ์ด ๊ฐ๋ฅํ๋ค๋ ์ ์์ ํฐ ์ฅ์ ์ ๊ฐ์ง๋๋ค.
3. ๋ณต์ ๋๊ฐ ํ๋ ฌ ์ฌ์ฉ์ ์ด์
์ธ์ฌ์ดํธ
: S4๋ ๋ณต์์ ๋๊ฐ ํ๋ ฌ์ ์ฌ์ฉํจ์ผ๋ก์จ ๋ชจ๋ธ์ ์์ ์ฑ๊ณผ ์ฑ๋ฅ์ ๋์ผ ์ ์์ต๋๋ค.์ด์
: ๋ณต์์ ๋๊ฐ ํ๋ ฌ์ ์ฌ์ฉํ๋ฉด ๋ชจ๋ธ์ด ๋ ๋ง์ ์ ๋ณด๋ฅผ ํ์ตํ ์ ์์ผ๋ฉฐ, ์ฐ์ฐ ํจ์จ๋ ํฌ๊ฒ ํฅ์๋ฉ๋๋ค. ํนํ ๋ณต์์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ด์ฉํ์ฌ ๋ชจ๋ธ ์ด๊ธฐํ๋ฅผ ์์ ์ ์ผ๋ก ์ค์ ํ๊ณ , ํ๋ จ ์ ์ฅ๊ธฐ ์์กด์ฑ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์ ์์ต๋๋ค.
4. ํน์ ์ด๊ธฐํ์ ์ค์์ฑ
์ธ์ฌ์ดํธ
: S4 ๋ชจ๋ธ์ ์ฑ๋ฅ ํฅ์์ ํน์ ํ ์ด๊ธฐํ ๋ฐฉ๋ฒ์์ ๊ธฐ์ธํ์ง๋ง, ์ด๋ฌํ ์ด๊ธฐํ๊ฐ ํญ์ ๊ฒฐ์ ์ ์ธ ๊ฒ์ ์๋๋๋ค.์ด์
: ์ด๊ธฐํ๋ ๋ชจ๋ธ์ด ์ฅ๊ธฐ์ ์ธ ์์กด์ฑ์ ํ์ตํ๋ ๋ฐ ์ค์ํ์ง๋ง, ๋ ผ๋ฌธ์์๋ ์ด ์ด๊ธฐํ ๊ท์น์ด ํญ์ ์ด๋ก ์ ์ผ๋ก ์ต์ ์ ์๋ ์ ์์ผ๋ฉฐ, ๋ค๋ฅธ ๋ฐฉ๋ฒ์ผ๋ก๋ ์ ์ฌํ ์ฑ๋ฅ์ ์ป์ ์ ์๋ค๊ณ ์ฃผ์ฅํฉ๋๋ค.
5. ํ๋ จ ์๋ ๋ฐ ์ฑ๋ฅ์ ๊ท ํ
์ธ์ฌ์ดํธ
: S4 ๋ชจ๋ธ์ ๊ธด ์ํ์ค๋ฅผ ๋ค๋ฃฐ ๋๋ ๋น ๋ฅธ ํ์ต ์๋๋ฅผ ์ ์งํ๋ฉฐ, ์ฑ๋ฅ์ ์์ง ์๋ ๊ท ํ์ ์ ์ฐพ์์ต๋๋ค.์ด์
: S4์ ์ค๊ณ๋ RNN๊ณผ ์ ์ฌํ์ง๋ง ํ๋ จ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ํตํด ํ์ต ์๋๋ฅผ ๋์ผ ์ ์์ผ๋ฉฐ, ๋ณต์กํ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ์ ์งํ๋ฉด์๋ ๋น ๋ฅธ ํ๋ จ์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค.
๊ฒฐ๋ก ์ ์ผ๋ก, ์ด ๋ ผ๋ฌธ์ S4 ๋ชจ๋ธ๊ณผ ๊ทธ ๋ณํ๋ค์ด ๋์ ์ฑ๋ฅ์ ๋ฌ์ฑํ๋ ๋ฐ ํ์ํ ๋ค์ํ ๊ธฐํํ์ ๋ฐ ๊ณ์ฐ์ ์์๋ค์ ๋ถ์ํ๊ณ , ๊ธฐ์กด์ ๊ฐ์ ์ด ๋ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ผ๋ก๋ ๊ฒ์ฆ๋ ์ ์์์ ์ ์ํ๊ณ ์์ต๋๋ค.
์ด๋ฌํ ์ธ์ฌ์ดํธ๋ ํฅํ RNN ๋ฐ SSM ๋ชจ๋ธ ๊ฐ๋ฐ์ ์ค์ํ ๋ฐฉํฅ์ฑ์ ์ ๊ณตํ ๊ฒ์ผ๋ก ๊ธฐ๋๋๋ค๊ณ ํ๋ฉฐ ๊ธ์ ๋ง๋ฌด๋ฆฌํฉ๋๋ค.