Diffusion LLM을 위한 Attention 기반 Dependency-Aware Parallel Decoding
Dependency-Aware Parallel Decoding via Attention for Diffusion LLMs
TL;DR Highlight
Diffusion 언어 모델의 병렬 토큰 생성 시 self-attention으로 토큰 간 의존성을 파악해 서로 독립적인 토큰만 묶어 동시에 생성하는 훈련 불필요 기법
Who Should Read
LLaDA, Dream 같은 Diffusion LLM을 추론 서버에 올려야 하는 ML 엔지니어, 또는 AR(자동회귀) 모델 대비 dLLM의 병렬 디코딩 이점을 실제로 활용하고 싶은 연구자
Core Mechanics
- dLLM(Diffusion 기반 LLM)은 토큰을 동시에 여러 개 생성할 수 있지만, 서로 의존적인 토큰을 같이 뽑으면 'France의 수도는 London' 같은 전역 불일치가 생기는 문제가 있음
- DAPD는 모델의 self-attention 가중치를 이용해 마스킹된 토큰들 사이의 의존성 그래프(MRF)를 만들고, 엣지가 없는(독립적인) 토큰 집합만 골라 병렬로 생성함
- 토큰 선택은 Welsh-Powell 그래프 컬러링 휴리스틱으로 구현 — 연결이 많은 '허브' 토큰을 먼저 처리해 이후 단계 그래프를 점점 희소하게 만드는 전략
- 마스크 비율 50% 이하가 되면 대부분 토큰이 독립적이므로 confidence > 0.9인 토큰을 일괄 생성하는 빠른 완료 모드로 자동 전환
- 기존 방법들은 confidence가 높은 인접 토큰을 왼쪽→오른쪽으로 순차 생성하는 AR스러운 패턴을 보이는 반면, DAPD는 전체 시퀀스에 걸쳐 분산된 위치를 동시에 생성함
- 추가 학습이나 보조 모델 없이 추론 시점 attention만 사용하므로 기존 LLaDA-8B-Instruct, Dream-7B-Instruct에 바로 적용 가능
Evidence
- TriviaQA 5-question 묶음 실험에서 DAPD는 52.08% 정확도를 유지하면서 평균 66.2 스텝 — 기존 step-by-step(256 스텝) 대비 3.87×, Fast-dLLM(124.4 스텝) 대비 약 1.88× 빠름
- 합성 MRF 데이터셋 실험에서 attention 기반 엣지 스코어의 AUC 0.928, 엣지/비엣지 점수 비율 2.204 — attention이 실제 의존성 구조를 신뢰성 있게 포착함을 검증
- 노드 차수 순위 추정의 Order Violation Rate(OVR) 평균 0.04 — attention 합산값이 실제 MRF 차수 순서와 거의 일치
- LLaDA + MBPP/IFEval 벤치마크에서 DAPD는 block-wise 디코딩 없이 single-block으로도 4-block 기준선과 동등하거나 더 높은 정확도 달성
How to Apply
- LLaDA 또는 Dream 모델 추론 코드에서 각 디코딩 스텝마다 상위 레이어(전체의 상위 25%)의 attention map을 추출하고, 마스킹된 토큰 쌍의 대칭 점수 sij = (aij + aji)/2를 계산해 τ 임계값으로 의존성 그래프를 구성하면 됨
- 독립 집합 선택은 ˜di(attention 합산) × confi(최대 예측 확률) 기준으로 내림차순 정렬 후 Welsh-Powell 그리디로 구현 — 마스크 비율 50% 이하가 되면 confidence > 0.9 토큰을 일괄 언마스킹으로 전환
- 멀티-턴 또는 배치 QA처럼 독립적인 여러 질문을 하나의 프롬프트로 묶는 경우 효과가 가장 크게 나타남 — 서로 독립적인 질문들을 전체 시퀀스에 걸쳐 동시에 생성하므로 NFE(forward pass 횟수)가 대폭 줄어듦
Code Example
import torch
def compute_edge_scores(attention_maps, masked_indices):
"""
attention_maps: [num_layers, num_heads, seq_len, seq_len]
masked_indices: list of masked token positions
Returns symmetric edge score matrix for masked tokens
"""
# 상위 25% 레이어만 사용 (예: 32레이어 모델이면 마지막 8개)
top_layers = attention_maps[-len(attention_maps)//4:]
# 모든 헤드/레이어 평균
avg_attn = top_layers.mean(dim=(0, 1)) # [seq_len, seq_len]
# 대칭 edge score: sij = (aij + aji) / 2
n = len(masked_indices)
edge_scores = torch.zeros(n, n)
for ii, i in enumerate(masked_indices):
for jj, j in enumerate(masked_indices):
if ii != jj:
edge_scores[ii, jj] = (avg_attn[i, j] + avg_attn[j, i]) / 2
return edge_scores
def welsh_powell_independent_set(edge_scores, masked_indices, tau, confidences):
"""
Welsh-Powell 기반 독립 집합 선택
"""
n = len(masked_indices)
# confidence-weighted degree 계산
degrees = (edge_scores > tau).float().sum(dim=1) # [n]
weighted_degrees = degrees * confidences
# 내림차순 정렬
order = torch.argsort(weighted_degrees, descending=True)
selected = []
selected_set = set()
for idx in order.tolist():
# 이미 선택된 노드와 엣지가 없으면 추가
conflict = any(
edge_scores[idx, s] > tau for s in selected_set
)
if not conflict:
selected.append(masked_indices[idx])
selected_set.add(idx)
return selected
# 사용 예시
# mask_ratio = unmasked_count / total_length
# if mask_ratio > 0.5:
# tokens_to_unmask = welsh_powell_independent_set(...)
# else: # 후반부: confidence 기반 빠른 완료
# tokens_to_unmask = [i for i in masked_indices if confidences[i] > 0.9]Terminology
Related Resources
Original Abstract (Expand)
Parallel decoding for diffusion LLMs (dLLMs) is difficult because each denoising step provides only token-wise marginal distributions, while unmasking multiple tokens simultaneously requires accounting for inter-token dependencies. We propose Dependency-Aware Parallel Decoding (DAPD), a simple, training-free decoding method that uses self-attention to induce a conditional dependency graph over masked tokens. At each iteration, edges in this graph capture strong token interactions, while non-edges indicate weak dependence. Parallel decoding is then reduced to selecting an independent set on the graph and unmasking the selected tokens in parallel. This avoids co-updating strongly coupled tokens without auxiliary models or retraining. Experiments on LLaDA and Dream show that DAPD improves the accuracy-steps trade-off over existing methods and enables more globally distributed parallel updates that better exploit the any-order generation capability of dLLMs.