Training Algorithm (Legacy)#

Note: The AReaL-lite approach is more recommended for new implementations due to its cleaner separation of concerns and better maintainability.

The legacy approach encapsulates algorithms in a ModelInterface with three core methods:

# From realhf/api/core/model_api.py
class ModelInterface(abc.ABC):
    """Interface for model training, inference, and generation.

    This interface follows the dependency injection pattern, allowing
    algorithms like REINFORCE and PPO to use the same underlying model
    while exhibiting different training behaviors.
    """

    def inference(
        self,
        model: Model,
        data: SequenceSample,
        mb_spec: MicroBatchSpec,
    ) -> SequenceSample | None:
        raise NotImplementedError()

    def generate(
        self,
        model: Model,
        data: SequenceSample,
        mb_spec: MicroBatchSpec,
    ) -> SequenceSample | None:
        raise NotImplementedError()

    def train_step(
        self,
        model: Model,
        data: SequenceSample,
        mb_spec: MicroBatchSpec,
    ) -> Dict | List[Dict]:
        raise NotImplementedError()

When the dataflow is fixed, you typically only need to modify the algorithm interface file.

Note: We recommend using asynchronous RL so you can customize generation behavior by modifying your RL agent instead of the generate method.

Example 1: Grouped Advantage Normalization#

Let’s modify PPO’s global advantage normalization to use grouped normalization (GRPO approach).

Understanding Data Organization#

Each batch contains multiple prompts (batch size) and each prompt may have multiple responses (group size). So total sequences = batch_size × group_size.

Sequences have different lengths but are packed into a 1D tensor. We use cu_seqlens (cumulative sequence lengths) to mark boundaries, similar to flash-attention.

Implementation#

The standard PPO normalization looks like:

@dataclass
class PPOActorInterface(ModelInterface):
    def train_step(self, model: Model, data: SequenceSample, mb_spec: MicroBatchSpec) -> Dict | List[Dict]:
        # ...
        if self.adv_norm:
            advantages = masked_normalization(advantages, loss_mask)
        # ...

For grouped normalization, we partition advantages by group:

@dataclass
class PPOActorInterface(ModelInterface):
    group_adv_norm: bool = False

    def train_step(self, model: Model, data: SequenceSample, mb_spec: MicroBatchSpec) -> Dict | List[Dict]:
        # ...
        if self.adv_norm:
            if not self.group_adv_norm:
                advantages = masked_normalization(advantages, loss_mask)
            else:
                n_samples = data.bs
                adv_list = []
                for i in range(0, n_samples, self.group_size):
                    # Define chunk boundaries
                    s = short1cu_seqlens[i]
                    e = short1cu_seqlens[i + self.group_size]

                    # Extract advantages for this group
                    adv = advantages[s:e]
                    mask = loss_mask[s:e]

                    # Normalize within group
                    advn = masked_normalization(adv, mask, all_reduce=False)
                    adv_list.append(advn)

                advantages = torch.cat(adv_list, 0)
        # ...

Configuration Changes#

Update the experiment configuration to expose the new parameter:

@dataclasses.dataclass
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
    group_adv_norm: bool = False

    @property
    def rpcs(self):
        # ...
        actor_interface = ModelInterfaceAbstraction(
            "ppo_actor",
            args={
                **copy.deepcopy(self.ppo_kwargs),
                "group_adv_norm": self.group_adv_norm,
                # ...
            },
        )

Example 2: Decoupled PPO Loss#

The decoupled PPO loss (from AReaL’s paper) recomputes probabilities before mini-batch updates and uses this as π_prox:

decoupled loss

Probability Recomputation#

We recompute probabilities using the existing inference method:

@dataclass
class PPOActorInterface(ModelInterface):
    use_decoupled_loss: bool = False

    def train_step(self, model: Model, data: SequenceSample, mb_spec: MicroBatchSpec) -> Dict | List[Dict]:
        if self.use_decoupled_loss:
            s: SequenceSample = self.inference(model, data, mb_spec)
            prox_logp = s.data["logprobs"]

        # Prepare mini-batch data
        flat_data = dict(
            advantages=advantages,
            old_logp=old_logp,
            ppo_loss_mask=loss_mask,
            packed_input_ids=input_.data["packed_input_ids"],
            kl_rewards=kl_rewards,
        )

        if self.use_decoupled_loss:
            flat_data["prox_logp"] = prox_logp.float()

        flat_input = SequenceSample.from_default(
            ids=list(range(input_.bs * self.group_size)),
            data=flat_data,
            seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()],
        )

        # Split into mini-batches and train
        datas = flat_input.split_with_spec(spec)
        for mb_i, data in enumerate(datas):
            train_stat = module.train_batch(
                input_=data,
                mb_spec=mb_spec,
                version_steps=model.version.global_step,
                loss_fn=_loss_fn,
                loss_weight_fn=lambda x: x.data["ppo_loss_mask"].count_nonzero(),
                token_normalize_scope=self.token_normalize_scope,
            )

Modifying the Loss Function#

Update the loss computation to use the recomputed probabilities:

def _ppo_actor_loss_from_model_outputs(
    logits: torch.FloatTensor,  # [tot_seqlen, vocab_size]
    input_: SequenceSample,
    ...
) -> torch.Tensor:
    # ...
    prox_logp = input_.data.get("prox_logp")

    loss, ppo_stat = ppo_functional.actor_loss_fn(
        logprobs=logprobs,
        old_logprobs=old_logp,
        advantages=advantages,
        eps_clip=eps_clip,
        loss_mask=ppo_loss_mask,
        c_clip=c_clip,
        proximal_logprobs=prox_logp,
        behav_imp_weight_cap=behav_imp_weight_cap,
    )

And in the core loss function:

def actor_loss_fn(
    logprobs: torch.FloatTensor,
    old_logprobs: torch.FloatTensor,
    advantages: torch.FloatTensor,
    eps_clip: float,
    loss_mask: Optional[torch.BoolTensor] = None,
    c_clip: Optional[float] = None,
    proximal_logprobs: Optional[torch.FloatTensor] = None,
    behav_imp_weight_cap: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.Tensor, Dict]:
    # Use proximal probabilities if available, otherwise use old probabilities
    denorm_logprobs = proximal_logprobs if proximal_logprobs is not None else old_logprobs

    loss_mask_count = loss_mask.count_nonzero() or 1

    # Compute importance weights
    ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)

    # Apply behavioral importance weighting for decoupled loss
    if proximal_logprobs is not None:
        behav_kl = proximal_logprobs - old_logprobs
        behav_imp_weight = behav_kl.exp()
        behav_kl = torch.where(loss_mask, behav_kl, 0.0)
        behav_imp_weight = torch.where(loss_mask, behav_imp_weight, 0.0)
        pg_loss = pg_loss * behav_imp_weight

    # ...
    return pg_loss, stat

Configuration Update#

@dataclasses.dataclass
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
    use_decoupled_loss: bool = False

    @property
    def rpcs(self):
        # ...
        actor_interface = ModelInterfaceAbstraction(
            "ppo_actor",
            args={
                **copy.deepcopy(self.ppo_kwargs),
                "use_decoupled_loss": self.use_decoupled_loss,
                # ...
            },
        )