Training Algorithm#

An algorithm is encapsulated in a ModelInterface, which primarily defines three methods:

# in realhf/api/core/model_api.py
class ModelInterface(abc.ABC):
    """An interface for model training, inference, and generation.

    This interface is designed to follow the dependency injection pattern.
    We pass the model to the interface and call its methods, ensuring that model APIs
    and algorithms are fully decoupled. For example, REINFORCE and PPO can exhibit
    different behaviors during training. Separate interfaces can be written for these
    algorithms while using the same model that provides basic forward-backward-update
    functionality (i.e., :class:`PipelinableEngine`).

    During runtime, the master worker requests model workers to execute a specific
    interface type (e.g., generate) on a specific model. The model worker locates
    the corresponding model, passes it into the requested interface, performs the
    computation, and returns the result.
    """

    def inference(
        self,
        model: Model,
        data: SequenceSample,
        mb_spec: MicroBatchSpec,
    ) -> SequenceSample | None:
        raise NotImplementedError()

    def generate(
        self,
        model: Model,
        data: SequenceSample,
        mb_spec: MicroBatchSpec,
    ) -> SequenceSample | None:
        raise NotImplementedError()

    def train_step(
        self,
        model: Model,
        data: SequenceSample,
        mb_spec: MicroBatchSpec,
    ) -> Dict | List[Dict]:
        raise NotImplementedError()

When the dataflow is fixed, it’s usually sufficient to modify or add the file that defines the algorithm interface.

We provide two examples: (1) changing PPO’s global advantage normalization to grouped normalization in GRPO, and (2) changing the original PPO loss to the decoupled PPO loss in AReaL’s paper.

Note

We recommend using asynchronous RL, so that you can customize the generation behavior by modifying your RL agent and don’t need to modify the generate method of model interfaces.

Grouped Advantage Normalization#

The PPO algorithm is written in a single file ppo_interface.py. The method we are going to modify is the train_step method in PPOActorInterface. PPO’s global advantage normalization looks like:

@dataclass
class PPOActorInterface(ModelInterface):
    def train_step(
        self,
        model: Model,
        data: SequenceSample,
        mb_spec: MicroBatchSpec,
    ) -> Dict | List[Dict]:
        ...
        if self.adv_norm:
            advantages = masked_normalization(advantages, loss_mask)
        ...

An Additional Note on Data Management#

We need to explain how data in each batch is organized.

Usually, each data batch (i.e., the data variable) includes multiple prompts. The number of prompts is called “batch size”. Additionally, each prompt may have multiple corresponding answers. The number of answers is called “group_size”. Therefore, there are batch_size × group_size sequences in each batch.

These sequences have different lengths, but they are concatenated (or packed) together as a 1D tensor. The inner dimension is the “group” with the same prompt, and the outer dimension consists of answers from different prompts. Similar to flash-attention, we use cu_seqlens to mark the boundary of each sequence. cu_seqlens is the cumulative sum of sequence lengths across the batch.

Each token in the sequence has a corresponding reward and advantage, so advantages is also a packed 1D tensor just like the tokens (i.e., packed_input_ids). However, the “sequences” of advantages are all one step shorter than tokens due to the auto-regressive nature of LLMs. We can only compute the loss on tokens except for the first one in each sequence.

Implementation#

For grouped advantage normalization, we need to partition the advantages into groups and run normalization within the tensor chunk of each group:

@dataclass
class PPOActorInterface(ModelInterface):
+   group_adv_norm: bool = False

    def train_step(
        self,
        model: Model,
        data: SequenceSample,
        mb_spec: MicroBatchSpec,
    ) -> Dict | List[Dict]:
        ...
        if self.adv_norm:
-           advantages = masked_normalization(advantages, loss_mask)
+           if not self.group_adv_norm:
+               advantages = masked_normalization(advantages, loss_mask)
+           else:
+               n_samples = data.bs
+               adv_list = []
+               for i in range(0, n_samples, self.group_size):
+                   # Start and end of the chunk
+                   s = short1cu_seqlens[i]
+                   e = short1cu_seqlens[i + self.group_size]
+                   # Get advantages within each group of the same prompt
+                   adv = advantages[s: e]
+                   mask = loss_mask[s: e]
+                   # Run normalization
+                   advn = masked_normalization(adv, mask, all_reduce=False)
+                   adv_list.append(advn)
+               advantages = torch.cat(adv_list, 0)
        ...

Modify Your Experiment Configuration#

To make our new argument group_adv_norm effective in CLI args, we should make the following changes to the PPOMathConfig under realhf/experiments/common/ppo_math_exp.py:

@dataclasses.dataclass
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
+   group_adv_norm: bool = False

    @property
    def rpcs(self):
        ...
        # interfaces
        actor_interface = ModelInterfaceAbstraction(
            "ppo_actor",
            args={
                **copy.deepcopy(self.ppo_kwargs),
+               "group_adv_norm": self.group_adv_norm,
                ...
            },
        )

The Decoupled PPO Loss#

decoupled loss

