[Paper Review] LLaVA-PruMerge: Adaptive Token Reduction for Efficient Large Multimodal Models

๋ ผ๋ฌธ ์ ๋ณด
- ์ ๋ชฉ: LLaVA-PruMerge: Adaptive Token Reduction for Efficient Large Multimodal Models
- ์ ์: Yuzhang Shang, Mu Cai, Bingxin Xu, Yong Jae Lee, Yan Yan
- ์์: Illinois Institute of Technology, University of WisconsinโMadison
- ๋ฐํ: ICCV 2025 (arXiv: 2024๋ 3์)
- ๋ ผ๋ฌธ ๋งํฌ: https://arxiv.org/abs/2403.15388
- GitHub: https://github.com/42Shawn/LLaVA-PruMerge
- Project Page: https://llava-prumerge.github.io
-
Introduction: LMM ํจ์จ์ฑ์ ์๋ก์ด ํจ๋ฌ๋ค์
1.1 Large Multimodal Models (LMMs)์ ๋ฑ์ฅ
- Large Language Models (LLMs)์ GPT-4, LLaMA, Mistral ๋ฑ์์ ๋ณด๋ฏ ๊ฐ๋ ฅํ ์ถ๋ก ๋ฅ๋ ฅ์ ๋ณด์ฌ์ฃผ๊ณ ์์ต๋๋ค. ์ด๋ฌํ LLM์ ๋๊ท๋ชจ ํ ์คํธ ์ฝํผ์ค๋ก ์ฌ์ ํ์ต๋ ๊ณ ์ฉ๋ Transformer ์ํคํ ์ฒ์ ๋๋ค.
- Large Multimodal Models (LMMs)์ LLM์ ํ ์คํธ ์์ฑ ๋ฅ๋ ฅ์ ๊ณ์นํ๋ฉด์, CLIP-ViT ๊ฐ์ Vision Encoder๋ฅผ ์ถ๊ฐํ์ฌ ์ด๋ฏธ์ง ํจ์น๋ฅผ visual tokens์ผ๋ก ๋ณํํฉ๋๋ค. ์ด visual tokens์ LLM์ prefix context๋ก ์ ๋ ฅ๋์ด ์๊ฐ์ ์ถ๋ก ์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค.
1
2
[Vision Encoder] โ Visual Tokens (prefix) โ [LLM] โ ํ
์คํธ ์๋ต
(CLIP-ViT) (576๊ฐ) (Vicuna/LLaMA)
1.2 LMM์ ๊ณ์ฐ ๋น์ฉ ๋ฌธ์
LMM์ ์ถ๋ก (inference)์ ์๋นํ ๊ณ์ฐ ๋น์ฉ์ด ํ์ํฉ๋๋ค.
์ด ๋น์ฉ์ ๊ตฌ์กฐ๋ฅผ ๋ถ์ํ๋ฉด:
| ๊ตฌ์ฑ์์ | ํ๋ผ๋ฏธํฐ ์ | ๋น๊ณ |
|---|---|---|
| Vision Encoder (ViT-L) | ~0.3B | ์๋์ ์ผ๋ก ์์ |
| LLM (LLaMA/Vicuna) | 7B~13B | ์ฃผ์ ๋น์ฉ ์์ธ |
๐ ํต์ฌ ํต์ฐฐ: Vision Encoder๋ LLM์ ๋นํด ๋งค์ฐ ์์ผ๋ฏ๋ก, LLM์ ์ถ๋ก ๋น์ฉ์ ์ค์ด๋ ๊ฒ์ด ์ ์ฒด LMM ํจ์จํ์ ํต์ฌ์ ๋๋ค.
1.3 ๊ธฐ์กด ์ ๊ทผ๋ฒ๊ณผ ํ๊ณ
์ด์ ์ฐ๊ตฌ๋ค์ ์ด๋ฌํ LLM ๋น์ฉ์ ์ค์ด๊ธฐ ์ํด ์๋์ ๊ฐ์ ์๋๋ค์ ์ํํ์์ต๋๋ค.
| ์ ๊ทผ๋ฒ | ๋ฐฉ๋ฒ | ํ๊ณ |
|---|---|---|
| Small LLM ์ฌ์ฉ | Phi-2 ๊ธฐ๋ฐ MobileVLM, TinyGPT-V | LLM ์ถ๋ก ๋ฅ๋ ฅ ํฌ์, VQAv2/MMBench์์ ํฐ ์ฑ๋ฅ ๊ฒฉ์ฐจ |
| Quantization | 4-bit, 8-bit ์์ถ | ํ๋ผ๋ฏธํฐ ์๋ ์ค์ง๋ง ๋ค๋ฅธ ๋ฌธ์ ๋ฏธํด๊ฒฐ |
1.4 ๊ฐ๊ณผ๋ ๋น์ฉ ์์ฒ: Input Context Length
ํ์ง๋ง, ์ ์ฐ๊ตฌ๋ค์์ ๊ฐ๊ณผํ ๋ด์ฉ์ผ๋ก โLLM์ ๋น์ฉ์ ํ๋ผ๋ฏธํฐ ์๋ฟ๋ง ์๋๋ผ ์ ๋ ฅ ์ปจํ ์คํธ ๊ธธ์ด์์๋ ๋ฐ์ํ๋ค.โ๋ผ๋ ์ฌ์ค์ ์ง์ ํฉ๋๋ค.
LLM = Transformer ์ํคํ
์ฒ:
LLM์ Transformer ๊ธฐ๋ฐ์ด๋ฉฐ, ํต์ฌ ์ฐ์ฐ์ธ Self-Attention์ ์
๋ ฅ๋ ๋ชจ๋ ํ ํฐ ์ ๊ฐ์ ๊ด๊ณ๋ฅผ ๊ณ์ฐํฉ๋๋ค.
Attention(Q,K,V)=softmax(QKTdk)โ V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \cdot VAttention(Q,K,V)=softmax(dkโโQKTโ)โ V
์ฌ๊ธฐ์ QKTQK^TQKT ์ฐ์ฐ์ N ร N ์ดํ ์ ๋งคํธ๋ฆญ์ค๋ฅผ ์์ฑํฉ๋๋ค (N = ์ ๋ ฅ ํ ํฐ ์). ๋ฐ๋ผ์ Self-Attention์ ๊ณ์ฐ ๋ณต์ก๋๋ ์ ๋ ฅ ๊ธธ์ด์ ๋ํด O(Nยฒ)์ ๋๋ค.
LMM์์์ ๋ฌธ์ :
- LMM์ ๊ณ ์ ๋ ๋๋์ visual tokens์ prefix๋ก ์ฌ์ฉ
- LLaVA-1.5: 576 visual tokens
- Video-LLaVA: 2,048+ tokens (๊ณ ํด์๋/๋น๋์ค ์ฒ๋ฆฌ ์)
- ์์ ๊ฐ์ ๊ตฌ์กฐ๋ก ์ธํ์ฌ Visual tokens ์๊ฐ ๋์ด๋ ์๋ก LLM์ ์ดํ ์ ์ฐ์ฐ๋์ด ์ ๊ณฑ์ผ๋ก ์ฆ๊ฐ
๐ ํต์ฌ ์ง๋ฌธ: Prefix visual tokens์ ์๋ฅผ ์ค์ด๋ฉด์๋ ์ฑ๋ฅ์ ์ ์งํ ์ ์๋๊ฐ?
1.5 ํต์ฌ ๊ด์ฐฐ: Visual Tokens์ Redundancy
๋ณธ ์ฐ๊ตฌ์์ ๋ฐ๊ฒฌํ ์ค์ํ ํ์:
๊ด์ฐฐ 1: Sparse Attention Distribution
Vision Encoder์ self-attention์์ [CLS] ํ ํฐ๊ณผ spatial patches ๊ฐ์ ์ดํ ์ ์ด sparseํฉ๋๋ค. ์ด๋ ์์์ visual tokens๋ง์ด ํต์ฌ ์๊ฐ ์ ๋ณด์ ์ฐ๊ด๋จ์ ์๋ฏธํฉ๋๋ค.
๊ด์ฐฐ 2: ๋๋ถ๋ถ์ Visual Tokens์ Redundant
๊ธฐ์กด ์ฐ๊ตฌ(Bolya et al., 2023; Liu et al., 2022)์ ์ผ๊ด๋๊ฒ, ๋๋ถ๋ถ์ visual tokens์ ์ฑ๋ฅ ์ ํ ์์ด ์ ๊ฑฐ(prune)๋ ์ ์์ต๋๋ค.
1.6 ์ ์ ๋ฐฉ๋ฒ: PruMerge ๊ฐ์
์ด๋ฌํ sparse similarity๋ฅผ ํ์ฉํ์ฌ ์ค์ํ visual tokens์ ์ ์์ ์ผ๋ก ์ ํํ๋ ๋ฐฉ๋ฒ์ ์ ์ํฉ๋๋ค.

