Skip to content
33 changes: 33 additions & 0 deletions docs/en/get_started/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -580,3 +580,36 @@ miles has been deeply optimized for distributed training of large-scale Mixture
- [Example: 64xH100 Training GLM-4.5](../examples/glm4.5-355B-A32B.md)
- [Example: 128xH100 Training DeepSeek-R1](../examples/deepseek-r1.md)
- The scripts such as `scripts/run_qwen3_30b_a3b.py`, `scripts/run_glm45_355b_a32b.py` also support multi-node training, though there are little documentations about it currently.

## Verification of Dataset Initialization

Once your environment is set up and weights are converted, you can verify that the system is correctly using **Lazy Data Loading** to handle large datasets efficiently.

Run the following command to start a short verification task:

```bash
python train.py \
${MODEL_ARGS[@]} \
--hf-checkpoint /root/GLM-Z1-9B-0414 \
--load /root/GLM-Z1-9B-0414_torch_dist \
--prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl \
--input-key prompt \
--label-key label \
--apply-chat-template \
--rm-type deepscaler \
--use-miles-router \
--num-proc 4 \
--rollout-batch-size 16 \
--global-batch-size 16 \
--n-samples-per-prompt 1 \
--num-rollout 1 \
--colocate
```

### What to Observe

1. **Fast Startup**: Even with a large dataset, the training process should start almost immediately. This confirms the data is being memory-mapped (Lazy Loading) instead of read entirely into RAM.
2. **Filtering Progress**: You will see a progress bar titled `Filtering invalid samples during init`. This confirms that the filtering logic is running in parallel and providing visual feedback.
3. **Configurable Parallelism**: By specifying `--num-proc 4`, we override the default value. You can verify that the system spawns exactly 4 worker processes for the data preparation phase.
4. **RAM Stability**: Monitor your system RAM during startup. It should remain stable because only the dataset indices are stored in memory, not the raw text.

1 change: 1 addition & 0 deletions miles/ray/rollout_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self, args):
apply_chat_template=args.apply_chat_template,
apply_chat_template_kwargs=args.apply_chat_template_kwargs,
seed=args.rollout_seed,
num_proc=args.num_proc,
)
if self.args.rollout_shuffle:
self.dataset.shuffle(self.epoch_id)
Expand Down
1 change: 1 addition & 0 deletions miles/rollout/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, args):
apply_chat_template=args.apply_chat_template,
apply_chat_template_kwargs=args.apply_chat_template_kwargs,
seed=args.rollout_seed,
num_proc=args.num_proc,
)
if self.args.rollout_shuffle:
self.dataset.shuffle(self.epoch_id)
Expand Down
1 change: 1 addition & 0 deletions miles/rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ async def eval_rollout_single_dataset(
tool_key=dataset_cfg.tool_key,
apply_chat_template=args.apply_chat_template,
apply_chat_template_kwargs=args.apply_chat_template_kwargs,
num_proc=args.num_proc,
)
dataset = EVAL_PROMPT_DATASET[cache_key]

Expand Down
6 changes: 6 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,12 @@ def add_data_arguments(parser):
"and should be set to a larger value than `max_tokens_per_gpu` if you want better performance. "
),
)
parser.add_argument(
"--num-proc",
type=int,
default=8,
help="Number of processes for dataset initialization and filtering.",
)
return parser

def add_eval_arguments(parser):
Expand Down
129 changes: 92 additions & 37 deletions miles/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
import logging
import os
import random
import re

import datasets
import numpy as np
import pandas as pd
import ray
Expand All @@ -16,6 +16,11 @@

logger = logging.getLogger(__name__)

_FILE_TYPE_MAP = {
".jsonl": "json",
".parquet": "parquet",
}


# TODO: don't read the whole file into memory.
def read_file(path):
Expand Down Expand Up @@ -124,53 +129,103 @@ def __init__(
seed=42,
apply_chat_template=False,
apply_chat_template_kwargs=None,
num_proc=8,
):
self.origin_samples = []
for data in read_file(path):
prompt = _build_messages(data, prompt_key, multimodal_keys)

metadata = data.get(metadata_key) or {}
if tool_key is not None and tool_key in data:
tools = data[tool_key]
if isinstance(tools, str):
tools = json.loads(tools)
elif isinstance(tools, np.ndarray):
tools = tools.tolist()
assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead"
metadata["tools"] = tools

