Google Titans: 테스트 타임에 기억하는 법을 배우는 Neural Long-Term Memory 아키텍처
Google Titans architecture, helping AI have long-term memory
TL;DR Highlight
Transformer의 quadratic 복잡도 없이 200만 토큰 컨텍스트를 처리하는 Google의 새 아키텍처 — attention 대신 '놀라운 정보만 기억하는' 신경망 메모리 모듈을 도입.
Who Should Read
긴 문서 처리나 대화 히스토리 관리에서 컨텍스트 한계를 느끼는 LLM 애플리케이션 개발자. Transformer 대체 아키텍처(Mamba, RWKV 등)를 검토 중인 ML 엔지니어.
Core Mechanics
- Attention(단기 기억) + Neural Long-Term Memory(장기 기억) + Persistent Memory(영구 지식) 세 가지 메모리를 조합한 새 아키텍처 Titans 제안
- 장기 메모리 모듈은 별도의 작은 신경망으로 구현 — 추론 중에도 파라미터가 업데이트되는 'test-time training' 방식
- 'Surprise(놀라움)' 지표로 뭘 기억할지 결정 — 예측 오차(gradient norm)가 클수록 중요한 정보로 간주해 더 강하게 기억
- 세 가지 아키텍처 변형 제공: 메모리를 컨텍스트로 주입하는 MAC, 게이팅으로 혼합하는 MAG, 레이어로 삽입하는 MAL
- 시퀀스 길이가 늘어도 선형 복잡도 유지 — 2M 토큰 처리 가능, Transformer의 quadratic 복잡도 문제 해결
- Recency Bias 없이 오래된 정보도 유지 — 기존 선형 RNN(Mamba 등)이 최근 토큰에 편향되는 문제를 Forgetting Gate로 보완
Evidence
- BABILong 벤치마크(초장문 추론)에서 Transformer, Mamba-2, TTT 등 기존 모델 대비 일관적으로 성능 우위
- 2M 토큰 컨텍스트에서도 선형 시간복잡도 유지 — Transformer 대비 메모리 사용량 대폭 절감
- MQAR(Multi-Query Associative Recall) 태스크에서 시퀀스 길이 4K~16K 구간 모두에서 Mamba-2, GLA보다 높은 정확도
- 언어 모델링(Wikitext-103, SlimPajama) 및 시계열, DNA 시퀀스 등 다양한 도메인에서 경쟁력 있는 perplexity 달성
How to Apply
- 긴 대화 히스토리를 처리하는 챗봇을 만들 때: Titans 기반 모델이 공개되면 KV Cache 크기 제한 없이 전체 대화를 메모리에 유지할 수 있어 RAG 없이도 긴 세션 처리 가능
- 문서 분석 파이프라인에서 청킹 전략을 쓰고 있다면: Titans 계열 모델은 2M 토큰을 한 번에 처리하므로 청크 분할 로직을 단순화하거나 제거하는 방향으로 아키텍처 재검토 가능
- 자체 모델을 학습하거나 파인튜닝하는 경우: MAC/MAG/MAL 변형 중 사용 케이스에 맞게 선택 — 검색/QA는 MAC(명시적 컨텍스트 주입), 생성 품질 중심은 MAG(게이트 혼합) 권장
Code Example
snippet
# Titans 아키텍처의 Surprise 기반 메모리 업데이트 핵심 로직 (개념 코드)
import torch
import torch.nn as nn
class NeuralMemory(nn.Module):
def __init__(self, dim, memory_lr=0.01):
super().__init__()
# 장기 메모리 = 작은 MLP (key-value 연상 기억)
self.memory_mlp = nn.Sequential(
nn.Linear(dim, dim * 2),
nn.SiLU(),
nn.Linear(dim * 2, dim)
)
self.memory_lr = memory_lr # 메모리 업데이트 속도
def compute_surprise(self, query, target):
"""Surprise = 예측 오차의 크기 (gradient norm)"""
pred = self.memory_mlp(query)
loss = nn.functional.mse_loss(pred, target)
grad = torch.autograd.grad(loss, self.memory_mlp.parameters())
surprise = sum(g.norm() for g in grad) # 놀라울수록 강하게 기억
return surprise, loss
def update_memory(self, key, value, forget_gate):
"""놀라운 정보를 메모리에 기록 (test-time update)"""
surprise, loss = self.compute_surprise(key, value)
# Forgetting Gate: 오래된 기억 decay + 새 정보 기록
effective_lr = self.memory_lr * surprise.item() * forget_gate
for param in self.memory_mlp.parameters():
if param.grad is not None:
param.data -= effective_lr * param.grad
def recall(self, query):
"""쿼리로 장기 메모리에서 정보 검색"""
return self.memory_mlp(query)
# MAC (Memory as Context) 변형 사용 예시
# memory_output = neural_memory.recall(current_query)
# context = torch.cat([memory_output, short_term_kv_cache], dim=1)
# output = attention(query, context) # 장기 + 단기 메모리 통합Terminology
Quadratic ComplexityTransformer가 토큰 수가 2배 늘면 연산량이 4배 늘어나는 문제. 1000토큰이면 괜찮지만 100만 토큰이면 사실상 불가능해짐.
Test-Time Training보통 AI는 학습 후 추론 때는 파라미터가 고정되는데, 이건 추론 중에도 메모리 모듈 파라미터를 업데이트함. 시험 보면서 공부하는 것과 비슷.
KV CacheTransformer가 이전 토큰들의 Key-Value를 저장해두는 공간. 길이가 길수록 메모리가 폭발적으로 늘어나는 단기 기억 저장소.
Forgetting Gate새 정보를 기억할 때 오래된 정보를 얼마나 지울지 결정하는 스위치. 뇌가 오래된 기억을 자연스럽게 흐리게 만드는 것과 유사.
Persistent Memory학습 데이터에서 굳어진 영구 지식 (가중치에 박힌 상식/언어 규칙 등). 사람으로 치면 태어날 때부터 가진 본능적 지식.
Linear RecurrenceMamba, RWKV처럼 이전 상태를 선형으로 압축해 넘기는 방식. 연산은 빠르지만 먼 과거 정보가 희미해지는 단점이 있음.
SurpriseTitans에서 '이 정보가 얼마나 예상 밖인가'를 수치화한 것. 예측 오차가 클수록 = 놀라울수록 더 강하게 장기 기억에 저장됨.