Checkpointing#
This document describes AReaL’s checkpointing system, which handles model saving for evaluation and fault-tolerant recovery during distributed RL training.
Overview#
AReaL provides two complementary checkpointing mechanisms:
Mechanism |
Purpose |
Format |
Includes Optimizer/DataLoader State |
|---|---|---|---|
Saver |
Export models for evaluation or publishing |
HuggingFace |
No |
RecoverHandler |
Resume training after failures |
DCP (Distributed Checkpoint) |
Yes |
Both mechanisms are invoked automatically during training and can be configured via
config.saver and config.recover respectively.
Checkpoint Formats#
HuggingFace Format#
Used by Saver for model export:
Standard HuggingFace model format (safetensors + config.json)
Compatible with
transformers.AutoModel.from_pretrained()Can be uploaded to HuggingFace Hub
Does not include optimizer state
DCP Format (Distributed Checkpoint)#
Used by RecoverHandler for fault tolerance:
Backend’s native distributed checkpoint format (
torch.distributed.checkpointor Megatron distributed checkpoint)Sharded across all ranks for efficient parallel I/O
Includes model weights, optimizer state, RNG state, etc
Backend-specific: checkpoints are only compatible with the same parallelism configuration
Overwrites previous checkpoint to save disk space
Architecture#
PPOTrainer.train()
│
├── Training loop
│ ├── Rollout, compute values, PPO update...
│ │
│ ├── _save_hf() # HuggingFace export
│ │ └── Saver.save()
│ │ └── engine.save(weight_format="hf")
│ │
│ └── _save_recover_checkpoint() # Fault tolerance
│ └── RecoverHandler.dump()
│ └── engine.save(weight_format="dcp", with_optim=True)
│
└── On restart
└── RecoverHandler.load()
├── Restore dataloader, saver, evaluator states
└── engine.load(weight_format="dcp", with_optim=True)
Saver: HuggingFace Model Export#
The Saver
periodically exports model weights in HuggingFace format for evaluation or deployment.
Configuration#
Configure via config.saver:
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
int | None |
None |
Save every N epochs. None disables. |
|
int | None |
None |
Save every N steps. None disables. |
|
int | None |
None |
Save every N seconds. None disables. |
Example configuration:
saver:
freq_epochs: 1 # Save at end of each epoch
freq_steps: null # Disabled
freq_secs: null # Disabled
Saving is triggered when any of epoch/step/time condition is met.
Output Location#
Checkpoints are saved to:
{fileroot}/checkpoints/{user}/{experiment_name}/{trial_name}/default/
└── epoch{E}epochstep{S}globalstep{G}/
├── config.json
├── model.safetensors (or model-00001-of-00002.safetensors, etc.)
├── tokenizer.json
└── ...
Usage#
Load saved checkpoints with standard HuggingFace APIs:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"/path/to/checkpoint/epoch0epochstep99globalstep99"
)
tokenizer = AutoTokenizer.from_pretrained(
"/path/to/checkpoint/epoch0epochstep99globalstep99"
)
RecoverHandler: Fault Tolerance#
The
RecoverHandler
enables resuming training after failures by saving complete training state.
Configuration#
Configure via config.recover:
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
str |
“disabled” |
Recovery mode: “on”/”auto” or “off”/”disabled” |
|
int | None |
None |
Checkpoint every N epochs |
|
int | None |
None |
Checkpoint every N steps |
|
int | None |
None |
Checkpoint every N seconds |
|
int |
3 |
Number of recovery retries when recovery enabled |
Recovery Modes#
Mode |
Behavior |
|---|---|
|
Automatically resume if valid checkpoint exists |
|
No checkpointing or recovery |
When recovery is enabled (on/auto), the system will:
Periodically save recovery checkpoints (model weights, optimizer state, dataloader position)
Automatically resume from the last valid checkpoint on restart
Retry up to
retriestimes on failure
Example configuration:
recover:
mode: on # or "auto" for backward compatibility
freq_steps: 100 # Checkpoint every 100 steps
retries: 3
What Gets Saved#
RecoverHandler saves complete training state:
Component |
Contents |
|---|---|
Model weights |
DCP format, sharded across ranks |
Optimizer state |
Momentum, variance (Adam), learning rate scheduler |
RNG state |
Python, NumPy, PyTorch, CUDA random states |
Dataloader state |
Current position in dataset |
Training progress |
Epoch, step, global_step counters |
Auxiliary states |
Saver, Evaluator, StatsLogger states |
Output Location#
Recovery checkpoints are saved to:
{fileroot}/checkpoints/{user}/{experiment_name}/{trial_name}/
├── default/
│ └── recover_checkpoint/ # Model + optimizer (DCP format)
│ ├── __0_0.distcp
│ ├── __1_0.distcp
│ └── ...
├── critic/ # If using critic
│ └── recover_checkpoint/
└── recover_info/ # Metadata
├── step_info.json
├── saver_info.json
├── evaluator_info.json
├── stats_logger_info.json
├── checkpoint_info.json
└── dataloader_info.pkl
Recovery Process#
When training resumes:
RecoverHandler.load()restores all saved state (if any)Training continues from
last_step_info.next().global_stepInference engine weights are synchronized to match recovered state
Best Practices#
Frequency Guidelines#
Scenario |
Recommended Setting |
|---|---|
Long training runs |
|
Unpredictable time |
|
Unstable clusters |
|
Limited disk space |
Lower frequency, rely on final checkpoint |
Debugging |
|
Disk Space Considerations#
Saver: Each save creates a new directory. High frequency consumes significant space.
RecoverHandler: Overwrites previous checkpoint. Only one copy exists at a time.
Recovery Tips#
Verify checkpoint validity: Check
recover_info/step_info.jsonfor the last saved stepSame config required: DCP checkpoints require identical parallelism configuration, experiment name, and trial name
Clean restart: Delete
recover_info/directory to start fresh