Rethinking with Retrieval: 외부 지식 검색으로 LLM 추론 정확도 높이기
Rethinking with Retrieval: Faithful Large Language Model Inference
TL;DR Highlight
CoT로 생성한 추론 단계별로 외부 지식을 검색해서, 가장 사실에 충실한 답변을 고르는 post-processing 기법
Who Should Read
LLM이 틀린 사실을 자신있게 말하는 hallucination 문제를 해결하고 싶은 개발자. 파인튜닝 없이 GPT 계열 모델의 추론 정확도를 높이고 싶은 AI 서비스 개발자.
Core Mechanics
- CoT(Chain-of-Thought, 단계별 추론 프롬프트)로 여러 개의 추론 경로를 샘플링한 뒤, 각 추론 단계를 쿼리로 삼아 외부 지식(Wikipedia, Wikidata 등)을 검색함
- 검색된 외부 지식과 각 추론 경로를 NLI(자연어 추론) 모델로 비교해 '사실 충실도 점수'를 매기고, 가장 점수 높은 예측을 최종 답으로 선택
- 추론 경로 전체가 아니라 '분해된 각 추론 단계'로 검색하는 게 핵심 - 원래 질문으로 검색하면 성능이 훨씬 낮음 (commonsense: 73.36% vs 77.73%)
- 파인튜닝이나 추가 학습 없이 GPT-3에 post-processing으로 붙이는 방식이라 어떤 LLM에도 적용 가능
- GPT-3가 틀리는 주요 원인 두 가지 확인: 잘못된 supporting fact(예: Lil Jon 최고 빌보드 곡을 'Get Low'라고 잘못 암기)와 올바른 사실에서 잘못된 추론
- 작은 OPT 모델(1.3B~30B)에도 적용하면 CoT보다 일관되게 높은 정확도와 사실 충실도를 보임
Evidence
- commonsense reasoning(StrategyQA): Self-consistency 73.36% → RR 77.73% (+4.37%p)
- temporal reasoning(TempQuestions): Self-consistency 37.28% → RR 39.05% (+1.77%p)
- tabular reasoning(INFOTABS): Self-consistency 84.00% → RR 84.83% (+0.83%p)
- 설명 사실 충실도(Faithfulness): CoT 38.73% → RR Variant II 54.54% (+15.81%p)
How to Apply
- RAG 파이프라인에서 사용자 질문을 그대로 검색 쿼리로 쓰는 대신, CoT로 먼저 추론 단계를 분해하고 각 단계를 개별 검색 쿼리로 사용하면 더 관련성 높은 문서를 찾을 수 있음
- LLM이 여러 답변 후보를 생성할 때(temperature > 0 샘플링), 각 후보의 추론 과정을 외부 KB와 NLI 모델로 검증해서 가장 사실에 충실한 답변을 자동 선택하는 검증 레이어를 추가할 수 있음
- 금융·법률·의료처럼 hallucination이 치명적인 도메인에서 GPT-4 응답을 그대로 쓰기 불안할 때, BM25로 신뢰할 수 있는 내부 문서를 검색해 NLI 기반 사실 검증 스코어링을 붙이면 됨
Code Example
from sentence_transformers import SentenceTransformer, util
from pyserini.search.lucene import LuceneSearcher
import torch
# 1. CoT로 여러 추론 경로 샘플링 (temperature=0.7)
reasoning_paths = [
"Aristotle died in 2000. The first laptop was invented in 1980. So the answer is yes.",
"Aristotle died in 322BC. The first laptop was invented in 2000. So the answer is no.",
"Aristotle died in 322BC. The first laptop was invented in 1980. So the answer is no."
]
# 2. 각 추론 단계(문장)를 쿼리로 BM25 검색
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-dpr')
def retrieve_for_sentence(sentence, top_k=10):
hits = searcher.search(sentence, k=top_k)
return [hit.raw for hit in hits]
# 3. MPNet으로 가장 유사한 문단 선택
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
def get_most_similar_para(sentence, paragraphs):
sent_emb = model.encode(sentence, convert_to_tensor=True)
para_embs = model.encode(paragraphs, convert_to_tensor=True)
scores = util.cos_sim(sent_emb, para_embs)[0]
best_idx = scores.argmax().item()
return paragraphs[best_idx], scores[best_idx].item()
# 4. NLI 모델로 faithfulness 점수 계산
from transformers import pipeline
nli = pipeline('text-classification', model='cross-encoder/nli-deberta-v3-base')
def faithfulness_score(sentence, premise):
result = nli(f"{premise} [SEP] {sentence}")[0]
if result['label'] == 'ENTAILMENT':
return result['score']
elif result['label'] == 'CONTRADICTION':
return -result['score']
return 0
# 5. 가장 faithful한 추론 경로의 예측 선택
path_scores = {}
for path in reasoning_paths:
sentences = path.split('. ')
prediction = sentences[-1] # 'So the answer is ...'
score = 0
for sent in sentences[:-1]:
paras = retrieve_for_sentence(sent)
best_para, sim = get_most_similar_para(sent, paras)
score += faithfulness_score(sent, best_para)
path_scores[path] = (score, prediction)
best_path = max(path_scores, key=lambda x: path_scores[x][0])
print(f"최종 예측: {path_scores[best_path][1]}")Terminology
Related Resources
Original Abstract (Expand)
Despite the success of large language models (LLMs) in various natural language processing (NLP) tasks, the stored knowledge in these models may inevitably be incomplete, out-of-date, or incorrect. This motivates the need to utilize external knowledge to assist LLMs. Unfortunately, current methods for incorporating external knowledge often require additional training or fine-tuning, which can be costly and may not be feasible for LLMs. To address this issue, we propose a novel post-processing approach, rethinking with retrieval (RR), which retrieves relevant external knowledge based on the decomposed reasoning steps obtained from the chain-of-thought (CoT) prompting. This lightweight approach does not require additional training or fine-tuning and is not limited by the input length of LLMs. We evaluate the effectiveness of RR through extensive experiments with GPT-3 on three complex reasoning tasks: commonsense reasoning, temporal reasoning, and tabular reasoning. Our results show that RR can produce more faithful explanations and improve the performance of LLMs.