Speculative Sampling으로 Large Language Model 디코딩 가속화
Accelerating Large Language Model Decoding with Speculative Sampling
TL;DR Highlight
작은 draft 모델이 토큰을 미리 예측하고 큰 모델이 한 번에 검증해서 LLM 추론 속도를 최대 25배 높이는 기법
Who Should Read
LLM 서빙 비용과 응답 레이턴시를 줄여야 하는 ML 엔지니어나 인프라 개발자. 특히 70B 이상 대형 모델을 프로덕션에서 운영 중인 팀.
Core Mechanics
- 핵심 아이디어: 작은 draft 모델이 K개 토큰을 미리 생성하고, 큰 target 모델이 이를 병렬로 한 번에 채점(scoring)해서 여러 토큰을 한 번의 forward pass로 확정
- Modified Rejection Sampling 기법으로 draft 토큰을 수락/거절하는데, 수학적으로 target 모델 분포를 그대로 보존함 — 출력 품질 손실 없음
- Chinchilla 70B + 4B draft 모델 조합에서 HumanEval(코드 생성)은 최대 2.46배, XSum(요약)은 최대 2배 속도 향상 달성
- 코드 생성 태스크에서 특히 효과적 — for i in range(len(arr)): 같은 반복 패턴을 작은 모델이 정확히 예측해서 acceptance rate가 높음
- target 모델 구조 수정 없이 적용 가능하고, 양자화(quantization)나 multi-query attention 같은 다른 최적화 기법과 병행 사용 가능
- K(draft 길이) 값이 너무 크면 오히려 역효과 — XSum에서는 K=3이 최적이고, K가 커질수록 레이턴시 분산도 증가
Evidence
- Chinchilla 70B 기준 HumanEval에서 14.1ms/token → 5.73ms/token, 속도 2.46배 향상, ROUGE/pass@100 지표는 동일 수준 유지
- XSum 요약 태스크에서 nucleus sampling 기준 14.1ms/token → 7.52ms/token (1.92배), greedy 기준 → 7.00ms/token (2.01배)
- HumanEval과 greedy XSum에서 달성한 속도는 하드웨어 메모리 대역폭이 설정한 auto-regressive 샘플링의 이론적 상한선을 초과
- draft 모델(4B)은 1.8ms/token, target 모델(Chinchilla 70B)은 14.1ms/token — 약 8배 속도 차이로 drafting 오버헤드 충분히 상쇄
How to Apply
- vLLM, TGI 같은 서빙 프레임워크에서 speculative decoding 옵션을 활성화할 때 — draft 모델은 target 모델과 같은 tokenizer를 쓰는 작은 버전(보통 1/10~1/20 크기)으로 설정하면 된다
- 코드 자동완성, SQL 생성처럼 반복 패턴이 많은 태스크에 우선 적용하면 효과가 크다 — acceptance rate가 자연어 요약보다 훨씬 높으므로 K=4~5로 설정해볼 것
- K 값은 도메인마다 튜닝이 필요하다 — 자연어 태스크는 K=3~4, 코드 태스크는 K=4~6 범위에서 실제 레이턴시를 측정해서 결정해야 한다
Code Example
# HuggingFace Transformers에서 Speculative Decoding 사용 예시
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Target 모델 (큰 모델)
target_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-70b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")
# Draft 모델 (작은 모델, 같은 tokenizer 사용)
draft_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
inputs = tokenizer("def fibonacci(n):", return_tensors="pt")
# speculative decoding 활성화: assistant_model 파라미터에 draft 모델 전달
outputs = target_model.generate(
**inputs,
assistant_model=draft_model, # speculative decoding 핵심 파라미터
max_new_tokens=200,
do_sample=True,
temperature=0.8,
top_p=0.95,
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# vLLM에서 사용하는 경우
# vllm serve meta-llama/Llama-2-70b-hf \
# --speculative-model meta-llama/Llama-2-7b-hf \
# --num-speculative-tokens 4Terminology
Related Resources
Original Abstract (Expand)
We present speculative sampling, an algorithm for accelerating transformer decoding by enabling the generation of multiple tokens from each transformer call. Our algorithm relies on the observation that the latency of parallel scoring of short continuations, generated by a faster but less powerful draft model, is comparable to that of sampling a single token from the larger target model. This is combined with a novel modified rejection sampling scheme which preserves the distribution of the target model within hardware numerics. We benchmark speculative sampling with Chinchilla, a 70 billion parameter language model, achieving a 2-2.5x decoding speedup in a distributed setup, without compromising the sample quality or making modifications to the model itself.