DivPrune: 다양성 기반 Visual Token Pruning for Large Multimodal Models
DivPrune: Diversity-based Visual Token Pruning for Large Multimodal Models
TL;DR Highlight
이미지/비디오 멀티모달 모델의 visual token을 최대 90% 제거해도 성능을 거의 유지하는 plug-and-play 추론 최적화 기법
Who Should Read
LLaVA 같은 멀티모달 LLM을 서빙하면서 추론 속도와 GPU 메모리 비용을 줄이고 싶은 ML 엔지니어. 특히 이미지·비디오 처리 파이프라인에서 latency가 병목인 팀.
Core Mechanics
- 기존 attention score 기반 pruning(FastV, VTW)은 비슷한 토큰끼리 몰려서 선택하는 문제가 있음 → 고압축 시 성능 급락
- DivPrune은 MMDP(Max-Min Diversity Problem: 선택된 토큰 간 최소 거리를 최대화하는 조합 최적화 문제)로 pruning을 재정의해서 다양한 시각 정보를 고르게 보존
- Fine-tuning도, calibration 데이터도 필요 없는 plug-and-play — 기존 모델에 그냥 끼워 쓸 수 있음
- LLaVA 1.5-7B 기준 TFLOP 84% 절감 시 POPE F1이 오히려 원본보다 0.18% 향상, COCO CIDEr 손실은 12.7%에 불과 (FastV/VTW는 동일 조건에서 95% 손실)
- 이미지뿐 아니라 비디오(LLaVA-NeXT-Video-7B)에도 동일하게 적용되며, visual context가 클수록 효과가 더 좋음
- KV 캐싱 등 기존 추론 최적화 기법과 호환 가능
Evidence
- LLaVA 1.6-7B에서 TFLOP 89% 절감 시 POPE F1 기준 FastV/VTW -79% vs DivPrune -3.4%
- LLaVA-NeXT-Video-7B에서 ActivityNet accuracy FastV 33.91% vs DivPrune 45.90% (원본 48.10%)
- 비디오 모델 E2E latency 22% 단축, prefill time 55% 단축, GPU 메모리 400MB 절감
- LLaVA 1.5-13B POPE F1 기준 VTW 대비 83%, FastV 대비 53.4%, PruMerge 대비 15.2% 향상
How to Apply
- LLaVA 계열 모델 서빙 시 visual token을 LLM 첫 번째 레이어에 넘기기 전에 DivPrune으로 90% 제거 → TFLOP 84% 절감. GitHub 코드로 바로 적용 가능
- Attention score 기반 FastV를 쓰고 있다면 DivPrune으로 교체 검토. 특히 압축률이 높을수록(토큰을 많이 잘라낼수록) 성능 차이가 극적으로 벌어짐
- 비디오 입력처럼 frame당 토큰 수가 많은 경우(1152개 이상) 효과가 더 크므로, 비디오 이해 서비스에서 우선 도입 검토
Code Example
import torch
import torch.nn.functional as F
def divprune(visual_tokens: torch.Tensor, keep_ratio: float = 0.1) -> torch.Tensor:
"""
DivPrune: Max-Min Diversity 기반 visual token 선택
Args:
visual_tokens: [M, D] shape의 visual token 텐서
keep_ratio: 유지할 토큰 비율 (0.1 = 10% 유지)
Returns:
선택된 토큰 [M_kept, D]
"""
M, D = visual_tokens.shape
M_keep = max(1, int(M * keep_ratio))
# 코사인 거리 행렬 사전 계산 (1 - cosine_similarity)
normed = F.normalize(visual_tokens, dim=-1)
sim_matrix = normed @ normed.T # [M, M]
dist_matrix = 1.0 - sim_matrix # cosine distance
dist_matrix.fill_diagonal_(float('inf')) # 자기 자신 제외
selected_idx = []
remaining = list(range(M))
# Stage 1: 첫 번째 토큰 선택 (다른 모든 토큰과의 최소 거리가 가장 큰 것)
min_dists = dist_matrix[remaining][:, remaining].min(dim=1).values
first = remaining[min_dists.argmax().item()]
selected_idx.append(first)
remaining.remove(first)
# Stage 2: 선택된 집합과의 최소 거리가 가장 큰 토큰을 반복 추가
while len(selected_idx) < M_keep:
sel_tensor = torch.tensor(selected_idx)
rem_tensor = torch.tensor(remaining)
# 각 remaining 토큰 → 선택된 토큰들과의 최소 거리
dists_to_sel = dist_matrix[rem_tensor][:, sel_tensor].min(dim=1).values
next_idx = remaining[dists_to_sel.argmax().item()]
selected_idx.append(next_idx)
remaining.remove(next_idx)
return visual_tokens[torch.tensor(selected_idx)]
# 사용 예시 (LLaVA-style)
# visual_tokens: vision encoder 출력 [576, 4096]
# pruned = divprune(visual_tokens, keep_ratio=0.10) # 576 -> ~58개 토큰
# llm_input = torch.cat([text_tokens, pruned], dim=0)Terminology
Related Resources
Original Abstract (Expand)
Large Multimodal Models (LMMs) have emerged as powerful models capable of understanding various data modalities, including text, images, and videos. LMMs encode both text and visual data into tokens that are then combined and processed by an integrated Large Language Model (LLM). Including visual tokens substantially increases the total token count, often by thousands. The increased input length for LLM significantly raises the complexity of inference, resulting in high latency in LMMs. To address this issue, token pruning methods, which remove part of the visual tokens, are proposed. The existing token pruning methods either require extensive calibration and fine-tuning or rely on suboptimal importance metrics which results in increased redundancy among the retained tokens. In this paper, we first formulate token pruning as Max-Min Diversity Problem (MMDP) where the goal is to select a subset such that the diversity among the selected tokens is maximized. Then, we solve the MMDP to obtain the selected subset and prune the rest. The proposed method, DivPrune, reduces redundancy and achieves the highest diversity of the selected tokens. By ensuring high diversity, the selected tokens better represent the original tokens, enabling effective performance even at high pruning ratios without requiring fine-tuning. Extensive experiments with various LMMs show that DivPrune achieves state-of-the-art accuracy over 16 image- and video-language datasets. Additionally, DivPrune reduces both the end-to-end latency and GPU memory usage for the tested models. The code is available here⋄.