AdaFuse: Adaptive Ensemble Decoding with Test-Time Scaling for LLMs
TL;DR Highlight
An inference-time ensemble technique that dynamically merges multiple open-source LLMs during inference—without retraining—achieving an average 6.88% performance improvement over single models
Who Should Read
ML engineers looking to combine multiple open-source LLMs to surpass single-model limitations. Especially backend/AI developers who want to boost QA, translation, or math reasoning pipeline performance without a fine-tuning budget.
Core Mechanics
- Existing ensemble methods (token/span/sample-level) have fixed fusion granularity and cannot change strategy mid-generation — AdaFuse ensembles at the word level and dynamically decides 'how many tokens to group' at every step
- When the model is confident (top-1 minus top-2 probability gap ≥ 0.7), it generates up to 3 words greedily without ensemble scoring; ensemble scoring only triggers under uncertainty → minimizing unnecessary computation
- Under uncertainty, Diversity-Aware Scaling kicks in: B initial token candidates are each greedily completed to form diverse word candidates, and the optimal word is selected by averaging NLL (how 'surprising' the model finds each word) across models
- LLaMA-3.1-8B-Instruct + Mistral-7B-Instruct-v0.3 is used as the default pair, with Qwen3-8B and InternLM3-8B-Instruct also included, outperforming all existing ensemble methods across 6 benchmarks
- Word count/round distributions vary by task: NQ has 80% single-word rounds, De→En translation has 46% three-word rounds — evidence for why fixed-length strategies are inferior
- Closed APIs (e.g., GPT-4) do not expose token-level probabilities, so AdaFuse is only applicable to open-source models — explicitly noted as a limitation
Evidence
- Average of 63.23 across 6 benchmarks vs. previous best SweetSpan at 59.16 → relative improvement of 6.88%
- NQ +10.01%, GSM8K +17.03%, SQuAD +4.12%, Flores En→De +6.04% improvement (all compared to the strongest existing ensemble baseline)
- Combining the two strongest models on GSM8K (Qwen3-8B + LLaMA-3.1-8B) achieves 90.25 accuracy — +11.52pt over the strongest existing ensemble
- Increasing branching factor B from 1 to 5 consistently improves accuracy on NQ, TriviaQA, and GSM8K; beam-search-based scaling degrades performance as B increases
How to Apply
- Load 2 open-source models with complementary strengths from HuggingFace, then replace the decoding loop with AdaFuse from the GitHub repo (https://github.com/CCM0111/AdaFuse) — default values are confidence threshold τ=0.7 and max words per round M=3
- Models from different architectures (e.g., LLaMA + Mistral) yield greater complementary gains than same-family models (e.g., LLaMA-8B + LLaMA-70B) — if the performance gap between individual models is too large (e.g., Mistral on GSM8K), ensemble gains diminish
- Enabling Diversity-Aware Scaling only during uncertain spans provides additional accuracy gains — branching factor B=3 is the quality/speed sweet spot; try B=5 if GPU headroom allows
Code Example
# AdaFuse core decoding loop (pseudocode based on Algorithm 1)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def adafuse_decode(prompt, models, tokenizers, tau=0.7, M=3, B=3):
"""
models: list of AutoModelForCausalLM (open-source models with token prob access required)
tau: confidence threshold (top1 - top2 probability gap)
M: max number of words per round
B: number of candidates to explore under uncertainty
"""
prefix_ids = tokenizers[0](prompt, return_tensors='pt').input_ids
while not is_terminated(prefix_ids):
candidate_pool = []
for model, tokenizer in zip(models, tokenizers):
span_ids = []
word_count = 0
while word_count < M:
# Compute next token distribution
with torch.no_grad():
logits = model(torch.cat([prefix_ids, torch.tensor([span_ids])], dim=1)).logits[0, -1]
probs = torch.softmax(logits, dim=-1)
top2 = torch.topk(probs, 2)
p1, p2 = top2.values[0].item(), top2.values[1].item()
confidence_margin = p1 - p2
if confidence_margin >= tau:
# Confident: complete word greedily and continue
word_ids = gen_word_greedy(model, prefix_ids, span_ids)
span_ids += word_ids
word_count += 1
else:
# Uncertain: generate diverse candidates from top-B initial tokens
top_b_tokens = torch.topk(probs, B).indices
for start_token in top_b_tokens:
word_ids = gen_word_greedy(model, prefix_ids, span_ids, force_first=start_token)
candidate_pool.append(span_ids + word_ids)
break
else:
candidate_pool.append(span_ids) # Completed M words in confident state
# Select the best span by averaging NLL across all models
best_span = min(
candidate_pool,
key=lambda span: avg_nll_across_models(span, models, prefix_ids)
)
prefix_ids = torch.cat([prefix_ids, torch.tensor([best_span])], dim=1)
return tokenizers[0].decode(prefix_ids[0])
def avg_nll_across_models(span_ids, models, prefix_ids):
"""Equations (4),(5): average NLL across all models"""
nlls = []
for model in models:
log_probs = get_log_probs(model, prefix_ids, span_ids) # token-level
nlls.append(-log_probs.mean().item())
return sum(nlls) / len(nlls)Terminology
Related Resources
Original Abstract (Expand)
Large language models (LLMs) exhibit complementary strengths arising from differences in pretraining data, model architectures, and decoding behaviors. Inference-time ensembling provides a practical way to combine these capabilities without retraining. However, existing ensemble approaches suffer from fundamental limitations. Most rely on fixed fusion granularity, which lacks the flexibility required for mid-generation adaptation and fails to adapt to different generation characteristics across tasks. To address these challenges, we propose AdaFuse, an adaptive ensemble decoding framework that dynamically selects semantically appropriate fusion units during generation. Rather than committing to a fixed granularity, AdaFuse adjusts fusion behavior on the fly based on the decoding context, with words serving as basic building blocks for alignment. To be specific, we introduce an uncertainty-based criterion to decide whether to apply ensembling at each decoding step. Under confident decoding states, the model continues generation directly. In less certain states, AdaFuse invokes a diversity-aware scaling strategy to explore alternative candidate continuations and inform ensemble decisions. This design establishes a synergistic interaction between adaptive ensembling and test-time scaling, where ensemble decisions guide targeted exploration, and the resulting diversity in turn strengthens ensemble quality. Experiments on open-domain question answering, arithmetic reasoning, and machine translation demonstrate that AdaFuse consistently outperforms strong ensemble baselines, achieving an average relative improvement of 6.88%. The code is available at https://github.com/CCM0111/AdaFuse.