FastAV: Efficient Token Pruning for Audio-Visual Large Language Model Inference
TL;DR Highlight
A token pruning framework for audio-visual multimodal LLMs that cuts computation by 40%+ without additional training while maintaining or even improving performance
Who Should Read
ML engineers deploying audio-visual LLMs like VideoLLaMA2 or video-SALMONN2 in production who want to reduce inference cost and latency. Also useful for infra/serving developers optimizing GPU memory and speed for multimodal models.
Core Mechanics
- In both VideoLLaMA2 and video-SALMONN2, after the middle layer (~14th), attention concentrates on earlier tokens — later tokens become effectively unnecessary
- 2-stage pruning strategy: attention rollout-based global pruning at the middle layer → fine pruning based on last query token at subsequent layers
- Compatible with FlashAttention — doesn't need the full attention map, uses only the last query
- Audio tokens reduced from 1,496 to 10 (99%+ reduction) while maintaining performance (AV matching actually +11%)
- Raw attention weights don't reveal important token patterns, but attention rollout (cumulative aggregation across layers) makes them clearly visible
- Applicable at inference-time only without any additional training (fine-tuning)
Evidence
- VideoLLaMA2: FLOPs 100→56 (44% reduction), inference latency 0.43→0.32s, memory 22G→19G, AV matching 57.8→69.0 (+11.2%p)
- video-SALMONN2: FLOPs 100→58 (42% reduction), latency 0.44→0.29s, memory 28G→21G, AVQA 57.6→58.4
- Global pruning comparison: random (69.0%), low attentive (70.5%), FastAV rollout-based (74.5%) — best performance at same FLOPs
- Fine pruning ratio P=20%: FLOPs 56, average accuracy 74.9% (peak, even higher than P=30%)
How to Apply
- When serving AV-LLMs like VideoLLaMA2, compute attention rollout at the middle layer (half of total layers) and bulk-remove tokens after position 750 for immediate memory and compute savings
- In subsequent layers, remove the bottom 20% by attention score of the last query token per layer — no full attention map needed, directly compatible with FlashAttention environments
- For models with many audio tokens (1,000+), consider aggressive pruning to just 10-20 — per the paper, no performance loss or even improvement
Code Example
# FastAV core logic pseudocode (PyTorch-based)
import torch
import torch.nn.functional as F
def attention_rollout(attention_matrices, alpha=0.5):
"""
attention_matrices: list of [batch, heads, seq, seq] tensors (per layer)
"""
rollout = None
for attn in attention_matrices:
# average over heads
attn_mean = attn.mean(dim=1) # [batch, seq, seq]
# reflect residual connection
identity = torch.eye(attn_mean.size(-1), device=attn.device).unsqueeze(0)
attn_mod = alpha * attn_mean + (1 - alpha) * identity
rollout = attn_mod if rollout is None else torch.bmm(attn_mod, rollout)
return rollout # [batch, seq, seq]
def global_pruning(token_indices, rollout, keep_position=750):
"""Remove low-informative tokens after position 750 in intermediate layers"""
# Remove tokens with low importance in the latter positions based on rollout
importance = rollout.sum(dim=1) # [batch, seq]
mask = (token_indices < keep_position) | (importance > importance.median())
return mask
def fine_pruning(query_last, keys, prune_ratio=0.2):
"""
query_last: [batch, heads, 1, dim] — last query token
keys: [batch, heads, seq, dim]
"""
scores = torch.softmax(
torch.matmul(query_last, keys.transpose(-2, -1)) / (keys.size(-1) ** 0.5),
dim=-1
).mean(dim=1).squeeze(1) # [batch, seq]
k = int(scores.size(-1) * (1 - prune_ratio))
_, keep_indices = scores.topk(k, dim=-1)
return keep_indicesTerminology
Related Resources
Original Abstract (Expand)
In this work, we present FastAV, the first token pruning framework tailored for audio-visual large language models (AV-LLMs). While token pruning has been actively explored in standard large language models (LLMs) and vision-language models (LVLMs), its application to AV-LLMs has received little attention, even though multimodal integration substantially increases their token demands. To address this gap, we introduce a pruning strategy that utilizes attention weights to identify tokens emphasized at different stages and estimates their importance. Building on this analysis, FastAV applies a two-stage pruning strategy: (1) global pruning in intermediate layers to remove broadly less influential tokens, and (2) fine pruning in later layers considering the impact on next token generation. Notably, our method does not rely on full attention maps, which makes it fully compatible with efficient attention mechanisms such as FlashAttention. Extensive experiments demonstrate that FastAV reduces FLOPs by more than 40% on two representative AV-LLMs, while preserving or even improving model performance.