PPO를 활용한 언어 모델의 Tree Search Distillation
Tree Search Distillation for Language Models Using PPO
TL;DR Highlight
AlphaZero처럼 MCTS로 더 강한 추론 경로를 탐색한 뒤 그 결과를 PPO로 모델에 다시 증류(distill)하는 방법을 실험한 글로, 표준 RL 방법(GRPO)보다 높은 성능을 보여줬다.
Who Should Read
LLM의 추론 능력을 강화하는 방법을 연구하거나, RL 기반 학습(GRPO, PPO 등)을 언어 모델에 적용해보려는 ML 엔지니어 및 연구자.
Core Mechanics
- AlphaZero 같은 게임 AI는 '정책(policy) + 탐색(search) + 증류(distillation)' 사이클로 성능을 높이는데, 언어 모델에서는 이 방법이 잘 쓰이지 않는다. DeepSeek-R1 팀도 MCTS를 시도했지만 효과가 제한적이었다고 밝혔는데, 이는 탐색 알고리즘으로 UCT 대신 pUCT를 썼어야 한다는 분석이 있다.
- 실험 모델은 Qwen-2.5-1.5B-Instruct이고, 태스크는 Countdown(주어진 정수 4개로 사칙연산을 이용해 목표 숫자를 만드는 조합 산술 게임)이다. GSM8K로 먼저 시도했지만 GRPO와 MCTS 간 차이가 미미해서 조합적 탐색이 더 유리한 Countdown으로 전환했다.
- 성능 결과: 증류된 모델(탐색 없이 단독 추론)이 mean@16 기준 11.3%를 달성했고, 비교 대상인 CISPO는 8.4%, best-of-N은 7.7%, 사전 RL 모델은 3.1%였다. 절대 수치가 낮은 이유는 1.5B라는 작은 모델 크기 때문이며, 향후 더 큰 모델로 실험을 이어갈 예정이다.
- 보상 함수 설계가 중요했다. 처음에 정답/오답만 주는 sparse reward(0/1)를 쓰면 학습이 불안정했다. 대신 예측값과 정답의 차이에 비례하는 dense reward($1.0 - 2 \cdot \min(|t-p|/t, 1.0)$)를 학습에 사용하고, 평가는 여전히 sparse reward로 진행했다.
- 토큰 단위가 아닌 추론 스텝(reasoning step) 단위로 MCTS를 적용했다. 토큰 단위로 분기하면 'but', 'however', 'yet' 같은 기능어에서도 가지가 갈라져 탐색 트리가 비효율적으로 커진다. 대신 Tree-of-Thoughts 방식처럼 `<step>...</step>` 태그로 묶인 추론 단계를 하나의 노드로 취급했다.
- 탐색 다양성을 높이기 위해 N개의 에이전트가 같은 샘플의 탐색 트리를 공유하는 parallel MCTS를 구현했다. 각 에이전트는 virtual loss를 써서 서로 다른 경로를 탐색하도록 유도했다. pUCT에서 필요한 행동(action) 사전 확률은 시퀀스 로그확률을 합산한 뒤 softmax를 취해 계산했다.
- 모델에 value head(MLP + tanh)를 추가해 현재 상태의 가치를 예측하게 했다. 이 value head는 학습 중 점차 개선되면서 MCTS가 더 좋은 탐색 경로를 찾도록 안내한다. MCTS로 찾은 더 강한 추론 경로를 PPO 루프를 통해 모델 가중치에 온라인으로 증류하는 구조다.
Evidence
- MCTS를 학습 시 탐색에만 쓰고 PPO로 증류하면, 추론 시에는 탐색 없이 모델만 쓰는 것이므로 추론 비용이 GRPO와 동일한 것 아니냐는 질문이 있었다. 원문에서 'MCTS는 샘플당 추론 컴퓨팅을 더 많이 쓰니까 당연히 성능이 좋다'는 표현이 있는데, 증류된 모델은 추론 시 탐색을 쓰지 않으므로 이 표현이 혼란스럽다는 지적이었다.
- MCTS를 증류 없이 test-time compute harness로만 쓸 때 같은 컴퓨팅 예산 기준으로 best-of-N과 비교해봤는지 묻는 댓글이 있었다. 이는 MCTS의 탐색 효율 자체를 검증하는 중요한 비교인데, 원문에서는 이 비교가 명시적으로 다뤄지지 않았다.
- MCTS가 test-time compute 방법으로 왜 더 많이 쓰이지 않는지 의문을 표하는 댓글이 있었다. 언어 모델에서 MCTS가 어려운 이유(토큰 단위 탐색의 비효율, value function 학습의 어려움 등)에 대한 관심이 있었고, 이 글이 그 가능성을 탐구하는 시도라는 점에서 긍정적인 반응이 있었다.
How to Apply
- 소규모 조합 최적화 또는 수학 추론 태스크를 다루는 경우, 토큰 단위가 아닌 추론 스텝 단위로 MCTS를 적용하면 탐색 트리의 크기를 제어하면서도 다양한 추론 경로를 탐색할 수 있다. `<step>...</step>` 같은 구조화된 태그를 프롬프트에 도입하고, 각 스텝 완성 시점을 MCTS 노드 전환점으로 삼으면 된다.
- RL 기반 파인튜닝 시 sparse reward(정답/오답 0/1)로 학습이 불안정하다면, 예측값과 정답 사이의 거리에 비례하는 dense reward 함수를 학습에 사용하고 평가만 sparse reward로 유지하는 방식을 고려해볼 수 있다. 이 글의 공식은 $1.0 - 2 \cdot \min(|t-p|/t, 1.0)$이다.
- MCTS 탐색 다양성이 부족하다면 virtual loss를 도입한 parallel MCTS를 적용할 수 있다. 여러 에이전트가 같은 탐색 트리를 공유하되 서로 다른 노드를 방문하도록 유도하면 같은 컴퓨팅 예산으로 더 넓은 탐색이 가능하다.
- pUCT 선택 확률 계산 시 raw 누적 시퀀스 확률 대신 로그확률 합산 후 softmax를 적용하면 수치적 언더플로 문제를 피할 수 있다. 특히 긴 시퀀스를 다루는 경우 이 처리가 학습 안정성에 중요하다.
Terminology
MCTSMonte Carlo Tree Search의 약자. 가능한 행동들을 트리 구조로 탐색하면서 시뮬레이션을 반복해 최선의 경로를 찾는 알고리즘. 바둑 AI AlphaGo에서 유명해졌다.
PPOProximal Policy Optimization의 약자. RL(강화학습)에서 기존 정책을 너무 크게 바꾸지 않도록 제한하면서 학습하는 안정적인 알고리즘. LLM 파인튜닝에 자주 쓰인다.
distillation증류. 더 강한 모델(또는 탐색 결과)의 행동을 더 작은/약한 모델이 모방하도록 학습시키는 방법. 교사의 노하우를 학생에게 전수하는 것과 비슷하다.
pUCTPredictor + Upper Confidence bounds for Trees. MCTS에서 각 노드를 선택할 때 사전 확률(prior)을 함께 고려하는 변형 알고리즘. 순수 탐색/활용 균형만 보는 UCT보다 언어 모델처럼 사전 확률 정보가 풍부한 경우에 더 적합하다.
value head트랜스포머 모델에 추가하는 작은 신경망 레이어로, 현재 상태가 얼마나 좋은지(가치)를 숫자로 예측하게 한다. MCTS가 탐색 방향을 결정할 때 이 값을 참고한다.
dense reward학습 중 매 스텝마다 연속적인 보상 신호를 주는 방식. 정답/오답만 알려주는 sparse reward와 달리 '얼마나 가까운지'도 알려줘서 학습이 더 안정적이다.