Probing for Knowledge Attribution in Large Language Models
TL;DR Highlight
A simple linear classifier on LLM internal hidden states can distinguish whether the model used context vs parametric memory with 0.96 F1.
Who Should Read
Backend/ML engineers trying to debug why their RAG pipeline model ignores retrieved context and hallucinates. Devs building LLM response reliability systems or hallucination detection.
Core Mechanics
- When an LLM generates an answer, whether it "read from the prompt/context" vs "pulled from internal parameter memory" is linearly encoded in middle-to-upper layer hidden states
- Using trainable layer-wise aggregation + logistic regression alone achieves max Macro-F1 0.96 on Llama-3.1-8B, Mistral-7B, and Qwen2.5-7B
- Transfers to completely different domain datasets (SQuAD, WebQuestions) without retraining at 0.94-0.99 accuracy
- Attribution mismatch (wrong source usage) increases error rate by up to 70% — especially severe when misleading context overrides parametric memory
- Complex MLP classifiers learn lexical shortcuts like entity repetition, actually reducing robustness. Simple linear probes are more stable
- AttriWiki, an automated data pipeline, is released: uses Wikipedia + GPT-4o-mini to automatically generate "context-only" vs "parameter-only" examples
Evidence
- Layer-weighted logistic regression achieves Macro-F1 0.961 on Mistral-7B (LTE token) vs 0.904 for Final-Layer alone (+5.7pp)
- On SQuAD (contextual) and WebQuestions (parametric) external benchmarks, Qwen2.5-7B achieves 0.997 and 0.999 accuracy respectively — without retraining
- Attribution mismatch in misleading context condition increases error rate by up to 70%; parametric memory priority condition increases by 30%
- Text-based BoW/embedding classifiers only reach F1 0.65-0.68, proving attribution judgment is hard without hidden states
How to Apply
- During RAG system response generation, extract the hidden state of the first generated token (FTG) and pass it through an attribution probe to detect in real-time whether the model actually used retrieved context. If context-ignoring is detected, trigger re-retrieval or a warning.
- Reproduce the AttriWiki pipeline with your own domain data (legal, medical, etc.) on Llama/Mistral/Qwen family models to build a domain-specific attribution classifier that can catch hallucination patterns early in your specific domain.
- When surfacing attribution results in a chatbot UI, displaying "This answer is based on provided documents" vs "Based on model's internal knowledge" empowers users to judge answer reliability themselves.
Code Example
# Example of hidden state-based attribution probe inference (based on Llama-3.1-8B)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.linear_model import LogisticRegression
import numpy as np
model_name = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)
def get_first_token_hidden_states(prompt: str):
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Hidden states for all layers: (num_layers+1, seq_len, hidden_size)
hidden_states = outputs.hidden_states
# Representation of each layer at the last input token (just before first generation)
layer_reps = torch.stack([h[0, -1, :] for h in hidden_states[1:]]) # (L, H)
return layer_reps.cpu().numpy()
# Probe training (assuming weights trained on AttriWiki data are loaded)
# alpha: softmax weights per layer, probe: trained LogisticRegression
def predict_attribution(prompt: str, alpha: np.ndarray, probe) -> str:
layer_reps = get_first_token_hidden_states(prompt) # (L, H)
# Layer-weighted average
blended = (alpha[:, None] * layer_reps).sum(axis=0) # (H,)
pred = probe.predict([blended])[0]
return "contextual" if pred == 1 else "parametric"
# Usage example
context_prompt = "Based on this document: [DOC]. Q: What is X? A:"
result = predict_attribution(context_prompt, alpha, probe)
if result == "parametric":
print("⚠️ Model is ignoring the provided context and answering from internal knowledge")Terminology
Related Resources
Original Abstract (Expand)
Large language models (LLMs) often generate fluent but unfounded claims, or hallucinations, which fall into two types: (i) faithfulness violations - misusing user context - and (ii) factuality violations - errors from internal knowledge. Proper mitigation depends on knowing whether a model's answer is based on the prompt or its internal weights. This work focuses on the problem of contributive attribution: identifying the dominant knowledge source behind each output. We show that a probe, a simple linear classifier trained on model hidden representations, can reliably predict contributive attribution. For its training, we introduce AttriWiki, a self-supervised data pipeline that prompts models to recall withheld entities from memory or read them from context, generating labelled examples automatically. Probes trained on AttriWiki data reveal a strong attribution signal, achieving up to 0.96 Macro-F1 on Llama-3.1-8B, Mistral-7B, and Qwen-7B, transferring to out-of-domain benchmarks (SQuAD, WebQuestions) with 0.94-0.99 Macro-F1 without retraining. Attribution mismatches raise error rates by up to 70%, demonstrating a direct link between knowledge source confusion and unfaithful answers. Yet, models may still respond incorrectly even when attribution is correct, highlighting the need for broader detection frameworks.