# TODO: this is slow.
if _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs):
continue
# 1. Store basic config
self.tokenizer = tokenizer
self.processor = processor
self.max_length = max_length
self.prompt_key = prompt_key
self.multimodal_keys = multimodal_keys
self.label_key = label_key
self.tool_key = tool_key
self.metadata_key = metadata_key
self.apply_chat_template_kwargs = apply_chat_template_kwargs or {}
self.seed = seed
self.epoch_id = -1

self.origin_samples.append(
Sample(
prompt=prompt,
label=data[label_key] if label_key is not None else None,
metadata=metadata,
)
# 2. Load and process dataset
self.hf_dataset = self._load_and_filter_dataset(path, num_proc)
self.origin_hf_dataset = self.hf_dataset

def _get_file_type(self, path: str) -> str:
_, ext = os.path.splitext(path)

if ext not in _FILE_TYPE_MAP:
supported = list(_FILE_TYPE_MAP.keys())
raise ValueError(f"Unsupported file format: {ext}. Supported: {supported}")

return _FILE_TYPE_MAP[ext]

def _load_and_filter_dataset(self, path, num_proc):
raw_file_path, row_slice = _parse_generalized_path(path)

if not os.path.exists(raw_file_path):
raise FileNotFoundError(f"Prompt dataset path '{raw_file_path}' does not exist.")

logger.info(f"Loading dataset from {raw_file_path} using Hugging Face datasets.")

# Determine file type and load using datasets library for memory-mapped access
file_type = self._get_file_type(raw_file_path)
ds = datasets.load_dataset(file_type, data_files=raw_file_path, split="train")

# Apply row slicing if specified
if row_slice:
num_rows = len(ds)
indices = range(num_rows)[row_slice]
ds = ds.select(indices)
logger.info(f"Applied slice {row_slice}, dataset size: {len(ds)}")

# Apply filtering using the existing helper functions
def filter_func(example):
prompt = _build_messages(example, self.prompt_key, self.multimodal_keys)
return not _should_skip_prompt(
prompt, self.tokenizer, self.processor, self.max_length, self.apply_chat_template_kwargs
)

self.epoch_id = -1
self.seed = seed
self.samples = self.origin_samples
original_size = len(ds)
ds = ds.filter(filter_func, num_proc=num_proc, desc="Filtering invalid samples during init")
new_size = len(ds)
logger.info(f"Filtered dataset from {original_size} to {new_size} samples.")

return ds

def __len__(self):
return len(self.hf_dataset)

def __getitem__(self, idx):
# The underlying HF dataset handles lazy fetching
data = self.hf_dataset[idx]

# Process the data using existing logic
prompt = _build_messages(data, self.prompt_key, self.multimodal_keys)

metadata = data.get(self.metadata_key) or {}
if self.tool_key is not None and self.tool_key in data:
tools = data[self.tool_key]
if isinstance(tools, str):
tools = json.loads(tools)
Comment on lines +213 to +216
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a comment. May not need to be done this time:

You are parsing JSON and building messages dynamically in every getitem call. While this saves RAM, it adds significant CPU overhead during the training loop.

If the JSON parsing is heavy, we might need to use hf_dataset.map() during init to pre-process these fields into a more efficient format (Arrow-native), rather than parsing raw strings on the fly.

Copy link
Contributor Author

@Ratish1 Ratish1 Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is correct. But I have not done this change yet, since our primary goal for this PR was resolving the RAM spike during initialization, I believe this lazy approach is the safest first step. If we find that JSON parsing becomes a bottleneck for GPU throughput in future benchmarks, we can definitely change it to a .map() based pre-processing step to offload that work to the initialization phase.

elif isinstance(tools, np.ndarray):
tools = tools.tolist()
assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead"
metadata["tools"] = tools

sample = Sample(
prompt=prompt,
label=data.get(self.label_key) if self.label_key is not None else None,
metadata=metadata,
)

return sample

def shuffle(self, new_epoch_id):
if self.epoch_id == new_epoch_id:
return

random.seed(self.seed + new_epoch_id)
permutation = list(range(len(self.samples)))
random.shuffle(permutation)
self.samples = [self.origin_samples[i] for i in permutation]
logger.info(f"Shuffling dataset for epoch {new_epoch_id} with seed {self.seed + new_epoch_id}")
self.hf_dataset = self.origin_hf_dataset.shuffle(seed=self.seed + new_epoch_id)
self.epoch_id = new_epoch_id

def __getitem__(self, idx):
return self.samples[idx]

def __len__(self):
return len(self.samples)


def get_minimum_num_micro_batch_size(total_lengths, max_tokens_per_gpu):
# use first fit to get the number of micro batches
Expand Down