AdaFuse: Test-Time Scaling을 활용한 LLM 적응형 Ensemble Decoding
AdaFuse: Adaptive Ensemble Decoding with Test-Time Scaling for LLMs
TL;DR Highlight
여러 오픈소스 LLM을 추론 중에 동적으로 합쳐서 재학습 없이 단일 모델보다 평균 6.88% 더 좋은 성능을 내는 inference-time 앙상블 기법
Who Should Read
오픈소스 LLM 여러 개를 조합해 단일 모델 한계를 넘고 싶은 ML 엔지니어. 특히 fine-tuning 예산 없이 QA·번역·수학 추론 파이프라인 성능을 올리려는 백엔드/AI 개발자.
Core Mechanics
- 기존 앙상블(토큰/스팬/샘플 단위)은 융합 단위가 고정돼 있어 생성 도중 전략을 바꿀 수 없음 — AdaFuse는 단어(word) 단위로 앙상블하면서 매 스텝마다 '얼마나 묶을지'를 동적으로 결정
- 모델이 확신할 때(top-1과 top-2 확률 차 ≥ 0.7)는 최대 3단어까지 그냥 이어서 생성하고, 불확실할 때만 앙상블 스코어링 실행 → 불필요한 연산 최소화
- 불확실한 상황에서는 Diversity-Aware Scaling이 발동: 첫 토큰 후보 B개를 각각 greedy로 완성해 다양한 단어 후보를 만들고 NLL(모델이 그 단어를 얼마나 틀리다고 느끼는지) 평균으로 최적 선택
- LLaMA-3.1-8B-Instruct + Mistral-7B-Instruct-v0.3 조합을 기본 페어로 쓰고, Qwen3-8B·InternLM3-8B-Instruct까지 포함해 6개 벤치마크에서 모든 기존 앙상블 방식 압도
- 단어 수/라운드 분포가 태스크마다 다름: NQ는 80%가 1단어, De→En 번역은 46%가 3단어 — 고정 길이 전략이 왜 열위인지 보여주는 근거
- closed API(GPT-4 등)는 token-level probability를 주지 않으므로 오픈소스 모델에만 적용 가능 — 한계로 명시
Evidence
- 6개 벤치마크 평균 63.23점 vs 기존 최강 SweetSpan 59.16점 → 상대적 6.88% 향상
- NQ +10.01%, GSM8K +17.03%, SQuAD +4.12%, Flores En→De +6.04% 개선 (모두 기존 앙상블 최강 대비)
- GSM8K에서 가장 강한 두 모델(Qwen3-8B + LLaMA-3.1-8B) 조합 시 90.25 정확도 — 기존 최강 앙상블 대비 +11.52pt
- branching factor B=1→5로 늘릴수록 NQ·TriviaQA·GSM8K 전반에서 정확도 우상향 확인; beam-search 기반 스케일링은 B 증가 시 오히려 성능 하락
How to Apply
- HuggingFace에서 서로 강점이 다른 오픈소스 모델 2개 로드 후 GitHub(https://github.com/CCM0111/AdaFuse) 코드의 AdaFuse 디코딩 루프로 교체 — confidence threshold τ=0.7, 최대 단어 수 M=3이 기본값
- 같은 계열 모델(예: LLaMA-8B + LLaMA-70B)보다 아키텍처가 다른 모델(LLaMA + Mistral)을 조합해야 상호 보완 효과가 큼 — 개별 모델 점수 격차가 너무 크면(GSM8K의 Mistral 사례) 앙상블 이득 감소
- 불확실한 구간에서만 다양성 탐색(Diversity-Aware Scaling)을 켜면 추가 정확도 확보 가능 — branching factor B=3이 품질/속도 균형점; GPU 여유 있으면 B=5까지 시도
Code Example
# AdaFuse 핵심 디코딩 루프 (Algorithm 1 기반 의사코드)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def adafuse_decode(prompt, models, tokenizers, tau=0.7, M=3, B=3):
"""
models: list of AutoModelForCausalLM (오픈소스, token prob 접근 필수)
tau: 확신도 threshold (top1 - top2 확률 차)
M: 라운드당 최대 단어 수
B: 불확실할 때 탐색할 후보 수
"""
prefix_ids = tokenizers[0](prompt, return_tensors='pt').input_ids
while not is_terminated(prefix_ids):
candidate_pool = []
for model, tokenizer in zip(models, tokenizers):
span_ids = []
word_count = 0
while word_count < M:
# 다음 토큰 분포 계산
with torch.no_grad():
logits = model(torch.cat([prefix_ids, torch.tensor([span_ids])], dim=1)).logits[0, -1]
probs = torch.softmax(logits, dim=-1)
top2 = torch.topk(probs, 2)
p1, p2 = top2.values[0].item(), top2.values[1].item()
confidence_margin = p1 - p2
if confidence_margin >= tau:
# 확신: greedy로 단어 완성 후 계속
word_ids = gen_word_greedy(model, prefix_ids, span_ids)
span_ids += word_ids
word_count += 1
else:
# 불확실: top-B 초기 토큰으로 다양한 후보 생성
top_b_tokens = torch.topk(probs, B).indices
for start_token in top_b_tokens:
word_ids = gen_word_greedy(model, prefix_ids, span_ids, force_first=start_token)
candidate_pool.append(span_ids + word_ids)
break
else:
candidate_pool.append(span_ids) # 확신 상태로 M단어 완성
# 모든 모델에서 NLL 평균으로 최적 span 선택
best_span = min(
candidate_pool,
key=lambda span: avg_nll_across_models(span, models, prefix_ids)
)
prefix_ids = torch.cat([prefix_ids, torch.tensor([best_span])], dim=1)
return tokenizers[0].decode(prefix_ids[0])
def avg_nll_across_models(span_ids, models, prefix_ids):
"""수식 (4),(5): 각 모델의 NLL 평균"""
nlls = []
for model in models:
log_probs = get_log_probs(model, prefix_ids, span_ids) # token-level
nlls.append(-log_probs.mean().item())
return sum(nlls) / len(nlls)Terminology
Related Resources
Original Abstract (Expand)
Large language models (LLMs) exhibit complementary strengths arising from differences in pretraining data, model architectures, and decoding behaviors. Inference-time ensembling provides a practical way to combine these capabilities without retraining. However, existing ensemble approaches suffer from fundamental limitations. Most rely on fixed fusion granularity, which lacks the flexibility required for mid-generation adaptation and fails to adapt to different generation characteristics across tasks. To address these challenges, we propose AdaFuse, an adaptive ensemble decoding framework that dynamically selects semantically appropriate fusion units during generation. Rather than committing to a fixed granularity, AdaFuse adjusts fusion behavior on the fly based on the decoding context, with words serving as basic building blocks for alignment. To be specific, we introduce an uncertainty-based criterion to decide whether to apply ensembling at each decoding step. Under confident decoding states, the model continues generation directly. In less certain states, AdaFuse invokes a diversity-aware scaling strategy to explore alternative candidate continuations and inform ensemble decisions. This design establishes a synergistic interaction between adaptive ensembling and test-time scaling, where ensemble decisions guide targeted exploration, and the resulting diversity in turn strengthens ensemble quality. Experiments on open-domain question answering, arithmetic reasoning, and machine translation demonstrate that AdaFuse consistently outperforms strong ensemble baselines, achieving an average relative improvement of 6.88%. The code is available at https://github.com/CCM0111/AdaFuse.