토큰이 아닌 Feature를 맞춰라: Language Model의 Energy-Based Fine-Tuning
Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models
TL;DR Highlight
SFT처럼 토큰 단위로 학습하는 대신, 모델 출력의 feature 통계를 ground-truth와 맞추는 새로운 파인튜닝 방법(EBFT)으로 SFT를 능가하고 RLVR과 동등한 성능을 달성했다.
Who Should Read
LLM 파인튜닝 파이프라인에서 SFT의 한계(긴 시퀀스 품질 저하, 분포 불일치)를 겪고 있는 ML 엔지니어. 특히 검증 가능한 reward 없이도 RL 수준의 성능을 내고 싶은 팀.
Core Mechanics
- SFT(Supervised Fine-Tuning)는 teacher forcing 방식으로 학습해서, 모델이 실제로 생성할 때의 분포(rollout distribution)와 달라지는 distribution shift 문제가 있음 — 생성이 길어질수록 점점 더 나빠짐
- EBFT는 frozen feature network(학습 초기 모델 복사본)를 써서 모델 생성물과 ground-truth의 중간 레이어 임베딩을 비교하는 reward를 만들고, REINFORCE로 policy를 업데이트함 — 별도 verifier나 reward 모델 불필요
- Qwen2.5-1.5B 기준으로 EBFT가 SFT보다 validation cross-entropy도 낮게 달성(0.207 vs 0.289) — SFT가 CE를 직접 최적화하는데도 EBFT에 역전당하는 역설적 결과
- RLVR(강화학습 기반 검증 보상)은 downstream 정확도는 올리지만 cross-entropy와 feature-matching loss를 base model보다 오히려 악화시킴 — EBFT는 이 trade-off 없이 모든 지표를 동시에 개선
- 검증 reward가 없는 비정형 코드(GitHub raw code) 학습에서도 EBFT는 SFT 대비 큰 성능 향상 달성 — RLVR 적용 불가 영역에서도 쓸 수 있음
- Qwen2.5-1.5B/3B/7B 모두에서 일관된 성능 향상 확인, 스케일이 커질수록 개선폭도 유지됨
Evidence
- Q&A 코딩에서 HumanEval greedy 정확도: EBFT 0.548 vs SFT 0.483 vs RLVR 0.535 — EBFT가 verifier 없이도 RLVR 이상
- Validation cross-entropy: EBFT 0.207 vs SFT 0.289 vs RLVR 0.774 (base 0.338) — RLVR은 base보다 2배 이상 악화
- 번역 태스크(Llama-3.2-1B, MTNT): EBFT greedy COMET 0.737 vs SFT 0.703 vs RLVR 0.705 — out-of-distribution에서 EBFT가 가장 강함
- 비정형 코딩(unstructured code)에서 pass@1: EBFT 0.524 vs SFT 0.467 — RLVR 적용 불가 설정에서도 12% 개선
How to Apply
- SFT 1 epoch 후 EBFT를 이어서 적용하는 'warm-start' 전략: EBFT는 초기화 품질에 robust하지만, SFT warm-start 후 EBFT 적용 시 추가 이득을 얻을 수 있음. RLVR은 warm-start 없으면 성능이 크게 떨어지므로 EBFT가 더 실용적
- verifier를 만들기 어려운 도메인(raw 코드, 비정형 텍스트, 창작물 등)에서 SFT 대신 EBFT를 써보면 됨 — 정답 레이블 외에 추가 reward 설계 없이 feature-matching만으로 학습 가능
- feature network는 학습 시작 시 generator의 frozen 복사본을 그대로 쓰면 됨(별도 큰 모델 불필요). 레이어 25%, 50%, 75% 위치의 last-token 임베딩을 concat하고 whitening 적용하는 것이 핵심 설정
Code Example
# EBFT 핵심 reward 계산 의사코드 (Algorithm 1 기반)
import torch
import torch.nn.functional as F
def compute_ebft_rewards(rollout_features, gt_feature, n_samples):
"""
rollout_features: [n_samples, feature_dim] - 모델 생성물의 feature
gt_feature: [feature_dim] - ground-truth의 feature
n_samples: 프롬프트당 샘플 수
"""
# Whitening: second-moment matrix 계산
Sigma = rollout_features.T @ rollout_features / n_samples # [d, d]
# pseudo-inverse의 square root
U, S, Vh = torch.linalg.svd(Sigma)
S_inv_sqrt = torch.where(S > 1e-6, 1.0 / S.sqrt(), torch.zeros_like(S))
whitening_mat = U @ torch.diag(S_inv_sqrt) @ Vh # [d, d]
# Whitened features
w_rollouts = rollout_features @ whitening_mat.T # [n, d]
w_gt = gt_feature @ whitening_mat.T # [d]
# Normalized alignment term (Variant i)
w_rollouts_norm = F.normalize(w_rollouts, dim=-1)
w_gt_norm = F.normalize(w_gt.unsqueeze(0), dim=-1)
alignment = 2 * (w_rollouts_norm * w_gt_norm).sum(dim=-1) # [n]
# Diversity term
rewards = []
for j in range(n_samples):
other_idx = [k for k in range(n_samples) if k != j]
diversity = 2 / (n_samples - 1) * sum(
(w_rollouts[j] * w_rollouts[k]).sum()
for k in other_idx
)
r_j = alignment[j] - diversity
rewards.append(r_j)
rewards = torch.stack(rewards) # [n]
# RLOO baseline: 다른 샘플들의 reward 평균으로 baseline 설정
baseline = torch.stack([
rewards[torch.arange(n_samples) != j].mean()
for j in range(n_samples)
])
return rewards - baseline # advantageTerminology
Related Resources
Original Abstract (Expand)
Cross-entropy (CE) training provides dense and scalable supervision for language models, but it optimizes next-token prediction under teacher forcing rather than sequence-level behavior under model rollouts. We introduce a feature-matching objective for language-model fine-tuning that targets sequence-level statistics of the completion distribution, providing dense semantic feedback without requiring a task-specific verifier or preference model. To optimize this objective efficiently, we propose energy-based fine-tuning (EBFT), which uses strided block-parallel sampling to generate multiple rollouts from nested prefixes concurrently, batches feature extraction over these rollouts, and uses the resulting embeddings to perform an on-policy policy-gradient update. We present a theoretical perspective connecting EBFT to KL-regularized feature-matching and energy-based modeling. Empirically, across Q&A coding, unstructured coding, and translation, EBFT matches RLVR and outperforms SFT on downstream accuracy while achieving a lower validation cross-entropy than both methods.