Running GRPO on GSM8K Dataset#
This guide walks you through how AReaL runs the GRPO algorithm on the GSM8K dataset.
We’ll use the example training script
examples/math/gsm8k_rl.py
and configuration file
examples/math/gsm8k_grpo.yaml
to explain the key concepts step by step.
Overview: How AReaL Works#
Single-Controller Architecture#
AReaL uses a single-controller architecture where the training script orchestrates remote workers via RPC:
Controller Process (Your Script)
│
├─> RolloutController
│ ├─> Manages rollout workers (SGLang/vLLM)
│ ├─> Submits prompts to inference workers
│ ├─> Collects trajectories
│ └─> Returns: RTensor (distributed batch)
│
└─> TrainController
├─> Manages training workers (FSDP/Megatron)
├─> Dispatches RTensor via data_parallel_dispatch()
├─> Workers compute forward/backward
├─> Merges results via data_parallel_merge()
└─> Returns: loss, metrics
Training Step Flow:
Rollout Phase: Controller loads data and passes it to RolloutController, which schedules and routes rollout requests to rollout workers (GPUs).
Each rollout worker serves a complete model (may occupy multiple GPUs)
Returns: RTensor with shards stored on rollout workers (controller holds only metadata)
Dispatch Phase: TrainController distributes work via
data_parallel_dispatch()Uses FFD (First Fit Decreasing) to balance sequence lengths across workers
Workers fetch their assigned shards directly from rollout workers
Training Phase: Each training worker processes its shard independently
Supports 5D parallelism (data, tensor, pipeline, context, expert)
Weight Sync: Transfer updated weights to inference workers
Via NCCL (fast, GPU-to-GPU) or disk (fallback)
Data Flow with RTensor#
Rollout Workers (GPUs 0-3) Controller Training Workers (GPUs 4-7)
───────────────────────── ────────── ───────────────────────────
Worker 0: Generates 16 samples
├─> Shard 0 stored ──────┐
Worker 1: Generates 16 samples │
├─> Shard 1 stored ────┐ │
Worker 2: Generates 16 samples │ │
├─> Shard 2 stored ──┐ │ │
Worker 3: Generates 16 samples │ │ │
└─> Shard 3 stored ─┐│ │ │
││ │ │
││ │ │ RTensor metadata
││ │ └──> Controller ──> data_parallel_dispatch()
││ └─────────────┼──────────────┬─────────────┐
│└───────────────┼──────────────┼─────────────┤
└────────────────┼──────────────┼─────────────┤
│ │ │
▼ ▼ ▼
Worker 4: Worker 5: Worker 6:
Fetch Fetch Fetch
Shards 0,1 Shards 2 Shards 3
│ │ │
├─> Forward ├─> Forward ├─> Forward
├─> Backward ├─> Backward ├─> Backward
└─> Gradients └─> Gradients└─> Gradients
│
NCCL AllReduce
│
Worker 4: Worker 5: Worker 6:
Returns Returns Returns
RTensor RTensor RTensor
│ │ │
└──────────────┴─────────────┘
│
data_parallel_merge()
│
▼
Controller receives:
• loss (scalar)
• metrics (dict)
In the following sections, we’ll walk through the code to explain each component in detail.
Launching the Experiment#
AReaL supports launching experiments with different scheduler backends for different environments. As shown in the quickstart guide, you can launch experiments with:
# Local machine (using subprocesses)
python examples/math/gsm8k_rl.py --config examples/math/gsm8k_grpo.yaml scheduler.type=local
# Ray cluster
python examples/math/gsm8k_rl.py --config examples/math/gsm8k_grpo.yaml scheduler.type=ray
# Slurm cluster
python examples/math/gsm8k_rl.py --config examples/math/gsm8k_grpo.yaml scheduler.type=slurm
How Single-Controller Mode Works#
Training Script: Your experiment entry point (e.g., examples/math/gsm8k_rl.py)
that runs on the controller node.
Controller Responsibilities:
Controllers create worker processes (an HTTP or Ray server)
scheduler.create_workers()After workers are created, controllers create engines (e.g.,
RemoteSGLangEngine,FSDPEngine) viascheduler.create_engine()Controllers dispatch work via RPC and coordinate via PyTorch distributed primitives
Key Configuration:
scheduler.type: Determines which backend to use (local,ray, orslurm)allocation_mode: Determines number of GPUs for training/inference and parallel strategiesSchedulers automatically handle worker placement, resource allocation, and lifecycle management
Configuration Files#
Configuration files are YAML files that specify options from
areal/api/cli_args.py.
You can override settings via CLI:
# Example: change model and attention backend
python examples/math/gsm8k_rl.py \
--config examples/math/gsm8k_grpo.yaml \
scheduler.type=local \
actor.path=Qwen/Qwen3-1.7B \
+sglang.attention_backend=triton
In your training script, parse the configuration:
config, _ = load_expr_config(args, GRPOConfig)
config: GRPOConfig
See CLI Reference for all available options.
The Training Script: Entry Point#
The training script
(examples/math/gsm8k_rl.py)
follows this pattern:
def main(args):
# 1. Load config (YAML + CLI overrides)
config, _ = load_expr_config(args, GRPOConfig)
tokenizer = load_hf_tokenizer(config.tokenizer_path)
# 2. Prepare datasets (loaded on controller)
train_dataset = get_custom_dataset(split="train", dataset_config=config.train_dataset, tokenizer=tokenizer)
valid_dataset = get_custom_dataset(split="test", dataset_config=config.valid_dataset, tokenizer=tokenizer)
# 3. Define workflow configuration (imported on workers)
workflow_kwargs = dict(
reward_fn="areal.reward.gsm8k.gsm8k_reward_fn",
gconfig=config.gconfig,
tokenizer=config.tokenizer_path,
)
# 4. Train with PPOTrainer
with PPOTrainer(config, train_dataset=train_dataset, valid_dataset=valid_dataset) as trainer:
trainer.train(
workflow="areal.workflow.rlvr.RLVRWorkflow",
workflow_kwargs=workflow_kwargs,
)
Key Points:
Datasets loaded on controller, then distributed to workers by controllers
Workflows specified as import strings to enable dynamic instantiation on remote workers
PPOTrainerhandles all infrastructure (scheduler, controllers, workers)
See CLI Reference for configuration options, and Customization: Dataset for custom datasets.
The PPOTrainer: Controller-Based Training#
The
PPOTrainer
orchestrates distributed training by initializing the scheduler and creating controllers
for actors (policy/critic) and rollout workers.
Controller Architecture#
PPOTrainer (Controller Process)
│
├── actor: PPOActorController (TrainController)
│ ├── scheduler.create_workers() → Training workers
│ ├── Remote engines: FSDPPPOActor instances
│ └── APIs: compute_logp(), compute_advantages(), ppo_update()
│
├── rollout: RolloutController
│ ├── scheduler.create_engine() → Inference workers (SGLang/vLLM)
│ ├── BatchTaskDispatcher → Async workflow execution
│ └── API: prepare_batch() → Returns batch tensors
│
└── ref: PPOActorController (optional)
└── Frozen reference model for KL penalty
Key Pattern: Engines use as_controller(config, scheduler) to wrap themselves in
controllers. The controller handles worker creation, RPC dispatch, and result merging.
Rollout: Generating Training Data#
Workflow Specification#
In examples/math/gsm8k_rl.py, workflows are specified as strings to enable dynamic
importing on remote workers:
trainer.train(
workflow="areal.workflow.rlvr.RLVRWorkflow",
workflow_kwargs={
"reward_fn": "areal.reward.gsm8k.gsm8k_reward_fn",
"gconfig": config.gconfig,
"tokenizer": config.tokenizer_path,
},
)
RLVRWorkflow: Single-Turn Reward Learning#
The
RLVRWorkflow
defines how prompts become training samples. Each trajectory goes through these steps:
Tokenize input: Apply chat template to messages
Generate response: Call inference engine (SGLang/vLLM)
Compute reward: Evaluate completion against ground truth
Build training sample: Construct tensor dict with:
input_ids: Full sequence (prompt + completion)loss_mask: 0 for prompt tokens, 1 for completion tokenslogprobs: Log probabilities from generationversions: Model version for each token (-1 for prompt)rewards: Scalar reward
GSM8K Reward: Binary reward (1.0 for correct answer, 0.0 otherwise). See
gsm8k_reward_fn.
This workflow adopts the low-level API of inference engines — the agenerate API. It
is perferrable if you want more fine-grained control over token IDs. agenerate inputs
token IDs to the inference server and produces output token IDs for user’s processing.
We also provide high-level API for convenient agentic workflow orchestration. We refer
to the agentic RL guide.
Asynchronous Rollout Collection#
Rollout in AReaL is fully asynchronous with three levels of concurrency that enable overlap between generation and training.
Three-Process Architecture#
Controller Process Worker Process (RPC Server) GPU Process
────────────────── ─────────────────────────── ───────────
RolloutController Flask HTTP Server (CPU) SGLang/vLLM
│ │ │
└─> BatchTaskDispatcher /call endpoint Inference
(background thread) │ Engine
│ └─> Engine Thread │
├─ submit task 1 └─> RemoteInfEngine │
│ (HTTP POST) └─> submit() ────────────>│
│ Generate
├─ submit task 2 tokens
│ (HTTP POST) │
│ │
├─ submit task 3 HTTP Callback <─────────────┘
│ (trajectory)
│ ┌──────────────┘
└─ collect <──────┘
Meanwhile (on different GPUs)...
TrainController Training Worker
│ │
└─> ppo_update(batch) ──────────>│ Forward/Backward
Key: Generation and training happen SIMULTANEOUSLY on different GPUs
Three Levels of Concurrency#
Level 1 - Controller Thread:
BatchTaskDispatcher
runs in a background thread, continuously submitting rollout requests to workers via
HTTP:
Submits tasks round-robin to rollout workers
Maintains 2+ batches of inflight requests to hide latency
Non-blocking: returns task_id immediately
Level 2 - Worker RPC Server: Each rollout worker runs a Flask HTTP server
(rpc_server.py)
on CPU:
Accepts concurrent HTTP requests (multi-threaded Flask)
Engine thread: Processes engine operations serially (NCCL compatibility)
Routes requests to
RemoteInfEnginewhich queues work to SGLang/vLLM
Level 3 - GPU Subprocess: SGLang/vLLM runs as a separate subprocess on GPU:
Launched via
backend.launch_server()(separate from RPC server)Maintains its own request queue
Processes multiple concurrent generations with continuous batching
Sends HTTP callbacks when trajectories complete
Request Flow#
# 1. Controller calls prepare_batch
batch = rollout.prepare_batch(
dataloader,
workflow="areal.workflow.rlvr.RLVRWorkflow",
workflow_kwargs=workflow_kwargs,
)
# 2. RolloutController delegates to BatchTaskDispatcher
# Background thread submits tasks:
for data in dataloader:
task = _RemoteRolloutTaskInput(data, workflow, workflow_kwargs, task_id)
dispatcher.submit_task_input(task) # Non-blocking HTTP POST
# 3. Worker RPC server receives HTTP POST /call (method="submit")
# Engine thread executes:
workflow_instance = import_from_string(workflow)(**workflow_kwargs)
task_id = workflow_executor.submit(data, workflow_instance)
# Returns immediately (non-blocking)
# 4. WorkflowExecutor (on worker) runs in background:
result = await workflow_instance.arun_episode(engine, data)
# Sends HTTP callback to controller with trajectory
# 5. Controller collects results
# BatchTaskDispatcher waits for batch_size accepted trajectories
results = dispatcher.wait_results(batch_size)
return concat_padded_tensors(results) # Shape: [batch_size, seq_len]
Overlap with Training#
The key benefit of this architecture is that rollout and training happen simultaneously:
Timeline:
─────────────────────────────────────────────────────────────
Rollout GPUs: [Generate Batch N+1] [Generate Batch N+2] ...
Training GPUs: [Train on Batch N] [Train Batch N+1] ...
│
Weight sync happens here
(rollout paused briefly)
Staleness Control:
StalenessManager
limits concurrent inflight requests:
max_concurrent_rollouts: Maximum inflight trajectoriesmax_head_offpolicyness: Reject samples generated with weights too oldVersion tracking: Each token tagged with model version used during generation
Pause/Resume: During weight sync, rollout is paused to avoid stale generations:
# In PPOTrainer.train() loop
rollout.pause() # Pause new submissions
actor.update_weights(...) # Sync weights to inference GPUs
rollout.set_version(step) # Update version tracker
rollout.resume() # Resume submissions
Training: Controller-Worker Pattern#
Training follows a standard controller-worker pattern. The controller dispatches algorithm operations to training workers via RPC, workers process their data shards, and results are merged back.
TrainController: Dispatch Mechanism#
TrainController
provides the core RPC dispatch:
_dispatch_inputs(): Splits batches using FFD load balancing across workersRPC calls: Each worker receives its shard, processes it, returns results
_merge_results(): Reconstructs results from data-parallel workers
Data Flow with RTensor:
Controller Worker 0 Worker 1
│ │ │
├─ RTensor (metadata) ─────┼───────────────────────┤
│ • Shard 0,1,2,3 │ │
│ │ │
├─ dispatch() ─────────────>│ │
│ • Worker 0: Shards 0,1 │ │
│ • Worker 1: Shards 2,3 │ │
│ │ │
│ ├─> Fetch Shards 0,1 │
│ │ from rollout workers│
│ │ ├─> Fetch Shards 2,3
│ │ │ from rollout workers
│ │ │
│ ├─> compute_logp() ├─> compute_logp()
│ │ │
│ ├─> RTensor (result) ├─> RTensor (result)
│<─ merge() ───────────────┴───────────────────────┘
│ • Reconstruct ordering
│ • Return unified RTensor
└─> batch["logp"] = result
Key Design: Controller only handles metadata (RTensor). Workers fetch actual tensors directly from rollout workers, avoiding controller memory overhead.
Training Workers: Algorithm Implementation#
On each training worker,
FSDPPPOActor
implements the GRPO/PPO algorithm:
Algorithm Methods:
compute_logp(batch): Forward pass through model to compute log probabilitiescompute_advantages(batch): Apply reward/advantage normalization (group or batch level)ppo_update(batch): Policy update with mini-batch training and gradient accumulationSplits batch into mini-batches
Computes PPO loss (clipped surrogate objective + optional KL penalty)
Performs backward pass and optimizer step
Parallelism: The allocation_mode config determines GPU allocation:
allocation_mode=sglang:d4+d4, n_gpus=8
Rollout Workers: Training Workers:
GPU 0: SGLang GPU 4: FSDP rank 0 ─┐
GPU 1: SGLang GPU 5: FSDP rank 1 ├─ Data Parallel
GPU 2: SGLang GPU 6: FSDP rank 2 │ (DP size = 4)
GPU 3: SGLang GPU 7: FSDP rank 3 ─┘
│
NCCL AllReduce for gradients
Each worker processes its shard, then synchronizes gradients via NCCL. For custom algorithms, see Customization: Algorithms.
The Training Loop#
The trainer.train() method orchestrates the complete loop. See
PPOTrainer.train()
for the full implementation:
for global_step in range(start_step, max_steps):
# 1. Rollout
rollout_batch = self.actor.prepare_batch(train_dataloader, workflow, workflow_kwargs)
# 2. Compute log-probs and advantages
if config.actor.should_compute_prox_logp():
rollout_batch["prox_logp"] = self.actor.compute_logp(rollout_batch)
if self.ref:
rollout_batch["ref_logp"] = self.ref.compute_logp(rollout_batch)
adv_batch = self.actor.compute_advantages(rollout_batch)
# 3. PPO update
self.actor.ppo_update(adv_batch)
self.actor.step_lr_scheduler()
# 4. Weight sync
self.rollout.pause()
self.actor.update_weights(weight_update_meta)
self.actor.set_version(global_step + 1)
self.rollout.set_version(global_step + 1)
self.rollout.resume()
All algorithm operations are controller method calls that dispatch to remote workers.
Weight Synchronization#
After each training step, updated weights must be synced to inference workers. AReaL supports two transfer methods:
Transfer Methods#
NCCL-based transfer (Recommended):
Direct GPU-to-GPU broadcast
Faster but uses more GPU memory
Requires training and inference GPUs on the same communication backend
Disk-based transfer:
Saves to shared storage, then loads on inference servers
Use when NCCL is unavailable or machines don’t share a backend
Weight Update Process#
The weight sync process in PPOTrainer.train() follows this pattern:
Pause rollout to avoid stale generations
Transfer weights via configured method (NCCL or disk)
Update version tracking for staleness management
Resume rollout with updated weights
See
PPOTrainer.train()
lines 861-874 for the full implementation.
Monitoring and Utilities#
AReaL provides utilities managed by PPOTrainer for checkpointing, evaluation, and
metrics tracking. These are automatically orchestrated during training.
Checkpointing#
The Saver
handles periodic checkpoint saving. Configure via config.saver (interval, format,
etc.). Called automatically in trainer.train().
Evaluation#
The
Evaluator
runs periodic evaluations on validation sets. Configure via config.evaluation. Called
automatically in trainer.train().
Metrics Tracking#
stats_tracker
(source):
Collects and aggregates training statistics across ranks.
scalar(key=value): Record simple metricsstat(key=tensor, denominator=mask): Record tensor statistics with selective aggregationrecord_timing(name): Context manager for timingscope(name): Hierarchical metric keysexport(): Returns aggregated stats across all ranks
StatsLogger
(source):
Sends metrics to logging backends (W&B, TensorBoard) from rank 0. Configure via
config.stats_logger.
Next Steps#
Now that you understand the basics, explore these advanced topics:
Tutorials:
Training Large MoE Models - Scale to massive models with Megatron integration
Agentic RL with OpenAI APIs - Build agents that use tools and APIs
Customization Guides:
Custom Datasets - Use your own data sources
Custom Workflows - Build agentic/RLVR workflows with custom reward functions
Custom Algorithms - Implement your own RL algorithms