PruMerge์ ํต์ฌ ์์ด๋์ด:
-
Adaptive Token Selection (Prune)
- Interquartile Range (IQR) ๊ธฐ๋ฐ outlier detection
- [CLS] ์ดํ ์ ๊ฐ์ด ๋์ ํ ํฐ์ ์ค์ ํ ํฐ์ผ๋ก ์ ๋ณ
-
Token Merging (Merge)
- IQR๋ก 32๊ฐ ํ ํฐ๋ง ์ ํํ๋ฉด, ๋๋จธ์ง 544๊ฐ ํ ํฐ์ ์ ๋ณด๊ฐ ์์ ํ ์์ค๋ ์ ์์
- ์ด๋ฅผ ๋ฐฉ์งํ๊ธฐ ์ํด k-nearest neighbor ๊ธฐ๋ฐ ํด๋ฌ์คํฐ๋ง
- ์ ํ๋ ํ ํฐ์ weighted averaging์ผ๋ก ์ ๋ฐ์ดํธ
- ์ ๊ฑฐ๋ ํ ํฐ์ ์ ๋ณด๋ฅผ ๋ณด์กด์ ๋ชฉ์ ์ผ๋ก ํจ
-
PruMerge+ (ํ์ฅ)
- ๋๋ฌด ๊ณต๊ฒฉ์ ์ธ ์์ถ์ผ๋ก ์ธํ ์ฑ๋ฅ ์ ํ ๊ฐ๋ฅ์ฑ ์กด์ฌ
- Spatial-uniform sampling ์ถ๊ฐ
- ๋ ํฌ๊ด์ ์ด๊ณ ๋ํ์ฑ ์๋ ํ ํฐ ์ ํ ๋ณด์ฅ
๐ท ๋ณธ ๋ ผ๋ฌธ์ ๊ธฐ์ฌ์
- Visual token redundancy ๋ถ์: [CLS]-spatial attention์ sparsity ๊ด์ฐฐ
- PruMerge ์ ์: ์ ์์ ํ ํฐ ์ ํ ๋ฐ ๋ณํฉ ์ ๋ต
- Plug-and-play ์ ์ฉ: ๊ธฐ์กด LMM์ ์ถ๊ฐ ํ์ต ์์ด ์ ์ฉ ๊ฐ๋ฅ
- ๋ค์ํ ๋ชจ๋ฌ๋ฆฌํฐ ํ์ฅ: ์ด๋ฏธ์ง๋ฟ ์๋๋ผ ๋น๋์ค(Video-LLaVA)์๋ ์ ์ฉ ๊ฐ๋ฅ
2.1 Efficient Large Multimodal Models
1. Compact Architecture (์์ ๋ชจ๋ธ ์ฌ์ฉ)
MobileVLM / MobileVLM-v2
๋ชฉํ: ๋ชจ๋ฐ์ผ ๋๋ฐ์ด์ค์์ ์คํ ๊ฐ๋ฅํ LMM
1
2
3
4
์ผ๋ฐ LMM: Vision Encoder โ [Vicuna-7B] โ ์๋ต
MobileVLM: Vision Encoder โ [MobileLLaMA-1.4B] โ ์๋ต
โ
5๋ฐฐ ์์ LLM
ํน์ง:
- ๋ชจ๋ฐ์ผ ์ต์ ํ LLM backbone ์ฌ์ฉ
- ๊ฒฝ๋ํ๋ projector ์ค๊ณ
TinyGPT-V
๋ชฉํ: ์์ LLM์ผ๋ก๋ ์ข์ ์ฑ๋ฅ ๋ฌ์ฑ
1
2
3
4
๊ธฐ์กด LLaVA: Vision Encoder โ [Vicuna-7B] โ ์๋ต
TinyGPT-V: Vision Encoder โ [Phi-2 (2.7B)] โ ์๋ต
โ
Microsoft์ ์ํ LLM
ํน์ง:
- Phi-2์ ๊ฐ๋ ฅํ reasoning ๋ฅ๋ ฅ ํ์ฉ
- 7B ๋๋น ์ฝ 3๋ฐฐ ์์ ๋ชจ๋ธ
LLaVA-Phi
๋ชฉํ: Phi ๊ธฐ๋ฐ ํจ์จ์ LMM
1
Vision Encoder โ [Phi-2] โ ์๋ต
ํน์ง:
- ์์ backbone + ํฅ์๋ vocabulary
- ๋ ๋์ ์ผ๋ฐํ ์ฑ๋ฅ ์ถ๊ตฌ
TinyLLaVA
๋ชฉํ: ์ํคํ ์ฒ ์ ํ๊ณผ ํ์ต ์ต์ ํ ์ฐ๊ตฌ
ํ๊ตฌ ๋ด์ฉ:
- ์ด๋ค Vision Encoder๊ฐ ์ต์ ์ธ๊ฐ?
- ์ด๋ค Projector๊ฐ ์ต์ ์ธ๊ฐ?
- ์ด๋ค ํ์ต ์ ๋ต์ด ์ต์ ์ธ๊ฐ?
๊ฒฐ๋ก : ์์ ๋ชจ๋ธ๋ ์ต์ ํํ๋ฉด ํฐ ๋ชจ๋ธ๊ณผ ์ ์ฌํ ์ฑ๋ฅ ๊ฐ๋ฅ
MoE-LLaVA
๋ชฉํ: Mixture of Experts๋ก ํจ์จ์ฑ ํฅ์
1
2
3
4
5
์ผ๋ฐ LLM: ๋ชจ๋ ํ๋ผ๋ฏธํฐ๊ฐ ํญ์ ํ์ฑํ
MoE-LLM: Expert 1 Expert 2 Expert 3 Expert 4
โ โ
Router๊ฐ ์ ํํ Expert๋ง ํ์ฑํ (sparse)
ํน์ง:
- ์ ์ฒด ํ๋ผ๋ฏธํฐ๋ ๋ง์ง๋ง, ์ถ๋ก ์ ์ผ๋ถ๋ง ์ฌ์ฉ
- ๊ณ์ฐ๋ ๊ฐ์ + ์ฑ๋ฅ ์ ์ง
2. Quantization & Compression
4/8-bit Quantization
๋ชฉํ: ํ๋ผ๋ฏธํฐ ์ ๋ฐ๋๋ฅผ ๋ฎ์ถฐ ๋ฉ๋ชจ๋ฆฌ/์ฐ์ฐ ์ ์ฝ
1
2
3
4
5
6
7
8
๊ธฐ์กด (FP16):
W = [0.1234, -0.5678, 0.9012, ...] โ ๊ฐ ์ซ์๊ฐ 16 bits
INT8 Quantization:
W = [0.12, -0.57, 0.90, ...] โ ๊ฐ ์ซ์๊ฐ 8 bits๋ก ๊ทผ์ฌ
INT4 Quantization:
W = [0.1, -0.6, 0.9, ...] โ ๊ฐ ์ซ์๊ฐ 4 bits๋ก ๊ทผ์ฌ
ํ๊ณ:
- ์์ถ ๋์: ๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ (weights)
- ํ ํฐ ์๋ ๊ทธ๋๋ก โ attention ์ฐ์ฐ๋ ๋์ผ
3. Vision-Language Connectors
Vision Encoder ์ถ๋ ฅ์ LLM ์ ๋ ฅ์ผ๋ก ๋ณํํ๋ ๋ชจ๋๋ค์ ๋๋ค.
MLP Projector (LLaVA)
1
Visual Token (1024-dim) โ [Linear โ GELU โ Linear] โ LLM Token (4096-dim)
ํน์ง:
- ๊ฐ์ฅ ๋จ์ํ ๊ตฌ์กฐ
- ํ ํฐ ์ ๋ณํ ์์ (576 โ 576)
Q-Former (BLIP-2)
1
2
3
4
5
Visual Tokens (576๊ฐ)
โ
[Q-Former] โ Learnable Query Tokens (32๊ฐ)์ cross-attention
โ
Query Outputs (32๊ฐ)
ํน์ง:
- Learnable query๊ฐ visual ์ ๋ณด๋ฅผ โ์ง์โ
- ํ ํฐ ์ ๊ฐ์ (576 โ 32)
- ํ์ง๋ง ๊ณ ์ ๋ ์์ query ์ฌ์ฉ (adaptive ์๋)
Resampler (Flamingo)
1
2
3
4
5
Visual Tokens (๊ฐ๋ณ)
โ
[Perceiver Resampler] โ Latent Tokens์ cross-attention
โ
Fixed-size Output (64๊ฐ)
ํน์ง:
- ๋ค์ํ ํด์๋ ์ ๋ ฅ ์ฒ๋ฆฌ ๊ฐ๋ฅ
- ๊ณ ์ ๋ ์์ ์ถ๋ ฅ ํ ํฐ
Connector ๋น๊ต
| Connector | ์ ๋ ฅ ํ ํฐ | ์ถ๋ ฅ ํ ํฐ | Adaptive? |
|---|---|---|---|
| MLP (LLaVA) | 576 | 576 | โ |
| Q-Former (BLIP-2) | 576 | 32 | โ (๊ณ ์ ) |
| Resampler (Flamingo) | ๊ฐ๋ณ | 64 | โ (๊ณ ์ ) |
| PruMerge | 576 | ์ฝ 32(์ ๋์ ) | โ (์ ์์ ) |
2.2 Token Reduction Methods
Sparse Attention
Linformer
- ๋ฌธ์ : Self-attention์ O(Nยฒ) ๋ณต์ก๋
- ํด๊ฒฐ: Key, Value๋ฅผ ์ ์ฐจ์์ผ๋ก projection
1
2
3
4
5
6
7
๊ธฐ์กด Attention:
Q (Nรd) @ K^T (dรN) = NรN ํ๋ ฌ โ O(Nยฒ)
Linformer:
K' = E @ K (kรd, where k << N)
V' = F @ V (kรd)
Q @ K'^T = Nรk ํ๋ ฌ โ O(Nรk) โ O(N)

