Skip to content
1 change: 1 addition & 0 deletions miles/ray/rollout_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self, args):
apply_chat_template=args.apply_chat_template,
apply_chat_template_kwargs=args.apply_chat_template_kwargs,
seed=args.rollout_seed,
num_proc=args.num_proc,
)
if self.args.rollout_shuffle:
self.dataset.shuffle(self.epoch_id)
Expand Down
1 change: 1 addition & 0 deletions miles/rollout/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, args):
apply_chat_template=args.apply_chat_template,
apply_chat_template_kwargs=args.apply_chat_template_kwargs,
seed=args.rollout_seed,
num_proc=args.num_proc,
)
if self.args.rollout_shuffle:
self.dataset.shuffle(self.epoch_id)
Expand Down
1 change: 1 addition & 0 deletions miles/rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ async def eval_rollout_single_dataset(
tool_key=dataset_cfg.tool_key,
apply_chat_template=args.apply_chat_template,
apply_chat_template_kwargs=args.apply_chat_template_kwargs,
num_proc=args.num_proc,
)
dataset = EVAL_PROMPT_DATASET[cache_key]

Expand Down
6 changes: 6 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,12 @@ def add_data_arguments(parser):
"and should be set to a larger value than `max_tokens_per_gpu` if you want better performance. "
),
)
parser.add_argument(
"--num-proc",
type=int,
default=8,
help="Number of processes for dataset initialization and filtering.",
)
return parser

def add_eval_arguments(parser):
Expand Down
144 changes: 106 additions & 38 deletions miles/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import logging
import os
import random
import re
from functools import partial

import datasets
import numpy as np
import pandas as pd
import ray
Expand All @@ -16,6 +17,16 @@

logger = logging.getLogger(__name__)

_FILE_TYPE_MAP = {
".jsonl": "json",
".parquet": "parquet",
}


def _filter_func(example, tokenizer, processor, max_length, prompt_key, multimodal_keys, apply_chat_template_kwargs):
prompt = _build_messages(example, prompt_key, multimodal_keys)
return not _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs)


# TODO: don't read the whole file into memory.
def read_file(path):
Expand Down Expand Up @@ -124,53 +135,110 @@ def __init__(
seed=42,
apply_chat_template=False,
apply_chat_template_kwargs=None,
num_proc=8,
):
self.origin_samples = []
for data in read_file(path):
prompt = _build_messages(data, prompt_key, multimodal_keys)

metadata = data.get(metadata_key) or {}
if tool_key is not None and tool_key in data:
tools = data[tool_key]
if isinstance(tools, str):
tools = json.loads(tools)
elif isinstance(tools, np.ndarray):
tools = tools.tolist()
assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead"
metadata["tools"] = tools

# TODO: this is slow.
if _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs):
continue
# 1. Store basic config
self.tokenizer = tokenizer
self.processor = processor
self.max_length = max_length
self.prompt_key = prompt_key
self.multimodal_keys = multimodal_keys
self.label_key = label_key
self.tool_key = tool_key
self.metadata_key = metadata_key
self.apply_chat_template_kwargs = apply_chat_template_kwargs or {}
self.seed = seed
self.epoch_id = -1

self.origin_samples.append(
Sample(
prompt=prompt,
label=data[label_key] if label_key is not None else None,
metadata=metadata,
)
)
# 2. Load and process dataset
self.hf_dataset = self._load_and_filter_dataset(path, num_proc)
self.origin_hf_dataset = self.hf_dataset

self.epoch_id = -1
self.seed = seed
self.samples = self.origin_samples
def _get_file_type(self, path: str) -> str:
_, ext = os.path.splitext(path)

try:
return _FILE_TYPE_MAP[ext]
except KeyError:
raise ValueError(f"Unsupported format: {ext}. Supported: {list(_FILE_TYPE_MAP.keys())}") from None

def _load_and_filter_dataset(self, path, num_proc):
raw_file_path, row_slice = _parse_generalized_path(path)

if not os.path.exists(raw_file_path):
raise FileNotFoundError(f"Prompt dataset path '{raw_file_path}' does not exist.")

logger.info(f"Loading dataset from {raw_file_path} using Hugging Face datasets.")

# Determine file type and load using datasets library for memory-mapped access
file_type = self._get_file_type(raw_file_path)
ds = datasets.load_dataset(file_type, data_files=raw_file_path, split="train")

# Apply row slicing if specified
if row_slice:
num_rows = len(ds)
indices = range(num_rows)[row_slice]
ds = ds.select(indices)
logger.info(f"Applied slice {row_slice}, dataset size: {len(ds)}")

filter_kwargs = {
"tokenizer": self.tokenizer,
"processor": self.processor,
"max_length": self.max_length,
"prompt_key": self.prompt_key,
"multimodal_keys": self.multimodal_keys,
"apply_chat_template_kwargs": self.apply_chat_template_kwargs,
}

original_size = len(ds)

ds = ds.filter(partial(_filter_func, **filter_kwargs), num_proc=num_proc, desc="Filtering invalid samples")

new_size = len(ds)
logger.info(f"Filtered dataset from {original_size} to {new_size} samples.")

return ds

def __len__(self):
return len(self.hf_dataset)

def __getitem__(self, idx):
# The underlying HF dataset handles lazy fetching
data = self.hf_dataset[idx]

# Process the data using existing logic
prompt = _build_messages(data, self.prompt_key, self.multimodal_keys)

metadata = data.get(self.metadata_key) or {}
if self.tool_key is not None and self.tool_key in data:
tools = data[self.tool_key]
if isinstance(tools, str):
tools = json.loads(tools)
Comment on lines +213 to +216
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a comment. May not need to be done this time:

You are parsing JSON and building messages dynamically in every getitem call. While this saves RAM, it adds significant CPU overhead during the training loop.

If the JSON parsing is heavy, we might need to use hf_dataset.map() during init to pre-process these fields into a more efficient format (Arrow-native), rather than parsing raw strings on the fly.

Copy link
Contributor Author

@Ratish1 Ratish1 Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is correct. But I have not done this change yet, since our primary goal for this PR was resolving the RAM spike during initialization, I believe this lazy approach is the safest first step. If we find that JSON parsing becomes a bottleneck for GPU throughput in future benchmarks, we can definitely change it to a .map() based pre-processing step to offload that work to the initialization phase.

# TODO (chenyang): If the JSON parsing is heavy, we might need
# to use hf_dataset.map() during init to pre-process these
# fields into a more efficient format (Arrow-native), rather
# than parsing raw strings on the fly.
elif isinstance(tools, np.ndarray):
tools = tools.tolist()
assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead"
metadata["tools"] = tools

sample = Sample(
prompt=prompt,
label=data.get(self.label_key) if self.label_key is not None else None,
metadata=metadata,
)

return sample

def shuffle(self, new_epoch_id):
if self.epoch_id == new_epoch_id:
return

random.seed(self.seed + new_epoch_id)
permutation = list(range(len(self.samples)))
random.shuffle(permutation)
self.samples = [self.origin_samples[i] for i in permutation]
logger.info(f"Shuffling dataset for epoch {new_epoch_id} with seed {self.seed + new_epoch_id}")
self.hf_dataset = self.origin_hf_dataset.shuffle(seed=self.seed + new_epoch_id)
self.epoch_id = new_epoch_id

def __getitem__(self, idx):
return self.samples[idx]

def __len__(self):
return len(self.samples)


def get_minimum_num_micro_batch_size(total_lengths, max_tokens_per_gpu):
# use first fit to get the number of micro batches
Expand Down