고성능 RL Environment 자동 생성: 코딩 에이전트 + 계층적 검증으로 $10 이하에 구현
Automatic Generation of High-Performance RL Environments
TL;DR Highlight
AI 코딩 에이전트가 RL 학습 환경을 JAX/Rust로 자동 변환해서 최대 22,320배 빠르게 만들어주는 레시피 ($10 이하)
Who Should Read
RL 학습에서 환경 시뮬레이션이 병목인 ML 엔지니어나 연구자. 특히 느린 Python/TypeScript 환경을 GPU 병렬로 바꾸고 싶은데 JAX 포팅에 엔지니어링 리소스가 없는 팀.
Core Mechanics
- 코딩 에이전트(Gemini 3 Flash Preview)가 100K+ 줄짜리 TypeScript Pokemon Showdown을 JAX로 변환 → PPO 학습 속도 22,320배 향상, 비용 $6
- 4단계 계층적 검증 구조: L1(개별 컴포넌트 단위 테스트) → L2(모듈 간 상호작용) → L3(전체 에피소드 롤아웃 비교) → L4(cross-backend 정책 이전)로 시맨틱 동등성 보장
- L3 롤아웃 테스트만 쓰면 복잡한 물리 환경(HalfCheetah)에서 42번 반복해도 수렴 실패, 계층적 검증은 5번 만에 수렴 — 검증 구조가 핵심
- 200M 파라미터 모델 기준 환경 오버헤드가 학습 시간의 4% 이하로 떨어짐 (기존 50~90% → 4%)
- 웹에서 규칙 추출해 아예 새 환경 생성도 가능: TCGJax(포켓몬 TCG)를 공개 레포 없이 웹 스펙만으로 만들어서 Python 대비 6.6배 빠른 학습 환경 구현
- 방법론 자체가 에이전트 무관(agent-agnostic): Claude Sonnet 4.6, Claude Opus 4.6으로 동일 프롬프트 재실행해도 동일하게 동작 확인
Evidence
- PokeJAX: TypeScript 681 SPS → JAX 15.2M SPS PPO, 22,320배 속도 향상, 비용 $6
- HalfCheetah JAX: Google MJX 대비 1.04배 동등 성능(1.66M vs 1.6M SPS), Brax 대비 5배 (batch 4K 기준), 비용 $3.26
- Puffer Pong: C PufferLib 대비 PPO 42배 속도 향상 (854K → 35.5M SPS), 비용 $0.05
- HalfCheetah 검증 실험: L3-only는 42 iteration 후에도 수렴 실패, 계층적 검증은 5 iteration 수렴 (8.4배 효율)
How to Apply
- 느린 Gym 환경이 있다면 논문 Appendix B의 프롬프트 템플릿 구조(소스 모듈 명세 → 타깃 언어 제약 → 인터페이스 계약 → 참조 동작 → L1 테스트 생성 지시)를 그대로 써서 Gemini/Claude에게 JAX 변환 요청하면 됨
- 복잡한 환경일수록 모듈을 100~500줄 단위로 쪼개서 의존성 순서대로 번역하고, 각 모듈마다 L1 property test 통과 확인 후 다음 모듈로 넘어가는 방식 적용
- JAX 환경은 fixed-size 배열 + jnp.where 브랜치리스 + vmap + jit + lax.scan 롤아웃 퓨전 순서로 최적화하면 되고, Appendix C의 체크리스트를 최적화 프롬프트로 그대로 에이전트에게 전달 가능
Code Example
# JAX 환경 핵심 최적화 패턴 (Appendix C 기반)
import jax
import jax.numpy as jnp
from functools import partial
# 1. 단일 인스턴스 step 함수 작성
def step_single(state, action, constants):
# jnp.where로 브랜치리스 조건 처리
ball_vy = jnp.where(state.wall_hit, -state.ball_vy, state.ball_vy)
# fixed-size 배열 업데이트
new_state = state.replace(ball_vy=ball_vy)
reward = jnp.where(new_state.done, 1.0, 0.0)
return new_state, reward, new_state.done
# 2. vmap으로 배치 병렬화
step_batch = jax.vmap(
partial(step_single, constants=GAME_CONSTANTS),
in_axes=(0, 0) # constants는 broadcast
)
# 3. jit으로 GPU 커널 컴파일
step_jit = jax.jit(step_batch)
# 4. lax.scan으로 롤아웃 전체를 단일 GPU 커널로 퓨전
def scan_body(states, actions_t):
states, rewards, dones = step_jit(states, actions_t)
return states, (rewards, dones)
rollout = jax.jit(partial(jax.lax.scan, scan_body))
# 워밍업 (JIT 컴파일 트리거)
dummy_states = init_batch(batch_size=4096)
dummy_actions = jnp.zeros((4096,), dtype=jnp.int32)
_ = step_jit(dummy_states, dummy_actions)
# ---
# Rust + Rayon CPU 병렬화 패턴 (EmuRust 스타일)
# use rayon::prelude::*;
#
# self.emulators.par_iter_mut()
# .zip(actions.iter())
# .for_each(|(emu, &action)| emu.step(action));Terminology
Related Resources
Original Abstract (Expand)
Translating complex reinforcement learning (RL) environments into high-performance implementations has traditionally required months of specialized engineering. We present a reusable recipe - a generic prompt template, hierarchical verification, and iterative agent-assisted repair - that produces semantically equivalent high-performance environments for <$10 in compute cost. We demonstrate three distinct workflows across five environments. Direct translation (no prior performance implementation exists): EmuRust (1.5x PPO speedup via Rust parallelism for a Game Boy emulator) and PokeJAX, the first GPU-parallel Pokemon battle simulator (500M SPS random action, 15.2M SPS PPO; 22,320x over the TypeScript reference). Translation verified against existing performance implementations: throughput parity with MJX (1.04x) and 5x over Brax at matched GPU batch sizes (HalfCheetah JAX); 42x PPO (Puffer Pong). New environment creation: TCGJax, the first deployable JAX Pokemon TCG engine (717K SPS random action, 153K SPS PPO; 6.6x over the Python reference), synthesized from a web-extracted specification. At 200M parameters, the environment overhead drops below 4% of training time. Hierarchical verification (property, interaction, and rollout tests) confirms semantic equivalence for all five environments; cross-backend policy transfer confirms zero sim-to-sim gap for all five environments. TCGJax, synthesized from a private reference absent from public repositories, serves as a contamination control for agent pretraining data concerns. The paper contains sufficient detail - including representative prompts, verification methodology, and complete results - that a coding agent could reproduce the translations directly from the manuscript.