Handling OOM Issues#
OOM errors are pretty common when you’re doing large-scale RL training. Here’s how to tackle them across generation, training, and weight updates in your AReaL workflows.
Understanding Memory Usage#
Before jumping into fixes, let’s understand which parameters actually matter for memory usage:
Core Parameters#
allocation_mode
: How you split inference and training across GPUs. For large models, tensor parallelism typically uses less memory per GPU than data parallelism.train_dataset.max_length
: Your maximum prompt length. Longer prompts = more memory.gconfig.max_new_tokens
: How many tokens to generate per prompt. This plusmax_length
gives you your total sequence length.actor.mb_spec.max_tokens_per_mb
: Tokens per micro-batch during forward/backward passes. This is your main knob for controlling training memory. Can’t go belowmax_length + max_new_tokens
.max_concurrent_rollouts
: How many generation requests you run in parallel. More requests = better throughput but higher memory usage.
Engine-Specific Parameters#
Inference Engine:
sglang.mem_fraction_static
controls how much GPU memory SGLang uses. Check the SGLang docs for more tuning options.Training Engine: FSDP sharding and other PyTorch settings also impact memory usage. The FSDP docs have more details.
Don’t worry about
train_dataset.batch_size
- it doesn’t actually affect peak memory usage. Stick to the parameters above when troubleshooting OOM issues.
Resolving Generation OOM Errors#
When you hit generation OOM errors (you’ll see them in llm_server.log
), here’s what to
try:
1. Reduce Concurrent Rollouts (Most Effective)#
Lower the number of parallel generation requests:
max_concurrent_rollouts: 200 # Try reducing from default values like 256
This is usually your best bet since it directly reduces memory pressure on the inference servers.
2. Adjust Parallelism Strategy#
Try increasing tensor parallelism to spread your model weights across more GPUs:
# Before: sglang:d4+fsdp:d4 (4 data parallel processes)
# After: sglang:d2t2+fsdp:d4 (2 data parallel, 2 tensor parallel)
allocation_mode: sglang:d2t2+fsdp:d4
Just keep in mind that higher tensor parallelism will slow down your generation throughput.
3. Tune SGLang Parameters#
You can also tweak how SGLang allocates memory:
sglang:
mem_fraction_static: 0.8 # Reduce from 0.9 to leave more memory headroom
Check out the SGLang docs for more advanced tuning options.
Resolving Training OOM Errors#
Training OOM errors are trickier - you need to reduce the memory footprint of gradient computation and model updates.
1. Optimize Micro-batch Size#
Your first move: set max_tokens_per_mb
as low as safely possible:
actor:
mb_spec:
max_tokens_per_mb: 4096 # train_dataset.max_length + gconfig.max_new_tokens
For multi-turn conversations, calculate it like this:
max_tokens_per_mb = <longest_conversation_length> + gconfig.max_new_tokens
The exact value will depend on how your RolloutWorkflow
is implemented.
2. Enable Ulysses Sequence Parallelism#
If you’re dealing with really long contexts and can’t reduce max_tokens_per_mb
any
further, try Ulysses sequence parallelism to spread sequences across multiple GPUs:
# Before: sglang:d4+fsdp:d4 (4 data parallel processes)
# After: sglang:d4+fsdp:d2c2 (2 data parallel, 2 ulysses context parallel)
allocation_mode: sglang:d4+fsdp:d2c2
Just remember: Ulysses context parallel size needs to divide evenly into your model’s attention heads.
For example, with 40 attention heads:
These work:
1, 2, 4, 8
These don’t:
16, 32
Resolving Weight Update OOM Errors#
Weight updates can eat up a lot of memory, especially when using NCCL synchronization (which is the default).
1. Switch to Disk-Based Updates#
The easiest fix is switching from NCCL to disk-based weight synchronization:
# Instead of NCCL-based updates
weight_update_meta = WeightUpdateMeta.from_disk(config.saver)
Check the “Transferring Weights to Inference Servers” section in the Weight Updates Guide for the full implementation details.
2. Reduce Memory Buffer Size#
If you want to stick with NCCL, try reducing the memory buffer size for weight chunking:
# In WeightUpdateMeta.from_fsdp_nccl() calls
WeightUpdateMeta.from_fsdp_nccl(
...,
weight_chunked_mem_mb = 512, # Reduce from default (typically 1024+)
)