Automatic Generation of High-Performance RL Environments
TL;DR Highlight
A recipe for AI coding agents to automatically convert RL training environments to JAX/Rust, making them up to 22,320x faster — for under $10.
Who Should Read
RL researchers frustrated by slow Python-based training environments, and teams wanting to dramatically speed up their RL experimentation cycle without manual environment reimplementation.
Core Mechanics
- Proposed a pipeline where an AI coding agent automatically rewrites slow Python RL environments in JAX (for GPU parallelism) or Rust (for CPU speed)
- The conversion process costs under $10 in LLM API calls
- Achieved speedups up to 22,320x on certain environments compared to original Python implementations
- The agent handles the complexity of JAX's functional programming paradigm and vectorization automatically
- Resulting environments are compatible with standard RL training libraries
- Human expert validation confirms the converted environments are functionally equivalent to originals
Evidence
- Maximum observed speedup of 22,320x on benchmark RL environments
- Conversion cost consistently under $10 in LLM API usage across tested environments
- Functional equivalence verified through automated testing against original implementations
- Compatible with Gymnasium/Brax standards for drop-in replacement
How to Apply
- Feed your Python RL environment code to the agent pipeline along with test cases to verify correctness
- The agent targets JAX for GPU-parallelizable environments and Rust for CPU-bound environments — specify your hardware constraints upfront
- Run the included equivalence tests after conversion to verify behavior matches the original before using in production training runs
Code Example
# Core optimization patterns for JAX environment (based on Appendix C)
import jax
import jax.numpy as jnp
from functools import partial
# 1. Write a single-instance step function
def step_single(state, action, constants):
# Handle branchless conditions with jnp.where
ball_vy = jnp.where(state.wall_hit, -state.ball_vy, state.ball_vy)
# Update fixed-size array
new_state = state.replace(ball_vy=ball_vy)
reward = jnp.where(new_state.done, 1.0, 0.0)
return new_state, reward, new_state.done
# 2. Batch parallelization with vmap
step_batch = jax.vmap(
partial(step_single, constants=GAME_CONSTANTS),
in_axes=(0, 0) # constants are broadcast
)
# 3. Compile GPU kernel with jit
step_jit = jax.jit(step_batch)
# 4. Fuse entire rollout into a single GPU kernel with lax.scan
def scan_body(states, actions_t):
states, rewards, dones = step_jit(states, actions_t)
return states, (rewards, dones)
rollout = jax.jit(partial(jax.lax.scan, scan_body))
# Warm-up (trigger JIT compilation)
dummy_states = init_batch(batch_size=4096)
dummy_actions = jnp.zeros((4096,), dtype=jnp.int32)
_ = step_jit(dummy_states, dummy_actions)
# ---
# Rust + Rayon CPU parallelization pattern (EmuRust style)
# use rayon::prelude::*;
#
# self.emulators.par_iter_mut()
# .zip(actions.iter())
# .for_each(|(emu, &action)| emu.step(action));Terminology
Related Resources
Original Abstract (Expand)
Translating complex reinforcement learning (RL) environments into high-performance implementations has traditionally required months of specialized engineering. We present a reusable recipe - a generic prompt template, hierarchical verification, and iterative agent-assisted repair - that produces semantically equivalent high-performance environments for <$10 in compute cost. We demonstrate three distinct workflows across five environments. Direct translation (no prior performance implementation exists): EmuRust (1.5x PPO speedup via Rust parallelism for a Game Boy emulator) and PokeJAX, the first GPU-parallel Pokemon battle simulator (500M SPS random action, 15.2M SPS PPO; 22,320x over the TypeScript reference). Translation verified against existing performance implementations: throughput parity with MJX (1.04x) and 5x over Brax at matched GPU batch sizes (HalfCheetah JAX); 42x PPO (Puffer Pong). New environment creation: TCGJax, the first deployable JAX Pokemon TCG engine (717K SPS random action, 153K SPS PPO; 6.6x over the Python reference), synthesized from a web-extracted specification. At 200M parameters, the environment overhead drops below 4% of training time. Hierarchical verification (property, interaction, and rollout tests) confirms semantic equivalence for all five environments; cross-backend policy transfer confirms zero sim-to-sim gap for all five environments. TCGJax, synthesized from a private reference absent from public repositories, serves as a contamination control for agent pretraining data concerns. The paper contains sufficient detail - including representative prompts, verification methodology, and complete results - that a coding agent could reproduce the translations directly from the manuscript.