ํ๊ณ: LMM์ ์ง์ ์ ์ฉ ์ด๋ ค์ (prefix ๊ตฌ์กฐ)
ReFormer (Reformer)
- ๋ฌธ์ : ๊ธด ์ํ์ค์ attention ๋น์ฉ
- ํด๊ฒฐ: Locality-Sensitive Hashing (LSH) ์ฌ์ฉ
1
2
3
4
5
๊ธฐ์กด: ๋ชจ๋ ํ ํฐ ์์ attention ๊ณ์ฐ
ReFormer:
1. LSH๋ก ์ ์ฌํ ํ ํฐ๋ผ๋ฆฌ bucket ๋ถ๋ฅ
2. ๊ฐ์ bucket ๋ด์์๋ง attention ๊ณ์ฐ

1
2
3
4
[ํ ํฐ๋ค] โ [LSH Hashing] โ [Bucket 1] [Bucket 2] [Bucket 3]
โ โ โ
Attention Attention Attention
(๋ด๋ถ๋ง) (๋ด๋ถ๋ง) (๋ด๋ถ๋ง)
ํ๊ณ: ์ฌ์ ํ ๋ชจ๋ ํ ํฐ ์ ์ง, ์ฐ์ฐ ๋ฐฉ์๋ง ๋ณ๊ฒฝ
Token Merging
ToMe (Bolya et al., 2023)
- ๋ชฉํ: ViT ๋ด๋ถ์์ ์ ์ง์ ์ผ๋ก ํ ํฐ ์ ๊ฐ์
- ๋ฐฉ๋ฒ: Bipartite Matching์ผ๋ก ์ ์ฌํ ํ ํฐ ๋ณํฉ
1
2
3
4
5
6
7
8
ViT Block 1: 576 tokens
โ (merge)
ViT Block 2: 500 tokens
โ (merge)
ViT Block 3: 450 tokens
โ (merge)
...
์ต์ข
: 1 token (class token)
Bipartite Matching:
1
2
3
4
5
6
7
8
9
ํ ํฐ๋ค์ ๋ ๊ทธ๋ฃน์ผ๋ก ๋ถํ :
Group A: [T1, T3, T5, ...]
Group B: [T2, T4, T6, ...]
์ ์ฌํ ์ ๋งค์นญ ํ ๋ณํฉ:
T1 + T2 โ T'1
T3 + T4 โ T'2
...
๊ธฐ์กด Token Reduction vs PruMerge ๋น๊ต
| ํญ๋ชฉ | ToMe (๊ธฐ์กด) | PruMerge |
|---|---|---|
| ์ ์ฉ ์์น | ViT ๋ด๋ถ (layer-by-layer) | ViT ์ถ๋ ฅ ํ (ํ ๋ฒ์) |
| ๋ชฉํ | ViT ์ฐ์ฐ ๊ฐ์ | LLM ์ฐ์ฐ ๊ฐ์ |
| ์ถ๋ ฅ | Single [CLS] token | Multiple visual tokens |
| ๊ฐ์ ๋ฐฉ์ | ์ ์ง์ (576โ500โ450โโฆ) | ํ ๋ฒ์ (576โ32) |
| Adaptive | โ (๊ณ ์ ๋น์จ) | โ (์ด๋ฏธ์ง๋ณ ๋ค๋ฆ) |
-
Method: Token Pru-Merging
3.1 Preliminaries
Vision Transformers (ViTs)

๊ตฌ์กฐ:
1
2
3
4
5
6
7
8
9
10
11
12
Input Image
โ (Patch embedding)
Patch Tokens (576 tokens for 336ร336 image with 14ร14 patches)
+ Class Token ([CLS])
โ
Transformer Blocks (ร24 for ViT-L/14)
โโ Multi-head Self-Attention (MSA)
โโ Feed-Forward Network (FFN)
โโ Skip connections
โโ Layer Normalization
โ
Output Tokens
Self-Attention ๋ฉ์ปค๋์ฆ:
1
2
3
4
5
6
7
8
# Query, Key, Value ๊ณ์ฐ
Q = X ยท Wq
K = X ยท Wk
V = X ยท Wv
# Attention ๊ณ์ฐ
A = softmax(Q ยท K^T / โdk)
Y = A ยท V
Class Token Attention:
1
2
# [CLS] token๊ณผ visual tokens ๊ฐ์ attention
a_cls = softmax(q_cls ยท K^T / โdk)
ํต์ฌ ๊ด์ฐฐ:
a_cls์ ๋ถํฌ๊ฐ ๋งค์ฐ sparse- ์์์ visual tokens๋ง ๋์ attention ๊ฐ
- ๋๋ถ๋ถ์ tokens๋ near-zero attention
Large Multimodal Models (LMMs)

Pipeline:
1
2
3
4
5
Image X_v โ [Vision Encoder] โ Z_v โ [Projector W] โ H_v
โ
Text X_q โ [Tokenizer] โโโโโโโโโโโโโโโโโโโโโโโโโโโ H_q
โ
[LLM f_ฮธ] โ Response Y_a
Computational Cost:
- N tokens โ N ร N attention matrix
- Quadratic complexity: O(Nยฒ)
- Visual tokens๊ฐ ๋ง์์๋ก ๋น์ฉ ๊ธ์ฆ
๐ผ LLaVa PruMerge ๋ชฉํ: Visual tokens ์๋ฅผ ์ค์ฌ LLM์ computational cost ๊ฐ์

3.2 Adaptive Important Token Selection via Outlier Detection
ํต์ฌ ์ง๋ฌธ: โ๊ฐ visual token์ ์ค์๋๋ฅผ ์ด๋ป๊ฒ ํ๋จํ๋๊ฐ?โ
๋ ๊ฐ์ง ๊ทน๋จ์ ํจ๋ฌ๋ค์
| ํจ๋ฌ๋ค์ | ํ ํฐ ์ | ํน์ง |
|---|---|---|
| LMM | 576๊ฐ (์ ๋ถ ์ฌ์ฉ) | ์์ธํ ์๊ฐ ์ ๋ณด ํํ |
| CLIP | 1๊ฐ ([CLS]๋ง ์ฌ์ฉ) | ๊ฐ์ฅ ์์ถ๋ ์ ๋ณด ํํ |
๊ท ํ์ ํ์: [CLS]-Visual Attention ์กฐ์ฌ
์ด ๋ ๊ทน๋จ์ ๊ท ํ์ ์ ์ฐพ๊ธฐ ์ํด, [CLS] token๊ณผ visual tokens ๊ฐ์ attention์ ์กฐ์ฌํฉ๋๋ค.

๊ด์ฐฐ ๊ฒฐ๊ณผ (Figure 3a):

