NanoGPT Slowrun: 무한 컴퓨팅으로 10배 데이터 효율 달성
NanoGPT Slowrun: 10x Data Efficiency with Infinite Compute
TL;DR Highlight
1.8B 파라미터 모델 앙상블을 100M 토큰만으로 학습시켜 기존 1B 토큰 학습 성능을 따라잡는 10배 데이터 효율을 몇 주 만에 달성했다는 실험 결과 공유. 컴퓨팅은 넘치는데 데이터가 병목이 되는 미래를 대비하는 접근법이다.
Who Should Read
LLM 사전학습(pretraining) 실험을 직접 돌려보는 ML 연구자나 엔지니어, 또는 제한된 데이터로 더 좋은 모델을 만들어야 하는 상황에 처한 AI 개발자.
Core Mechanics
- 핵심 주장은 '데이터 10배 효율' — 1.8B 파라미터 모델 18개를 앙상블(ensemble)해 총 18B 파라미터 규모로 만들었을 때, 100M 토큰으로 학습해도 일반적인 LM 베이스라인이 1B 토큰으로 달성하는 성능과 동등한 결과를 냈다.
- Chinchilla 스케일링 법칙에 따르면 100M 토큰에는 약 5M 파라미터 모델이 적합한데, 여기서는 1.8B 파라미터를 사용해 무려 360배 오버파라미터화(overparameterized)된 상태에서 실험을 진행했다. 기존 통념과 정반대 방향이다.
- 앙상블(여러 모델의 예측을 평균 내는 기법)이 핵심인데, 단일 모델은 12 에폭 이후 계속 학습하면 loss가 3.295 → 3.310으로 오히려 나빠졌지만, 앙상블 loss는 3.185 → 3.166으로 계속 개선됐다. 모델들이 개별 최적점을 넘어서도 서로 다른 것을 배우면서 앙상블 전체의 성능이 올라가는 구조다.
- Chain Distillation(순차 지식 증류)을 적용해 8개 모델을 직렬로 학습시켰다 — M1을 일반 학습 후, M2는 M1을 고정(frozen) 교사로 삼아 CE loss와 KL divergence를 α=0.5, T=1.0으로 혼합한 loss로 학습, 이런 식으로 이어간다. 이전 모델 하나만 교사로 쓰기 때문에 메모리는 일정하게 유지된다. 이 방식으로 데이터 효율이 7배에서 8배로 개선됐다.
- 정규화(regularization) 수준을 매우 공격적으로 높였다 — 표준 관행의 weight decay가 약 0.1인데 여기서는 1.6을 사용해 16배 높다. 오버파라미터화된 상태에서는 이 정도 정규화가 효과적이며, 모델이 클수록 더 강한 정규화가 필요하다는 것을 Kim et al. 연구와 일치하는 결과로 확인했다.
- Looped Transformer(반복 레이어 구조)를 도입했다 — 30레이어 트랜스포머를 학습 절반 지점부터 레이어 15~24를 4번 반복 실행하도록 바꿨다. 단순히 레이어를 한 번만 통과하는 게 아니라 중간 레이어를 여러 번 돌면서 표현을 정제하는 구조다. 마지막 몇 개 레이어는 반복하지 않는 것이 최적이라는 점도 발견했다.
- 이 연구의 큰 그림은 '컴퓨팅은 계속 늘어나지만 데이터는 한계가 있다'는 전제 하에, 데이터를 더 쓰는 게 아니라 컴퓨팅을 더 쓰는 방향으로 성능을 높이는 경로를 찾는 것이다. 앙상블 스케일링은 고정된 데이터로 컴퓨팅만 늘려도 계속 성능 향상이 가능하다는 점에서 의미가 있다.
Evidence
- 데이터 병목 전제에 대한 회의적 시각이 있었다 — '합성 데이터(synthetic data) 생성 능력이 좋아졌으니 컴퓨팅이 많으면 데이터도 더 만들 수 있다'는 반론이 제기됐다. 2023년 이후 모든 주요 AI 랩이 이미 이 방향을 쓰고 있어, 데이터가 병목이라는 전제 자체가 오래된 2022년 논문 기반이라는 지적이었다.
- Chinchilla 비교 기준에 대한 비판도 있었다 — 현업에서는 이미 소형 모델을 Chinchilla 최적 대비 10~400배 더 많은 데이터(1~40T 토큰)로 학습시키는 게 일반적인데, 이 연구는 정반대 방향(더 큰 모델, 더 적은 데이터)을 택했다. 산업계 트렌드와 역방향이라는 점에서 도입부 주장을 곧이곧대로 받아들이지 말라는 의견이었다.
- 자가 부트스트랩 가능성에 대한 기대도 나왔다 — 'LLM이 더 나은 LLM을 학습시키는 루프를 만들 수 있는 단계에 가까워지고 있다'는 흥미로운 시각도 등장했다. 이 연구의 체인 디스틸레이션 구조가 그런 방향의 초기 형태로 읽힐 수 있다는 맥락이었다.
- 초기 NanoGPT Slowrun 발표 관련 HN 토론(185포인트, 39댓글)이 이전에 있었으며, 이번 결과는 그 후속 연구임이 댓글에서 공유됐다.
How to Apply
- 고정된 데이터셋으로 모델 성능을 더 올려야 하는 상황이라면, 동일 데이터로 여러 모델을 독립적으로 학습한 뒤 inference 시 logit을 평균 내는 앙상블을 시도해볼 수 있다. 단일 모델을 더 오래 학습하는 것보다 앙상블 구성 후 각각 더 많은 에폭을 돌리는 게 효과적일 수 있다.
- 앙상블 학습 시 Chain Distillation을 적용하려면 위 알고리즘대로 M1을 먼저 학습 후 freeze, M2를 CE + KL(α=0.5, T=1.0) 혼합 loss로 학습하는 방식을 NanoGPT 같은 소규모 실험 환경에서 먼저 검증해볼 수 있다. 오픈소스 PR #31 구현이 공개되어 있으므로 코드 참고가 가능하다.
- 오버파라미터화된 환경에서 학습할 때는 weight decay를 표준값(0.1)보다 훨씬 높게 잡는 실험을 해볼 만하다. 이 연구에서는 1.6(16배)까지 올렸으며, 데이터 대비 모델이 클수록 더 강한 정규화가 도움이 된다는 점을 확인했으니 유사 세팅에서 hyperparameter search 범위를 넓혀볼 수 있다.
Code Example
snippet
# Chain Distillation Ensemble 학습 루프 (의사코드)
def train_chain_distillation_ensemble(data, num_models=8, alpha=0.5, T=1.0):
models = []
# 첫 번째 모델: 일반 cross-entropy loss로 학습
M1 = train_model(data, loss_fn='cross_entropy')
models.append(M1)
# 이후 모델들: 직전 모델을 teacher로 사용
for k in range(2, num_models + 1):
teacher = models[-1] # 직전 모델만 teacher로 사용 (메모리 효율)
freeze(teacher)
def distill_loss(student_logits, teacher_logits, labels):
ce_loss = cross_entropy(student_logits, labels)
kl_loss = T**2 * kl_divergence(
student_logits / T,
teacher_logits / T
)
return (1 - alpha) * ce_loss + alpha * kl_loss
M_k = train_model(data, loss_fn=distill_loss, teacher=teacher)
models.append(M_k)
del teacher # 메모리에서 teacher 제거
return models
def ensemble_inference(models, input_tokens):
# 모든 모델의 logit을 평균 내어 최종 예측
all_logits = [model(input_tokens) for model in models]
return sum(all_logits) / len(all_logits)
# Looped Transformer 설정 예시 (30레이어 기준)
# 레이어 0-14: 일반 통과
# 레이어 15-24: 4회 반복
# 레이어 25-29: 일반 통과 (마지막 레이어는 반복 안 함)
loop_config = {
'total_layers': 30,
'loop_start': 15,
'loop_end': 24,
'loop_count': 4
}Terminology
Ensemble여러 모델이 각자 예측한 결과를 합쳐 최종 답을 내는 방식. 여러 전문가 의견을 모아 다수결로 결정하는 것과 비슷하다.
Chain Distillation모델 A가 학습한 지식을 모델 B에게 전달하고, B의 지식을 C에게 전달하는 식으로 순차적으로 이어가는 학습 방식. 선배가 후배를 가르치고 후배가 그 다음 후배를 가르치는 도제 시스템과 유사하다.
Looped Transformer일반 트랜스포머가 레이어를 한 번만 통과하는 것과 달리, 중간 레이어를 여러 번 반복 통과하면서 생각을 다듬는 구조. 문제를 한 번에 풀지 않고 여러 번 검토하는 것과 같다.
Weight Decay학습 중 모델의 파라미터 값이 너무 커지지 않도록 패널티를 주는 정규화 기법. 모델이 특정 패턴에 과도하게 의존하지 않도록 적당히 억제하는 역할을 한다.
Chinchilla Scaling Law주어진 컴퓨팅 예산에서 최적 모델 크기와 데이터 양의 비율을 정의한 법칙. '모델 파라미터 수와 학습 토큰 수를 1:20 비율로 맞추는 게 효율적'이라는 게 핵심이다.
KL Divergence두 확률 분포가 얼마나 다른지를 측정하는 지표. Knowledge Distillation에서 학생 모델의 예측 분포가 교사 모델의 분포와 얼마나 차이 나는지 측정하는 데 사용된다.