As mentioned in AReaL’s paper, we implement this loss by recomputing the probabilities before mini-batched updates, and use this value as π_prox to compute the above loss.

Probability Recomputation#

Recomputation involves a single forward pass, which has already been implemented by PPOActorInterface.inference. We need to call this method in the train_step method:

@dataclass
class PPOActorInterface(ModelInterface):
+   use_decoupled_loss: bool = False

    def train_step(
        self,
        model: Model,
        data: SequenceSample,
        mb_spec: MicroBatchSpec,
    ) -> Dict | List[Dict]:
+       if self.use_decoupled_loss:
+           s: SequenceSample = self.inference(model, data, mb_spec)
+           prox_logp = s.data["logprobs"]
        ...

Next, we need to pass prox_logp to loss computation:

@dataclass
class PPOActorInterface(ModelInterface):
    ...

    def train_step(
        self,
        model: Model,
        data: SequenceSample,
        mb_spec: MicroBatchSpec,
    ) -> Dict | List[Dict]:
        # Prepare data to be split into mini-batches.
        flat_data = dict(
            advantages=advantages,
            old_logp=old_logp,
            ppo_loss_mask=loss_mask,
            packed_input_ids=input_.data["packed_input_ids"],
            kl_rewards=kl_rewards,
        )
+       if self.use_decoupled_loss:
+           flat_data["prox_logp"] = prox_logp.float()

        flat_input = SequenceSample.from_default(
            ids=list(range(input_.bs * self.group_size)),
            data=flat_data,
            seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()],
        )
        ...
        datas = flat_input.split_with_spec(spec)
        ...
        for mb_i, data in enumerate(datas):
            train_stat = module.train_batch(
                input_=data,
                mb_spec=mb_spec,
                version_steps=model.version.global_step,
                loss_fn=_loss_fn,
                loss_weight_fn=lambda x: x.data[
                    "ppo_loss_mask"
                ].count_nonzero(),
                token_normalize_scope=self.token_normalize_scope,
            )

The flat_input variable will be divided into mini-batches. Each mini-batch of data will be passed into the train_batch method to run distributed training. The data included in this SequenceSample object will all be passed into the _loss_fn. In this case, _loss_fn is a wrapper over _ppo_actor_loss_from_model_outputs:

def _ppo_actor_loss_from_model_outputs(
    logits: torch.FloatTensor,  # [tot_seqlen, vocab_size]
    input_: SequenceSample,
    ...
) -> torch.Tensor:
    ...

logits is the output of model forward, and input_ is exactly the input_ we passed into train_batch. So now we can retrieve the prox_logp via:

def _ppo_actor_loss_from_model_outputs(
    logits: torch.FloatTensor,  # [tot_seqlen, vocab_size]
    input_: SequenceSample,
    ...
) -> torch.Tensor:
    ...
+   prox_logp = input_.data["prox_logp"]
    loss, ppo_stat = ppo_functional.actor_loss_fn(
        logprobs=logprobs,
        old_logprobs=old_logp,
        advantages=advantages,
        eps_clip=eps_clip,
        loss_mask=ppo_loss_mask,
        c_clip=c_clip,
+        proximal_logprobs=prox_logp,
        behav_imp_weight_cap=behav_imp_weight_cap,
    )

We have successfully recomputed the probability and passed it into the loss function. Next we should revise the loss computation code.

Modifying the PPO Loss#

def actor_loss_fn(
    logprobs: torch.FloatTensor,
    old_logprobs: torch.FloatTensor,
    advantages: torch.FloatTensor,
    eps_clip: float,
    loss_mask: Optional[torch.BoolTensor] = None,
    c_clip: Optional[float] = None,
+    proximal_logprobs: Optional[torch.FloatTensor] = None,
    behav_imp_weight_cap: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.Tensor, Dict]:
    ...
+    if proximal_logprobs is not None:
+        denorm_logprobs = proximal_logprobs
+    else:
+        denorm_logprobs = old_logprobs
    ...
    loss_mask_count = loss_mask.count_nonzero() or 1
    # For numerical stability.
-    ratio = torch.where(loss_mask, torch.exp(logprobs - old_logprobs), 0)
+    ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
    ...
+    if proximal_logprobs is not None:
+        behav_kl = proximal_logprobs - old_logprobs
+        behav_imp_weight = behav_kl.exp()
+        behav_kl = torch.where(loss_mask, behav_kl, 0.0)
+        behav_imp_weight = torch.where(loss_mask, behav_imp_weight, 0.0)
+        pg_loss = pg_loss * behav_imp_weight
    ...
    return pg_loss, stat

Modify the Experiment Configuration#

@dataclasses.dataclass
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
+   use_decoupled_loss: bool = False

    @property
    def rpcs(self):
        ...
        # interfaces
        actor_interface = ModelInterfaceAbstraction(
            "ppo_actor",
            args={
                **copy.deepcopy(self.ppo_kwargs),
+               "use_decoupled_loss": self.use_decoupled_loss,
                ...
            },
        )