Rollout and Agentic RL#

This guide provides an example of modifying the rollout behavior for PPO training.

In particular, we implement a multi-turn math agent using end-to-end RL. The math agent will continuously attempt to think through and solve math problems until it reaches the correct answer.

Define Your Agent#

Create a new file under realhf/impl/agent/, for example, math_multi_turn_agent.py. Your Agent must implement the interface defined in realhf/api/core/agent.py, which requires implementing a single method: collect_trajectory.

class MathMultiTurnAgent(Agent):
    
    async def collect_trajectory(
        self,
        prompt: SequenceSample,
        env: EnvironmentService,
        obs_queue: asyncio.Queue,
        act_queue: asyncio.Queue,
    ):
        ...

Implement the collect_trajectory Logic#

The collect_trajectory function takes a task prompt, an environment, and two queues as input, then produces several trajectories for the RL trainer. Within this function, you can create arbitrary data processing logic to produce the input for the inference engine (i.e., via obs_queue) and extract the action (i.e., via act_queue) from the generated tokens.

In this example, the initial observation is the math problem itself. We put the token IDs and generation config into obs_queue and wait for the action produced by the inference engine from act_queue. After the inference engine returns, we extract the generated answers and send them to the environment.

for turn in range(self.num_turns):
    await obs_queue.put((qid, token_ids, self.gconfig))
    act: BundledGenerationOutputs = await act_queue.get()
    _, success, *_ = await env.step((qid, answers))
    ...

The environment is similar to a gym environment, which defines two methods: reset and step. However, to maintain efficiency, we use an asynchronous implementation to avoid mutual blocking across different environment instances.

The math environment is stateless and essentially serves as a wrapper around the reward function:

class MathCodeSingleStepEnv(EnvironmentService):

    async def step(self, action: Tuple[str, List[str]]):
        qid, answers = action
        ...
        # Make `math_verify_call` async
        format_rewards = await asyncio.to_thread(
            math_verify_call,
            answers,
            ...
        )
        return None, format_rewards, True, False, {}

After env.step returns the reward for the current step, we can check whether the answer is correct. If not, we can append a user prompt and send it to obs_queue again to enter the next round.

for turn in range(self.num_turns):
    ...
    feedback = None
    if success[0]:
        feedback = "Congratulations! You are correct!"
    else:
        feedback = "Unfortunately your answer is wrong. Let's try again."
    
    feedback = "\n" + self.tokenizer.apply_chat_template(
        [dict(content=feedback, role="user")],
        add_generation_prompt=True,
        tokenize=False,
    )
    feedback = self.tokenizer(feedback)["input_ids"]
    token_ids.extend(feedback)

Modify the Configuration#

You’re now close to running the end-to-end RL loop. The final step is to register and import your implementation, then modify the experiment configuration.

# in realhf/impl/agent/math_multi_turn_agent.py
register_agent("math-multi-turn", MathMultiTurnAgent)
# in realhf/impl/agent/__init__.py
import realhf.impl.agent.math_multi_turn_agent

In realhf/experiments/async_exp/async_math_ppo.py:

@dataclasses.dataclass
class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
+   # New CLI arguments are defined here
+   my_param: float = 1.0

    # in realhf/experiments/async_exp/async_ppo_math_exp.py
    @property
    def agent(self) -> AgentAbstraction:
        return AgentAbstraction(
-           "math-single-step",
+           "math-multi-turn",  # Your registered name
            args=dict(
-                ...
+                # Any configurations for your __init__ method
+                my_param=my_param,
            ),
        )

    @property
    def env(self) -> EnvServiceAbstraction:
-        return EnvServiceAbstraction(
-            "math-code-single-step", args=dict(dataset_path=self.dataset.path)
-        )
+        # Change to your customized environment if necessary
+        return EnvServiceAbstraction(
+            "my-env", args=dict(...)
+        )

Run Training#

Please follow the guide in quickstart. Generally, start your experiments by running:

python3 training/main_async_ppo.py my_param=5.0  # and any additional CLI arguments

The training reward of our trial is shown below:

Happy coding!