AL1 few-shot에서 7단계→4단계 축소 시 LLaMA3.1-70B 정확도 99.80%→99.20%, GPT-4o-mini 98.00%→98.80% 오히려 소폭 향상, 반면 랜덤 제거는 94.40%로 하락
NBC 태스크 12단계→9단계 축소 시 Ours(merge) GPT-4o-mini 95.80%→97.80%, 랜덤 제거는 91.60%로 하락 — 6%p 이상 차이
강한 모델(LLaMA3-8B)의 perplexity로 약한 모델(LLaMA2-7B, Qwen1.5-7B) 파인튜닝 데이터를 정제했을 때, 해당 약한 모델 자체 perplexity로 정제한 것보다 오히려 성능이 더 좋음
How to Apply
Few-shot 프롬프트 최적화: 기존 CoT 데모 예시에서 각 스텝을 하나씩 빼보고 calibration 샘플들의 perplexity 변화를 측정 — 가장 적게 변하는 스텝부터 제거 또는 인접 스텝과 병합하면 토큰을 줄이면서 정확도 유지 가능
CoT 파인튜닝 데이터 정제: GSM8K 같은 수학 추론 데이터셋의 각 샘플에서 SPIRIT-FT 알고리즘으로 불필요한 추론 스텝을 제거한 뒤 LoRA SFT/ORPO로 학습하면 생성 토큰 수를 줄이면서 정확도 트레이드오프를 조절 가능
모델 접근 제한 상황 (GPT-4o 등 closed 모델 파인튜닝): LLaMA3.1-70B 같은 오픈소스 모델로 perplexity를 계산해서 스텝 선택 후, 그 결과를 GPT 계열 모델의 few-shot 데모로 활용 — cross-model transferability 덕분에 성능 유지됨
Code Example
snippet
Terminology
Perplexity (PPL)모델이 텍스트를 얼마나 '예상 못했는지'를 나타내는 수치. 낮을수록 모델이 해당 텍스트를 자연스럽게 받아들인다는 뜻. 낯선 단어가 많은 문장일수록 PPL이 높아짐.
Chain-of-Thought (CoT)LLM이 바로 답을 내지 않고 '1단계: ... 2단계: ...' 식으로 풀이 과정을 쭉 쓰면서 추론하게 하는 방법. 수학 문제 풀 때 중간 계산 과정을 적는 것과 같음.
Few-shot CoT프롬프트에 풀이 예시를 2~5개 넣어주고 모델이 그 패턴을 따라 새 문제를 풀게 하는 방식. 시험 전에 예제 문제 몇 개 보여주는 것과 비슷.
SFT (Supervised Fine-Tuning)모범답안 데이터를 보여주고 따라하게 학습시키는 방법. 학교에서 선생님이 풀이 과정을 보여주고 학생이 비슷하게 풀도록 연습시키는 것.
LoRA모델 전체 파라미터를 다 학습하지 않고 작은 어댑터 레이어만 추가해서 학습하는 기법. 전체 옷을 새로 맞추는 대신 패치만 붙이는 것처럼 효율적.
ORPO (Odds Ratio Preference Optimization)좋은 응답과 나쁜 응답 쌍을 보여줘서 좋은 쪽을 선호하도록 학습시키는 방법. 별도의 보상 모델 없이 SFT와 preference 학습을 동시에 함.
Calibration Set알고리즘 튜닝에 쓰이는 소규모 검증용 샘플 집합. 실제 테스트 전에 설정값이 잘 작동하는지 확인하는 용도로, 일종의 파라미터 조정용 데이터셋.
Original Abstract (Expand)
Chain-of-Thought (CoT) reasoning, which breaks down complex tasks into intermediate reasoning steps, has significantly enhanced the performance of large language models (LLMs) on challenging tasks. However, the detailed reasoning process in CoT often incurs long generation times and high computational costs, partly due to the inclusion of unnecessary steps. To address this, we propose a method to identify critical reasoning steps using perplexity as a measure of their importance: a step is deemed critical if its removal causes a significant increase in perplexity. Our method enables models to focus solely on generating these critical steps. This can be achieved through two approaches: refining demonstration examples in few-shot CoT or fine-tuning the model using selected examples that include only critical steps. Comprehensive experiments validate the effectiveness of our method, which achieves a better balance between the reasoning accuracy and efficiency of CoT.
# SPIRIT-FT 핵심 로직 간단 구현 예시
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
def compute_ppl(model, tokenizer, question, reasoning_steps):
"""추론 스텝을 포함한 전체 시퀀스의 perplexity 계산"""
text = question + " " + " ".join(reasoning_steps)
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs, labels=inputs["input_ids"])
return torch.exp(outputs.loss).item()
def find_least_important_step(model, tokenizer, question, steps, t2=1.2):
"""
perplexity가 가장 적게 오르는 스텝 찾기
t2: 이 비율 이상 perplexity가 오르면 중요한 스텝으로 간주해서 중단
"""
orig_ppl = compute_ppl(model, tokenizer, question, steps)
best_step_idx = None
best_ppl = float('inf')
for i, step in enumerate(steps):
remaining = steps[:i] + steps[i+1:]
ppl_after_removal = compute_ppl(model, tokenizer, question, remaining)
# perplexity가 가장 적게 변하는 스텝 선택
if ppl_after_removal < best_ppl:
best_ppl = ppl_after_removal
best_step_idx = i
# t2 기준: 제거 후 perplexity가 너무 오르면 중단
if best_ppl > t2 * orig_ppl:
return None, orig_ppl
return best_step_idx, best_ppl
def spirit_ft_refine(model, tokenizer, question, steps, t1=0.95, t2=1.2):
"""
SPIRIT-FT: 파인튜닝 데이터의 추론 스텝 정제
t1: 이 비율보다 낮으면 바로 제거 (병합 불필요)
t2: 이 비율보다 높으면 중단 (스텝이 너무 중요함)
"""
refined_steps = steps.copy()
while len(refined_steps) > 1:
orig_ppl = compute_ppl(model, tokenizer, question, refined_steps)
idx, new_ppl = find_least_important_step(
model, tokenizer, question, refined_steps, t2
)
if idx is None: # 더 제거할 스텝 없음
break
if new_ppl < t1 * orig_ppl:
# perplexity가 충분히 낮음 -> 바로 제거
refined_steps.pop(idx)
else:
# perplexity 변화가 크면 -> LLM으로 인접 스텝과 병합
# 실제 구현시 GPT-4o 등으로 병합 프롬프트 사용
print(f"Step {idx} needs merging: '{refined_steps[idx]}'")
# merged = merge_with_llm(refined_steps, idx) # 병합 로직
refined_steps.pop(idx) # 여기선 단순 제거로 대체
return refined_steps
# 사용 예시
# model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_pretrained(model_name)
#
# question = "A store has 100 apples. 20% are sold. How many remain?"
# steps = [
# "20% of 100 apples = 20 apples sold.",
# "Remaining apples = 100 - 20 = 80.",
# "The answer is 80."
# ]
# refined = spirit_ft_refine(model, tokenizer, question, steps)
# print("Refined steps:", refined)