- Y์ถ: Log(Class Attention Value)
- X์ถ: Visual Token Index (0-575)
๋ถํฌ ํน์ฑ:
- ๋๋ถ๋ถ์ tokens: near-zero attention
- ์์์ tokens: ๋งค์ฐ ๋์ attention (outliers)
ํต์ฌ ๋ฐ๊ฒฌ: Attention ๋ถํฌ๊ฐ ๋งค์ฐ sparseํจ
โ ์์์ visual tokens๋ง ํต์ฌ ์๊ฐ ์ ๋ณด์ ์ฐ๊ด๋จ
๊ด์ฐฐ ๊ฒฐ๊ณผ (Figure 3b):

PruMerge: ์ ๋ณด๊ฐ ์ค์ํ ๊ณณ๋ง ์ ํ โ ํจ์จ์ ์ด์ง๋ง ์ผ๋ถ ์ ๋ณด ์์ค ๊ฐ๋ฅPruMerge+: ์ค์ํ ๊ณณ + ๊ท ๋ฑ ์ํ๋ง โ ์ฝ๊ฐ์ ํ ํฐ ์ฆ๊ฐ๋ก ์ปค๋ฒ๋ฆฌ์ง ๋ณด์ฅ
IQR (Interquartile Range) ๊ธฐ๋ฐ Outlier Detection
Sparseํ attention ๋ถํฌ์์ outlier = ์ค์ํ ํ ํฐ์ผ๋ก ํ๋จํฉ๋๋ค.
์๊ณ ๋ฆฌ์ฆ:
1
2
3
4
5
6
7
8
9
10
11
12
# 1. Attention ๊ฐ์ quartiles ๊ณ์ฐ
Q1 = percentile(a_cls, 25) # 1์ฌ๋ถ์์
Q3 = percentile(a_cls, 75) # 3์ฌ๋ถ์์
# 2. IQR ๊ณ์ฐ
IQR = Q3 - Q1
# 3. Upper fence (threshold) ๊ณ์ฐ
upper_fence = Q3 + 1.5 * IQR
# 4. Outliers = Important tokens
important_indices = where(a_cls > upper_fence)

์ IQR์ธ๊ฐ?
- Attention score๋ ์์์ด๋ฏ๋ก upper fence๋ง ์ฌ์ฉ
- ๊ฐ ์ด๋ฏธ์ง์ ๋ถํฌ์ ๋ฐ๋ผ threshold๊ฐ ์๋ ์กฐ์
- ํต๊ณ์ ์ผ๋ก ๊ฒ์ฆ๋ robustํ outlier detection
Adaptive Selection์ ํน์ฑ
์ด๋ฏธ์ง ๋ณต์ก๋์ ๋ฐ๋ฅธ ์๋ ์กฐ์ :
| ์ด๋ฏธ์ง ์ ํ | ํน์ฑ | ์ ํ ํ ํฐ ์ |
|---|---|---|
| ๋ณต์กํ ์ด๋ฏธ์ง (ํ ์คํธ ๅค) | Attention outlier ๅค | ๋ง์ (40-50๊ฐ) |
| ๋จ์ํ ์ด๋ฏธ์ง (ํ๋+๊ฐํ) | Attention outlier ๅฐ | ์ ์ (10-20๊ฐ) |

