Dataset#
AReaL directly integrates with the Dataset class from the HuggingFace datasets
package. This gives you full flexibility to load, process, and filter your data before
training.
The required column names (keys) and data format depend on the specific implementation
of the agent workflow (for online reinforcement learning) or the training engines (for
offline training, such as LMEngine for Supervised Fine-Tuning (SFT)).
Here are two concrete examples from the existing implementation:
SFT (Offline Training)#
In the SFT example, we see that the loaded data is directly passed to the train_lm
method:
# areal/trainer/sft_trainer.py
for global_step in range(start_step, max_steps):
batch = self._load_bcast_from(data_generator)
self.actor.train_lm(batch)
In this case, the train_lm method requires the keys “input_ids”, “attention_mask”, and
“loss_mask” to function. We first tokenize the dataset to extract the “input_ids” and
“loss_mask”. Then, the pad_sequences_to_tensors method is used to batch multiple
sequences and append the “attention_mask”:
# areal/dataset/gsm8k.py
def get_gsm8k_sft_dataset(
path: str,
split: str,
tokenizer,
max_length: int | None = None,
):
dataset = load_dataset(path=path, name="main", split=split)
def process(sample):
seq_token = tokenizer.encode(
sample["question"] + sample["answer"] + tokenizer.eos_token
)
prompt_token = tokenizer.encode(sample["question"])
loss_mask = [0] * len(prompt_token) + [1] * (len(seq_token) - len(prompt_token))
return {"input_ids": seq_token, "loss_mask": loss_mask}
dataset = dataset.map(process).remove_columns(["question", "answer"])
if max_length is not None:
# Filter out sequences longer than max_length
dataset = dataset.filter(lambda x: len(x["input_ids"]) <= max_length)
return dataset
GRPO (Online Training)#
In the GRPO example, the loaded data is first used for inference rather than training:
# areal/trainer/rl_trainer.py
self.train_dataloader = self._create_dataloader(
train_dataset,
dataset_config=self.config.train_dataset,
rank=self.actor.data_parallel_rank,
world_size=self.actor.data_parallel_world_size,
)
for global_step in range(start_step, max_steps):
rollout_batch = self.actor.prepare_batch(
self.train_dataloader,
workflow=workflow,
workflow_kwargs=workflow_kwargs,
should_accept_fn=dynamic_filter_fn,
group_size=config.gconfig.n_samples,
dynamic_bs=self.config.dynamic_bs,
)
Note that the collate_fn here is an identity function, meaning it simply returns the
list of individual data items as a batch. In prepare_batch, the data is then
dispatched to multiple concurrent executions of workflows, where each dispatched data
corresponds to a single episode.
In the following sections, we take
RLVRWorkflow
as an example. Agent workflows have the same pattern of using input data. You are free
to modify the customized dataset to include any keys as long as they accord with your
workflow implementation.
The RLVRWorkflow implementation extracts the “messages” field from the data dictionary
as the prompt for generating a response. Additionally, this data is passed to the
reward_fn as keyword arguments, which allows the reward function to make use of other
dataset fields, like “answers”. Here’s an example:
# areal/workflow/rlvr.py
class RLVRWorkflow(RolloutWorkflow):
async def arun_episode(self, engine: InferenceEngine, data):
input_ids = self.tokenizer.apply_chat_template(
data["messages"],
tokenize=True,
add_generation_prompt=True,
enable_thinking=self.enable_thinking,
)
req = ModelRequest(
input_ids=input_ids,
...
)
...
reward = self.reward_fn(
prompt=prompt_str,
completions=completions_str,
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
Thus, the “messages” field must be constructed when loading the dataset, and the reward function should be defined to handle the dataset’s specific fields. Here’s how you can process the dataset for this example:
from datasets import load_dataset
def get_gsm8k_rl_dataset(
path: str,
split: str,
tokenizer,
max_length: int | None = None,
):
dataset = load_dataset(path=path, name="main", split=split)
def process(sample):
messages = [
{
"role": "user",
"content": sample["question"]
+ "\nPlease put your final answer within \\boxed{}.",
}
]
return {"messages": messages}
dataset = dataset.map(process).remove_columns(["question"])
# Filter out sequences longer than max_length if tokenizer and max_length are provided
if max_length is not None:
def filter_length(sample):
# Tokenize the user content to check length
content = sample["messages"][0]["content"]
tokens = tokenizer.encode(content)
return len(tokens) <= max_length
dataset = dataset.filter(filter_length)
return dataset