๋ฒค์น๋งํฌ๋ณ ํ๊ท ํ ํฐ ์ (Table 4):
๋น๊ต ๋์ (๋์ผํ ํ ํฐ ์์์):
| ๋ฐฉ๋ฒ | ์ค๋ช |
|---|---|
| LLaVA-PruMerge | IQR ๊ธฐ๋ฐ adaptive selection |
| Sequential | ์์์๋ถํฐ ์์ฐจ์ ์ผ๋ก N๊ฐ ์ ํ |
| Spatial | ๊ณต๊ฐ์ ์ผ๋ก ๊ท ๋ฑํ๊ฒ N๊ฐ ์ ํ (์: 5ร8, 8ร5) |
1
2
3
4
5
6
7
8
9
10
Sequential ์ ํ:
[T1, T2, T3, ..., T40] โ ์์ชฝ 40๊ฐ๋ง ์ ํ
์ด๋ฏธ์ง ํจ์น ์์:
โโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ T1 T2 T3 ... T24 โ โ ์๋จ๋ง ์ ํ๋จ
โ T25 T26 T27 ... T48 โ
โ ... โ โ ํ๋จ์ ์์ ๋ฌด์
โ T553 ... T576 โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1
2
3
4
5
6
7
8
Spatial ์ ํ (5ร8 = 40):
โโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ โ ยท ยท ยท โ ยท ยท ยท โ
โ ยท ยท ยท ยท ยท ยท ยท ยท โ
โ โ ยท ยท ยท โ ยท ยท ยท โ
โ ยท ยท ยท ยท ยท ยท ยท ยท โ
โ โ ยท ยท ยท โ ยท ยท ยท โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1
2
3
4
5
6
7
8
PruMerge ์ ํ:
โโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ ยท ยท โ โ โ ยท ยท ยท โ
โ ยท ยท โ โ โ ยท ยท ยท โ โ ํ
์คํธ/๊ฐ์ฒด ์์ญ์ ์ง์ค
โ ยท ยท โ โ โ ยท ยท ยท โ
โ ยท ยท ยท ยท ยท ยท ยท ยท โ
โ ยท ยท ยท ยท ยท ยท ยท ยท โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Task: TextVQA (40 tokens)
| Approach | Performance |
|---|---|
| LLaVA-PruMerge | 54.00 |
| Sequential | 42.72 |
| Spatial (5ร8) | 46.85 |
| Spatial (8ร5) | 47.42 |
Task: MME (40 tokens)
| Approach | Performance |
|---|---|
| LLaVA-PruMerge | 1250.07 |
| Sequential | 703.60 |
| Spatial (5ร8) | 1180.23 |
| Spatial (8ร5) | 1142.32 |
Task: POPE (35 tokens)
| Approach | Performance |
|---|---|
| LLaVA-PruMerge | 76.2 |
| Sequential | 11.7 |
| Spatial (5ร7) | 69.8 |
| Spatial (7ร5) | 71.1 |
Task: ScienceQA (16 tokens)
| Approach | Performance |
|---|---|
| LLaVA-PruMerge | 68.07 |
| Sequential | 64.20 |
| Spatial (4ร4) | 66.29 |
Penultimate Layer ์ฌ์ฉ
์ ๋ง์ง๋ง layer๊ฐ ์๋ penultimate (๋์์ ๋ ๋ฒ์งธ) layer?
- ๋ง์ง๋ง layer: Classification์ ํนํ
- Penultimate layer: ๋ richํ feature representation ๋ณด์
3.3 Token Supplement via Similar Key Clustering
โWhile pruned tokens may initially seem extraneous, they hold potential value for the perception capabilities of the LLM backbone.โ
๋ฌธ์ : Pruned Tokens์ ์ ๋ณด ์์ค
Pruned tokens๋ฅผ ์์ ํ ๋ฒ๋ฆฌ๋ฉด:
- ํฐ ๊ฐ์ฒด๊ฐ scene์ ์ง๋ฐฐํ๋ ๊ฒฝ์ฐ ์ ๋ณด ์์ค
- ๋ชจ๋ธ์ representation ๋ฅ๋ ฅ ์ ํ ๊ฐ๋ฅ
์์: ํฐ ๊ฐ์ฒด๊ฐ ํ๋ฉด์ ์ง๋ฐฐํ๋ ๊ฒฝ์ฐ
1
2
3
4
5
6
7
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ ๐ ๐ ๐ ๐ ๐ ๐ ๐ ๐ ๐ ๐ โ
โ ๐ ๐ ๐ ๐ ๐ ๐ ๐ ๐ ๐ ๐ โ
โ ๐ ๐ ๐ ๐ ๐ ๐ ๐ ๐ ๐ ๐ โ โ ์ฝ๋ผ๋ฆฌ๊ฐ ์ด๋ฏธ์ง ๋๋ถ๋ถ ์ฐจ์ง
โ ๐ ๐ ๐ ๐ ๐ ๐ ๐ ๐ ๐ ๐ โ
โ ๐ฟ ๐ฟ ๐ฟ ๐ฟ ๐ฟ ๐ฟ ๐ฟ ๐ฟ ๐ฟ ๐ฟ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
IQR Outlier Selection ๊ฒฐ๊ณผ:
- ์ฝ๋ผ๋ฆฌ์ ๋, ๊ท ๋ฑ ํน์ง์ ์ธ ๋ถ๋ถ๋ง ์ ํ (5-10๊ฐ)
- ๋๋จธ์ง ์ฝ๋ผ๋ฆฌ ๋ชธํต ๋ถ๋ถ์ pruned๋จ
๋ฌธ์ : ์ฝ๋ผ๋ฆฌ ์ ์ฒด๋ฅผ ํํํ๊ธฐ์ ์ ๋ณด ๋ถ์กฑ
ํด๊ฒฐ์ฑ : Pruned tokens๋ฅผ ๋ฒ๋ฆฌ์ง ์๊ณ ์ ํ๋ ํ ํฐ์ ๋ณํฉ(merge)ํด์ฃผ๋ฉด, ๊ทธ ํน์ง์ ์ด๋ ค์ค ์ ์์ง ์์๊น?
Token Similarity ์ธก์ : Key Vector ํ์ฉ
โSince the key vector of each patch token already contains information summarized in the self-attention module, the final layerโs key vector serves as the representation.โ
์ Key vector์ธ๊ฐ?
- Self-attention์์ key vector๋ ์ด๋ฏธ ํด๋น ํ ํฐ์ ์ ๋ณด๋ฅผ ์์ฝ
- ๋ณ๋ ๊ณ์ฐ ์์ด ์ฌ์ฌ์ฉ ๊ฐ๋ฅ
Similarity ๊ณ์ฐ:
Sim(yi,yj)=kiโ kjT\text{Sim}(y_i, y_j) = \mathbf{k}_i \cdot \mathbf{k}_j^TSim(yiโ,yjโ)=kiโโ kjTโ
์ ์ฒด ํ ํฐ์ ๋ํด ๋ฒกํฐํ: KKT\mathbf{K}\mathbf{K}^TKKT
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Similarity Matrix (576 ร 576):
# K: ๋ชจ๋ ํ ํฐ์ Key vectors [576, d_k]
# d_k: key dimension (์: 64)
T0 T1 T2 T3 ... T575
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
T0 โ 1.0 0.3 0.1 0.8 ... 0.2 โ
T1 โ 0.3 1.0 0.7 0.2 ... 0.4 โ
T2 โ 0.1 0.7 1.0 0.1 ... 0.3 โ
T3 โ 0.8 0.2 0.1 1.0 ... 0.5 โ
... โ ... ... ... ... ... ... โ
T575 โ 0.2 0.4 0.3 0.5 ... 1.0 โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
similarity_matrix[i][j] = token i์ token j์ ์ ์ฌ๋
K-Nearest Neighbor Clustering & Weighted Merge
๊ณผ์ :
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def token_merge(K, Y, a_cls, unpruned_indices, k=32):
"""
K: Key vectors [576, d_k]
Y: Token features [576, d]
a_cls: Class attention [576]
unpruned_indices: IQR๋ก ์ ํ๋ ์ธ๋ฑ์ค [m]
k: neighbors ์
"""
# Step 1: ์ ์ฌ๋ ํ๋ ฌ ๊ณ์ฐ
similarity_matrix = K @ K.T # [576, 576]
# Step 2: ๊ฐ center์ ๋ํด merge
merged_tokens = []
for p in unpruned_indices:
# ์ ์ฌ๋ ๊ธฐ๋ฐ k-nearest neighbors
sims = similarity_matrix[p] # [576]
neighbor_idx = argsort(sims)[-k:] # top-k indices
# Class attention ๊ฐ์ค์น
weights = a_cls[neighbor_idx] # [k]
# Weighted sum
merged = (weights @ Y[neighbor_idx]) / weights.sum()
merged_tokens.append(merged)
return stack(merged_tokens) # [m, d]
ํต์ฌ:
- ์ ํ๋ ํ ํฐ = Cluster center
- Pruned tokens = ๊ฐ์ฅ ์ ์ฌํ center์ ๋ณํฉ
- Class attention์ ๊ฐ์ค์น๋ก ์ฌ์ฉ โ ์ค์ํ ์ ๋ณด ๋ ๋ง์ด ๋ฐ์
(๊ฐ์ธ์๊ฒฌ) ์ฝ๋ ๊ตฌํ์ฒด์์๋ k๋ฅผ 32๋ก ๊ณ ์ ํด๋๋๋ฐ, ์ด๋ฅผ dynamicํ๊ฒ ๋ฐ๊พธ๋๊ฒ ๋ ๋ง์ง ์์๊น? ๐ค (๋งํฌ)
1
2
3
4
5
6
7
# 1. Cosine Similarity ๊ณ์ฐ (KK^T ๋์ normalized dot product)
cos_sim_matrix = torch.bmm(key_others_norm, rest_Keys.transpose(1, 2))
## bmm : Batch ๋จ์๋ก ํ๋ ฌ ๊ณฑ์
์ ์ํํ๋ ํจ์
# 2. Top-k Nearest Neighbors ์ ํ โ ์ด๊ฒ KNN!
_, cluster_indices = torch.topk(cos_sim_matrix, k=int(32), dim=2, largest=True)
## topk: Tensor์์ ๊ฐ์ฅ ํฐ (๋๋ ์์) k๊ฐ์ ๊ฐ๊ณผ ์ธ๋ฑ์ค๋ฅผ ๋ฐํ
3.4 PruMerge+: Bridging the Efficiency-Performance Gap
๋ฌธ์ : PruMerge์ ์ฑ๋ฅ ๊ฒฉ์ฐจ
PruMerge๋ ~14๋ฐฐ ์์ถ (5.5% tokens)์ ๋ฌ์ฑํ์ง๋ง:
- ์๋ณธ LLaVA ๋๋น marginal performance drop ๋ฐ์
- ํน์ ์์ญ์ ํ ํฐ์ด ํธ์ค๋ ์ ์์
ํด๊ฒฐ์ฑ : Spatial Uniform Sampling ์ถ๊ฐ
PruMerge+ ์ ๋ต:
1
Final Tokens = Attention-based Outliers + Spatially-uniform Samples
์๊ณ ๋ฆฌ์ฆ:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Step 1: IQR๋ก outlier ratio ๊ณ์ฐ
if if_adaptive:
reduction_ratio = outlier_dectection(cls_attn) # ์: 0.05 (5%)
# Step 2: Top-k๋ก outlier ์ ํ
_, idx = torch.topk(cls_attn, int(N * reduction_ratio), dim=1, largest=True)
# idx: [B, ~32] โ IQR ๊ธฐ๋ฐ ์ ํ๋ ์ธ๋ฑ์ค
# Step 3: Spatial Uniform Sampling
if if_adaptive:
step_length = int(1 / reduction_ratio) # ์: 1/0.05 = 20
# ๊ท ๋ฑ ๊ฐ๊ฒฉ์ผ๋ก ์ํ๋ง (step_length/3 ๊ฐ๊ฒฉ)
arithmetic_sequence = torch.arange(0, 575, int(step_length / 3))
# ์: step=20 โ step/3โ6 โ [0, 6, 12, 18, 24, ..., 570]
# ์ด๋ฏธ ์ ํ๋ ์ธ๋ฑ์ค ์ ์ธ (์ค๋ณต ์ ๊ฑฐ)
original_tensor_1d = idx.flatten()
filtered_sequence = [x for x in arithmetic_sequence if x not in original_tensor_1d]
# Step 4: Union (ํฉ์งํฉ)
concatenated_tensor = torch.cat((idx, filtered_sequence.unsqueeze(0)), dim=1)
idx = concatenated_tensor # ์ต์ข
์ธ๋ฑ์ค
ํจ๊ณผ:
- ๊ณต๊ฐ์ ์ผ๋ก underrepresented ์์ญ ๋ณด์
- ๋ comprehensiveํ visual representation
PruMerge vs PruMerge+ ๋น๊ต
| ํญ๋ชฉ | PruMerge | PruMerge+ |
|---|---|---|
| ์์ถ๋ฅ | ~14ร (5.5%) | ~4ร (25%) |
| ์ ํ ๋ฐฉ์ | IQR Outlier๋ง | Outlier + Spatial Uniform |
| ๊ณต๊ฐ ์ปค๋ฒ๋ฆฌ์ง | ํธ์ค ๊ฐ๋ฅ | ๊ท ๋ฑ ๋ณด์ฅ |
์ฑ๋ฅ ๋น๊ต (Vicuna-7B):
| Metric | LLaVA-1.5 | PruMerge | PruMerge+ |
|---|---|---|---|
| VQAv2 | 78.5 | 72.0 | 76.8 |
| ScienceQA | 66.8 | 68.5 | 68.3 |
| TextVQA | 58.2 | 56.0 | 57.1 |
| POPE | 85.9 | 76.3 | 84.0 |
| MME | 1510.7 | 1350.3 | 1462.4 |
| MMBench | 64.3 | 60.9 | 64.9 |
Trade-off:
- PruMerge: ์ต๋ ํจ์จ์ฑ (14ร ์์ถ), ์ฝ๊ฐ์ ์ฑ๋ฅ ์ ํ
- PruMerge+: ํจ์จ์ฑ + ์ฑ๋ฅ ๊ท ํ (4ร ์์ถ, ๊ฑฐ์ ์๋ณธ ์ฑ๋ฅ)
3.5 Algorithm Summary
Algorithm 1: Token PruMerge and PruMerge+
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# Input: K, Q (penultimate layer), Y (output tokens), n (token count)
# Output: Y' (m tokens, m << n)
def token_prumerge(K, Q, Y, n):
# Step 1: Calculate class attention
a_cls = calculate_attention(Q_cls, K) # Eq 3.2
# Step 2: Adaptive token selection via IQR
indices = IQR_outlier_detection(a_cls) # Sec 3.2
m = len(indices)
selected_indices = indices
# Step 3 (Optional - PruMerge+): Spatial sampling
if PRUMERGE_PLUS:
r_o = m / n
spatial_indices = spatial_uniform_sample(r_o)
selected_indices = indices + spatial_indices
m = len(selected_indices)
# Step 4: Token merging via k-NN clustering
Y_prime = []
for p in selected_indices:
# Calculate similarity
similarities = cosine_similarity(
K[p],
K[others]
)
# Find k nearest neighbors
neighbor_indices = topk(similarities, k=32)
# Weighted averaging
weights = a_cls[neighbor_indices]
y_p_prime = weighted_average(
Y[neighbor_indices],
weights=weights
)
Y_prime.append(y_p_prime)
return Y_prime # m tokens
ํต์ฌ ๋จ๊ณ:
- AITS: IQR๋ก ์ค์ ํ ํฐ ์ ํ
- (Optional) Spatial sampling
- TS: k-NN clustering + weighted merging
-
Experiments
4.1 Main Results
์คํ ์ค์
Base Model: LLaVA-1.5 (7B, 13B)
- CLIP ViT-L/14 vision encoder
- Vicuna-7B / Vicuna-13B LLM
- 336ร336 resolution
- ์๋ณธ: 576 visual tokens
Training:
- LoRA fine-tuning (1 epoch)
- LLaVA-1.5 instruction data ์ฌ์ฉ
- Reduced visual tokens๋ก ํ์ต
Evaluation Benchmarks:
- VQAv2: Visual question answering
- ScienceQA (SQAI): Multimodal reasoning
- TextVQA (VQAT): OCR-based QA
- POPE: Hallucination evaluation
- MME: Perception & cognition
- MMBench (MMB): Comprehensive evaluation
์ฑ๋ฅ ๋น๊ต
Table 1: 6๊ฐ ๋ฒค์น๋งํฌ ๊ฒฐ๊ณผ
| Method | LLM | VQAv2 | SQAI | VQAT | POPE | MME | MMB |
|---|---|---|---|---|---|---|---|
| Existing Methods | ย | ย | ย | ย | ย | ย | ย |
| BLIP-2 | Vicuna-13B | 41.0 | 61.0 | 42.5 | 85.3 | 1293.8 | - |
| InstructBLIP | Vicuna-13B | - | 63.1 | 50.7 | 78.9 | 1212.8 | - |
| Qwen-VL-Chat | Qwen-7B | 78.2 | 68.2 | 61.5 | - | 1487.5 | 60.6 |
| LLaVA-1.5 Baselines | ย | ย | ย | ย | ย | ย | ย |
| LLaVA-1.5 | Vicuna-7B | 78.5 | 66.8 | 58.2 | 85.9 | 1510.7 | 64.3 |
| + PruMerge (5.5%) | Vicuna-7B | 72.0 | 68.5 | 56.0 | 76.3 | 1350.3 | 60.9 |
| + PruMerge+ (25%) | Vicuna-7B | 76.8 | 68.3 | 57.1 | 84.0 | 1462.4 | 64.9 |
| LLaVA-1.5 | Vicuna-13B | 80.0 | 71.6 | 61.3 | 85.9 | 1531.3 | 67.7 |
| + PruMerge (5.5%) | Vicuna-13B | 72.8 | 71.0 | 58.4 | 78.5 | 1428.2 | 62.3 |
| + PruMerge+ (25%) | Vicuna-13B | 77.8 | 71.0 | 58.6 | 84.4 | 1485.5 | 65.7 |
์ฃผ์ ๋ฐ๊ฒฌ:
-
PruMerge+ (25% tokens):
- VQAv2: 76.8 (์๋ณธ 78.5 ๋๋น -1.7)
- ScienceQA: 68.3 (์๋ณธ 66.8 ๋๋น +1.5)
- MME: 1462.4 (์๋ณธ 1510.7 ๋๋น -48.3)
- MMBench: 64.9 (์๋ณธ 64.3 ๋๋น +0.6)
- โ Comparable performance
-
PruMerge (5.5% tokens):
- ScienceQA: 68.5 (์๋ณธ ๋๋น +1.7)
- POPE: 76.3 (์๋ณธ 85.9 ๋๋น -9.6)
- โ ์ผ๋ถ ํ์คํฌ์์ ์ฑ๋ฅ ํฅ์!
-
vs. Previous Methods:
- BLIP-2, InstructBLIP ๋๋น ํจ์ฌ ์ฐ์
- Qwen-VL-Chat๊ณผ comparable
์ ์ผ๋ถ ํ์คํฌ์์ ์ฑ๋ฅ ํฅ์?
ScienceQA์์ ํฅ์ ์ด์ :
- ์ค์ํ ์๊ฐ ์ ๋ณด์ ์ง์ค
- Redundant tokens ์ ๊ฑฐ๋ก signal-to-noise ๋น์จ ํฅ์
- ์ถ๋ก ์ ํ์ํ ํต์ฌ features๋ง ์ ํ
POPE์์ PruMerge๊ฐ ์ฝํ ์ด์ :
- Object presence detection ํ์
- Spatial coverage ์ค์
- Aggressive pruning (5.5%)์ผ๋ก ์ผ๋ถ ๊ฐ์ฒด ์ ๋ณด ์์ค
- โ PruMerge+๊ฐ ์ด ๋ฌธ์ ํด๊ฒฐ (84.0)
4.2 Efficiency Analysis
Computational Cost (Table 2)
์คํ ํ๊ฒฝ: Tesla V100 GPU
๋ฐฉ๋ฒ๋ก : Roofline model ๊ธฐ๋ฐ theoretical analysis
LLaVA-1.5 (Vicuna-7B):
| Config | FLOPs (TB) | Prefill Time (ms) | Total Memory (GB) | Activation (GB) |
|---|---|---|---|---|
| FP16 | ย | ย | ย | ย |
| Original | 9.3 | 88.6 | 23.3 | 4.60 |
| + PruMerge | 0.91 | 15.3 | 13.7 | 0.28 |
| Speedup | 10.2ร | 5.8ร | 1.7ร | 16.4ร |
| INT4 | ย | ย | ย | ย |
| Original | 2.3 | 151.6 | 5.9 | 1.20 |
| + PruMerge | 0.28 | 14.9 | 3.5 | 0.07 |
| Speedup | 8.2ร | 10.2ร | 1.7ร | 17.1ร |
LLaVA-1.5 (Vicuna-13B):
| Config | FLOPs (TB) | Prefill Time (ms) | Total Memory (GB) | Activation (GB) |
|---|---|---|---|---|
| FP16 | ย | ย | ย | ย |
| Original | 18.2 | 170.5 | 41.6 | 7.30 |
| + PruMerge | 1.80 | 29.5 | 26.6 | 0.44 |
| Speedup | 10.1ร | 5.8ร | 1.6ร | 16.6ร |
| INT4 | ย | ย | ย | ย |
| Original | 4.6 | 294.9 | 10.5 | 1.80 |
| + PruMerge | 0.45 | 29.0 | 6.8 | 0.11 |
| Speedup | 10.2ร | 10.2ร | 1.5ร | 16.4ร |
ํต์ฌ ํจ์จ์ฑ ํฅ์:
-
FLOPs ๊ฐ์: ~10๋ฐฐ
- Quadratic complexity ํจ๊ณผ: O(nยฒ) โ O(mยฒ)
- 576ยฒ โ 40ยฒ โ 331,776 โ 1,600
-
Prefill Time: 5.8~10.2๋ฐฐ ๋นจ๋ผ์ง
- FP16: 88.6ms โ 15.3ms
- INT4: 151.6ms โ 14.9ms
- INT4 + PruMerge๊ฐ ๊ฐ์ฅ ๋น ๋ฆ!
-
Memory ์ ๊ฐ:
- Total: 1.5~1.7๋ฐฐ ๊ฐ์
- Activation: 16๋ฐฐ ์ด์ ๊ฐ์
-
Quantization๊ณผ์ ์๋์ง:
- INT4 quantization ์ ์ฉ ์ ๋ ๋น ๋ฅธ ์๋
- Orthogonal techniques๋ก ๊ฒฐํฉ ๊ฐ๋ฅ
Scenario Analysis
๊ฐ์ :
- Image: 336ร336 (576 visual tokens)
- Text prompt: 40 tokens
- PruMerge ์ ์ฉ ํ: 40 visual tokens
Token ์ ๋น๊ต:
1
2
3
4
Original: 576 (visual) + 40 (text) = 616 tokens
PruMerge: 40 (visual) + 40 (text) = 80 tokens
Reduction: 616 โ 80 (7.7ร fewer tokens)
Attention Computation:
1
2
3
4
Original: 616ยฒ = 379,456 operations
PruMerge: 80ยฒ = 6,400 operations
Speedup: 59.3ร in attention matrix computation
4.3 Generalization on Video-LLM
Video-LLaVA ํตํฉ
Video-LLaVA ํน์ฑ:
- 8 frames per video clip
- 16ร16 patches per frame
- 2048 visual tokens (8 ร 256)
- LLaVA-1.5 ๋๋น 4๋ฐฐ ๋ง์ tokens
PruMerge ์ ์ฉ (Training-free):
- Inference ์์๋ง ์ ์ฉ
- ์ถ๊ฐ ํ์ต ๋ถํ์
- ์ฆ์ ์ฌ์ฉ ๊ฐ๋ฅ
๊ฒฐ๊ณผ (Table 3)
Video QA Benchmarks:
| Method | LLM | MSVD-QA | ย | MSRVT-QA | ย | ActivityNet-QA | ย |
|---|---|---|---|---|---|---|---|
| ย | ย | Acc | Score | Acc | Score | Acc | Score |
| Baselines | ย | ย | ย | ย | ย | ย | ย |
| FrozenBiLM | 1B | 32.2 | - | 16.8 | - | 24.7 | - |
| VideoChat | 7B | 56.3 | 2.8 | 45.0 | 2.5 | - | 2.2 |
| LLaMA-Adapter | 7B | 54.9 | 3.1 | 43.8 | 2.7 | 34.2 | 2.7 |
| Video-LLaMA | 7B | 51.6 | 2.5 | 29.6 | 1.8 | 12.4 | 1.1 |
| Video-ChatGPT | 7B | 64.9 | 3.3 | 49.3 | 2.8 | 35.2 | 2.7 |
| Video-LLaVA | ย | ย | ย | ย | ย | ย | ย |
| Original | 7B | 70.7 | 3.9 | 59.2 | 3.5 | 45.3 | 3.3 |
| + PruMerge (12.5%) | 7B | 71.1 | 3.9 | 58.4 | 3.5 | 48.3 | 3.4 |
| + PruMerge+ (25%) | 7B | 71.1 | 3.9 | 59.3 | 3.6 | 47.7 | 3.4 |
๋๋ผ์ด ๋ฐ๊ฒฌ:
-
์ฑ๋ฅ ํฅ์:
- MSVD-QA: 70.7 โ 71.1 (+0.4)
- ActivityNet-QA: 45.3 โ 48.3 (+3.0)
- Token ๊ฐ์ํ๋๋ฐ ์ฑ๋ฅ ํฅ์!
-
ํ ํฐ ์์ถ:
- Original: 2048 tokens
- PruMerge: 256 tokens (12.5%)
- PruMerge+: 512 tokens (25%)
- 8๋ฐฐ / 4๋ฐฐ ์์ถ
-
Training-free:
- Video ๋ฐ์ดํฐ๋ก ์ฌํ์ต ๋ถํ์
- Inference ์์๋ง ์ ์ฉ
- ์ฆ์ ์ฌ์ฉ ๊ฐ๋ฅ
Insight:
- Video tokens์๋ significant redundancy ์กด์ฌ
- Temporal + spatial redundancy ๋ชจ๋ ํ์ฉ ๊ฐ๋ฅ
- Future direction: Temporal token reduction ํ๊ตฌ
4.4 Ablation Study
์ฉ์ด ์ ๋ฆฌ
PruMerge์ ๋ ๋ชจ๋
- AITS : Adaptive Important Token Selection
- IQR๋ก ์ค์ ํ ํฐ ์ ํ
- TS: Token Supplement
- KNN์ผ๋ก pruned ์ ๋ณด ๋ณํฉ
4.4.1 Token Sampling Strategy Analysis (Table 4)
๋น๊ต ์ ๋ต:
- LLaVA-PruMerge: IQR-based adaptive sampling
- Sequential: ์ฒ์ N๊ฐ ํ ํฐ ์ ํ
- Spatial: N๊ฐ ํ ํฐ์ ๊ณต๊ฐ์ ์ผ๋ก ๊ท ๋ฑ ๋ฐฐ์น
๊ฒฐ๊ณผ (๋์ผํ ํ ํฐ ์๋ก ๋น๊ต):
TextVQA (40 tokens):
- PruMerge: 54.00
- Sequential: 42.72
- Spatial 5ร8: 46.85
- Spatial 8ร5: 47.42
- โ PruMerge๊ฐ 11.3% ๋ ๋์
MME (40 tokens):
- PruMerge: 1250.07
- Sequential: 703.60
- Spatial 5ร8: 1180.23
- Spatial 8ร5: 1142.32
- โ PruMerge๊ฐ 77.7% ๋ ๋์
POPE (35 tokens):
- PruMerge: 76.2
- Sequential: 11.7 (!)
- Spatial 5ร7: 69.8
- Spatial 7ร5: 71.1
- Spatial 6ร6: 67.9
- โ PruMerge๊ฐ 6.5๋ฐฐ ๋์
ScienceQA (16 tokens):
- PruMerge: 68.07
- Sequential: 64.20
- Spatial 4ร4: 66.29
- โ PruMerge๊ฐ 3.87% ๋ ๋์
๋ถ์:
Sequential์ ๋ฌธ์ :
- ์ฒ์ N๊ฐ ํ ํฐ = ์ด๋ฏธ์ง ํน์ ์์ญ๋ง
- Spatial bias ์ฌ๊ฐ
- POPE์์ ๊ฑฐ์ random guess (11.7)
Spatial์ ์ฅ์ :
- ์ ์ฒด ์ด๋ฏธ์ง ์ปค๋ฒ๋ฆฌ์ง
- ๊ท ํ์กํ representation
- Sequential๋ณด๋ค ํจ์ฌ ์ฐ์
PruMerge์ ์ฐ์์ฑ:
- Attention-guided selection
- ์ ๋ณด ๋ฐ๋ ๋์ ์์ญ ์ง์ค
- Adaptive to image complexity
- ํนํ TextVQA (OCR)์์ ํฐ ์ฐจ์ด
- ํ ์คํธ ์์ญ์ ํ ํฐ ์ง์ค
- ์ธ๋ฐํ ์ ๋ณด ๋ณด์กด
4.4.2 Effectiveness of Each Module (Table 5)
์คํ ์ค์ :
- ๊ณ ์ : 40 tokens (6.9%)
- Vicuna-7B ๋ชจ๋ธ
- 4๊ฐ ๋ฒค์น๋งํฌ
Module ์กฐํฉ:
| Method | SQAI | VQAT | POPE | MME |
|---|---|---|---|---|
| LLaVA-1.5 (baseline) | 66.8 | 58.2 | 85.9 | 1510.7 |
| w. AITS only | 66.5 | 54.8 | 75.7 | 1221.6 |
| w. AITS & TS | 68.5 | 56.0 | 76.3 | 1350.3 |
๋ถ์:
AITS (Adaptive Important Token Selection) ๋จ๋ :
- SQA: 66.5 (baseline 66.8)
- TextVQA: 54.8 (baseline 58.2)
- POPE: 75.7 (baseline 85.9)
- MME: 1221.6 (baseline 1510.7)
- โ ํ ํฐ ์ ํ๋ง์ผ๋ก๋ ์ฑ๋ฅ ์ ํ
AITS + TS (Token Supplement):
- SQA: 68.5 (baseline ๋๋น +1.7)
- TextVQA: 56.0 (baseline ๋๋น -2.2)
- POPE: 76.3 (baseline ๋๋น -9.6)
- MME: 1350.3 (baseline ๋๋น -160.4)
- โ Token merging์ด ํ์์ !
TS์ ํจ๊ณผ:
- SQA: +2.0 (66.5 โ 68.5)
- TextVQA: +1.2 (54.8 โ 56.0)
- POPE: +0.6 (75.7 โ 76.3)
- MME: +128.7 (1221.6 โ 1350.3)
- โ ๋ชจ๋ ํ์คํฌ์์ ๊ฐ์
ํต์ฌ Insight:
- Token selection๋ง์ผ๋ก๋ ๋ถ์กฑ
- Merging์ด pruned tokens ์ ๋ณด ๋ณด์กด
- k-NN clustering + weighted averaging ํจ๊ณผ
4.4.3 Training Analysis (Table 6)
๋น๊ต:
- Training-free: PruMerge๋ง ์ ์ฉ, ํ์ต X
- LoRA fine-tuning: PruMerge + LoRA 1 epoch
๊ฒฐ๊ณผ (40 tokens, Vicuna-7B):
| Method | SQAI | VQAT | POPE | MME |
|---|---|---|---|---|
| LLaVA-1.5 (baseline) | 66.8 | 58.2 | 85.9 | 1510.7 |
| w.o. LoRA-FT | 68.0 | 54.0 | 76.2 | 1250.1 |
| w. LoRA-FT | 68.5 | 56.0 | 76.3 | 1350.3 |
๋ถ์:
Training-free ์ฑ๋ฅ:
- SQA: 68.0 (baseline ๋๋น +1.2)
- TextVQA: 54.0 (baseline ๋๋น -4.2)
- POPE: 76.2 (baseline ๋๋น -9.7)
- MME: 1250.1 (baseline ๋๋น -260.6)
- โ ์ผ๋ถ ํ์คํฌ๋ ์ฆ์ ์ฌ์ฉ ๊ฐ๋ฅ
Fine-tuning ํจ๊ณผ:
- SQA: +0.5 (68.0 โ 68.5)
- TextVQA: +2.0 (54.0 โ 56.0)
- POPE: +0.1 (76.2 โ 76.3)
- MME: +100.2 (1250.1 โ 1350.3)
- โ ๋ชจ๋ ํ์คํฌ์์ ๊ฐ์
Trade-off:
- Training-free: ๋น ๋ฅธ ์ ์ฉ, ์ผ๋ถ ์ฑ๋ฅ ์ ํ
- Fine-tuning: ์ต๊ณ ์ฑ๋ฅ, ์ถ๊ฐ ํ์ต ํ์ (1 epoch)
์ค์ฉ์ ์ ํ:
- Resource ์ถฉ๋ถ: Fine-tuning ๊ถ์ฅ
- ๋น ๋ฅธ ์ ์ฉ ํ์: Training-free๋ก ์์
-
์์ฝ
5.1 Adaptive Token Selection
ํต์ฌ ํ์ :
- IQR-based outlier detection: ํต๊ณ์ ์ผ๋ก ๊ฒ์ฆ๋ ๋ฐฉ๋ฒ
- Image-specific adaptation: ์ด๋ฏธ์ง๋ง๋ค ๋ค๋ฅธ ์์ ํ ํฐ
- Learned importance: ๋ชจ๋ธ์ด ํ์ตํ attention pattern ํ์ฉ
์ฅ์ :
- Manual threshold ๋ถํ์
- Robust to different image types
- Computation-efficient (๋จ์ ํต๊ณ ๊ณ์ฐ)
5.2 Token Merging via k-NN
ํต์ฌ ํ์ :
- Information preservation: Pruned tokens ์ ๋ณด ๋ณด์กด
- Similarity-based clustering: Semantic ์ ์ฌ๋ ๊ธฐ๋ฐ
- Weighted aggregation: Attention์ผ๋ก ๊ฐ์ค
์ฅ์ :
- Lossless์ ๊ฐ๊น์ด ์์ถ
- Semantic consistency ์ ์ง
- Large objects ์ ๋ณด ๋ณด์กด
5.3 PruMerge+ Hybrid Strategy
ํต์ฌ ํ์ :
- Attention + Spatial: ๋ ๊ฐ์ง ์์น ๊ฒฐํฉ
- Balanced coverage: ์ ์ฒด ์ด๋ฏธ์ง ์ปค๋ฒ๋ฆฌ์ง
- Performance-efficiency trade-off: ์ ํ ๊ฐ๋ฅ
์ฅ์ :
- Minimal performance drop
- Spatial bias ๋ฐฉ์ง
- Flexible deployment
5.4 Plug-and-Play Design
ํต์ฌ ํ์ :
- Vision encoder level: ์ํคํ ์ฒ ๋ ๋ฆฝ์
- Training-free option: ์ฆ์ ์ฌ์ฉ ๊ฐ๋ฅ
- Modular implementation: ์ฌ์ด ํตํฉ
์ฅ์ :
- LLaVA-1.5, Video-LLaVA ๋ฑ ์ฆ์ ์ ์ฉ
- Minimal code changes
- Research-friendly
-
Limitations ๋ฐ ํฅํ ๋ฐฉํฅ
ํ์ฌ ํ๊ณ (๋ ผ๋ฌธ ๊ธฐ์ค)
1. Not Entirely Lossless
- Visual token compression์ด ์์ ํ losslessํ์ง ์์
- ์๋ณธ LLaVA ๋๋น marginal performance gap ์กด์ฌ
- PruMerge+ (25%)๋ก ๋๋ถ๋ถ ํด๊ฒฐ๋๋ ์์ ํ์ง ์์
2. Large-Scale Model ๊ฒ์ฆ ๋ถ์กฑ
- Academic setting์ computational resources ํ๊ณ
- LLaVA-Next with Yi-34B ๋ฑ ๋๊ท๋ชจ ๋ชจ๋ธ์ ๋ํ ๊ฒ์ฆ ๋ฏธ์๋ฃ
ํฅํ ์ฐ๊ตฌ ๋ฐฉํฅ (๋ ผ๋ฌธ ๊ธฐ์ค)
1. Fully Lossless Compression
- ์์ ๋ฌด์์ค ํ ํฐ ์์ถ ์๊ณ ๋ฆฌ์ฆ ๊ฐ๋ฐ
- Performance gap ์์ ์ ๊ฑฐ ๋ชฉํ
2. Larger-Scale Models ํ์ฅ
- LLaVA-Next with Yi-34B backbone ๋ฑ ๋๊ท๋ชจ ๋ชจ๋ธ ์ ์ฉ
- Generalization ๋ฐ broader impact ๊ฒ์ฆ
-
Conclusion
LLaVA-PruMerge๋ Large Multimodal Models์ ํจ์จ์ฑ์ ํ๊ธฐ์ ์ผ๋ก ๊ฐ์ :
ํต์ฌ ๊ธฐ์ฌ:
- Adaptive token selection: IQR-based outlier detection
- Information-preserving merging: k-NN clustering + weighted averaging
- PruMerge+: Attention + spatial hybrid strategy
- 14๋ฐฐ / 4๋ฐฐ ์์ถ: ์ฑ๋ฅ ์ ์งํ๋ฉด์ ๋ํญ ์์ถ
์์:
- Visual token ์ ๊ด์ ์ ์ต์ด ํจ์จํ ์ฐ๊ตฌ
- Plug-and-play ๋ฐฉ์์ผ๋ก ์ฆ์ ์ ์ฉ ๊ฐ๋ฅ
- Training-free option์ผ๋ก ๋น ๋ฅธ ๋ฐฐํฌ
- Video-LLM์๋ ์ฆ์ ์ ์ฉ ๊ฐ๋ฅ
์ค์ฉ์ฑ:
- 10๋ฐฐ FLOPs ๊ฐ์
- 5.8~10.2๋ฐฐ ๋น ๋ฅธ prefill
- 50% ๋ฉ๋ชจ๋ฆฌ ์ ๊ฐ
- Quantization๊ณผ orthogonal (๊ฒฐํฉ ๊ฐ๋ฅ)
LLaVA-PruMerge๋ ํจ์จ์ฑ๊ณผ ์ฑ๋ฅ์ ๊ท ํ์ ์ด๋ฃจ๋ฉฐ, LMM์ ์ค์ฉ์ ๋ฐฐํฌ๋ฅผ ์ํ ์ค์ํ ๋จ๊ณ์ ๋๋ค.