Skip to content

feat: Implement lazy data loading for Dataset#246

Merged
zhaochenyang20 merged 14 commits intoradixark:mainfrom
Ratish1:data-loading
Dec 31, 2025
Merged

feat: Implement lazy data loading for Dataset#246
zhaochenyang20 merged 14 commits intoradixark:mainfrom
Ratish1:data-loading

Conversation

@Ratish1
Copy link
Contributor

@Ratish1 Ratish1 commented Nov 24, 2025

this pr improves the Dataset class in miles/utils/data.py to support lazy data loading, filtering, and indexed access for large datasets. This addresses memory consumption issues, in SFT workloads, by avoiding the in-memory materialization of entire datasets. @fzyzcjy #226

  • Modified read_file to use row-by-row iteration for JSONL and batched reading for Parquet.
  • Refactored Dataset to build a lightweight index of valid sample locations on initialization, applying max_length filtering during this phase.
  • Implemented __getitem__ to read and process only the requested sample on demand.

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Nov 24, 2025

Too busy now, I will try to squeeze some time and have a review later :)

@Ratish1
Copy link
Contributor Author

Ratish1 commented Nov 24, 2025

No problem

@Ratish1
Copy link
Contributor Author

Ratish1 commented Nov 25, 2025

Hello @fzyzcjy, I wanted to follow up on this PR , but if you are busy its fine you can look at it later. Thanks.

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Nov 25, 2025

After an internal discussion, we think the most ideal code may be somehow different. Therefore, may be need to wait a little bit.

@Ratish1
Copy link
Contributor Author

Ratish1 commented Dec 7, 2025

Hello @fzyzcjy , I wanted to follow up on this. Thanks.

@lhoestq
Copy link

lhoestq commented Dec 8, 2025

Hi ! Quentin from HF Datasets here :)

I was checking if I could help with anything, in particular if you are interested in loading larger than RAM datasets. E.g. the datasets library loads many formats like Parquet and JSON, makes tokenization easy, and works with larger than RAM datasets since the data is loaded with memory mapped Arrow files on disk. It also has streaming features to load datasets bigger than disk

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Dec 9, 2025

@Ratish1 I am really too busy recently to check more details about this :/ @zhaochenyang20 do you want to find someone to review these?

@lhoestq looks great! Yes it would be great to handle huge datasets here, and at the same time keep code clean, b/c the framework supports SFT and can have big datasets.

@Ratish1
Copy link
Contributor Author

Ratish1 commented Dec 9, 2025

Np @fzyzcjy , maybe I can work with @lhoestq and use the HF datasets library here

@zhaochenyang20
Copy link
Collaborator

@Ratish1 Have conflicts right now.

Copy link
Contributor

@PopSoda2002 PopSoda2002 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks for your effort and contribution! Nice work! Can I ask why do not choose memmap in python and can you share some comparison before and after?

@PopSoda2002
Copy link
Contributor

Hi @Ratish1 , are you working with datasets from HF? It that, that will be nice

@Ratish1
Copy link
Contributor Author

Ratish1 commented Dec 11, 2025

Hi @Ratish1 , are you working with datasets from HF? It that, that will be nice

No I'm not. I will refactor this code with the datasets library instead now. Thanks.

@Ratish1 Ratish1 requested a review from PopSoda2002 December 11, 2025 10:40
Copy link
Contributor

@PopSoda2002 PopSoda2002 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks for the effort, can you please show some results or comparison?

@Ratish1
Copy link
Contributor Author

Ratish1 commented Dec 12, 2025

Hi, thanks for the effort, can you please show some results or comparison?

Hi, I haven't compared since I think datasets would be much faster for large datasets, but I ran a memory benchmark using a 500 MB dummy JSONL file (5 million rows) to verify the lazy loading of the current implementation. This is what I got

BENCHMARK RESULTS (Lazy Loading)
Initial Memory:   371.73 MiB
Peak Memory:      469.66 MiB
Memory Increment: 97.93 MiB

@Ratish1 Ratish1 requested a review from PopSoda2002 December 12, 2025 07:45
Copy link
Contributor

@PopSoda2002 PopSoda2002 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good for me but I think it needs a more careful look!

Copy link

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a quick look and it looks good yo me, I just had one small nit:

Comment on lines +190 to +193
if self.tool_key is not None and self.tool_key in data:
tools = data[self.tool_key]
if isinstance(tools, str):
tools = json.loads(tools)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a comment. May not need to be done this time:

You are parsing JSON and building messages dynamically in every getitem call. While this saves RAM, it adds significant CPU overhead during the training loop.

If the JSON parsing is heavy, we might need to use hf_dataset.map() during init to pre-process these fields into a more efficient format (Arrow-native), rather than parsing raw strings on the fly.

Copy link
Contributor Author

@Ratish1 Ratish1 Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is correct. But I have not done this change yet, since our primary goal for this PR was resolving the RAM spike during initialization, I believe this lazy approach is the safest first step. If we find that JSON parsing becomes a bottleneck for GPU throughput in future benchmarks, we can definitely change it to a .map() based pre-processing step to offload that work to the initialization phase.

@zhaochenyang20
Copy link
Collaborator

Reproduce/Verification commands:

docker pull radixark/miles:latest

docker run --rm --gpus all --ipc=host --shm-size=16g \
  --ulimit memlock=-1 --ulimit stack=67108864 \
  -it radixark/miles:latest /bin/bash

rm -rf /root/miles

git clone -b data-loading https://github.com/Ratish1/miles.git

cd miles

pip install -e .

hf download zai-org/GLM-Z1-9B-0414 --local-dir /data/ratish/GLM-Z1-9B-0414

hf download --repo-type dataset zhuzilin/dapo-math-17k --local-dir /data/ratish/dapo-math-17k

source scripts/models/glm4-9B.sh

PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py ${MODEL_ARGS[@]} --hf-checkpoint /data/ratish/GLM-Z1-9B-0414 --save /data/ratish/GLM-Z1-9B-0414_torch_dist

export CUDA_DEVICE_MAX_CONNECTIONS=1
export CUDA_VISIBLE_DEVICES=5,7

PYTHONPATH=/root/Megatron-LM python train.py \
    ${MODEL_ARGS[@]} \
    --hf-checkpoint /data/ratish/GLM-Z1-9B-0414 \
    --load /data/ratish/GLM-Z1-9B-0414_torch_dist \
    --prompt-data /data/ratish/dapo-math-17k/dapo-math-17k.jsonl \
    --input-key prompt \
    --label-key label \
    --apply-chat-template \
    --use-miles-router \
    --num-proc 4 \
    --rollout-batch-size 16 \
    --global-batch-size 16 \
    --n-samples-per-prompt 1 \
    --num-rollout 1 \
    --colocate \
    --sglang-mem-fraction-static 0.8

Copy link
Contributor

@PopSoda2002 PopSoda2002 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! And I think adding this to a doc which will explain a little bit will be better!

@zhaochenyang20 zhaochenyang20 merged commit 9a3b297 into radixark:main Dec 31, 2025
3 checks passed
@nanjiangwill
Copy link

nanjiangwill commented Dec 31, 2025

@Ratish1 hi, i am wondering if lazy loading is still working during filtering ds = ds.filter(partial(_filter_func, **filter_kwargs), num_proc=num_proc, desc="Filtering invalid samples") (i.e. no peak memory consumption

@zhaochenyang20
Copy link
Collaborator

@Ratish1 hi, i am wondering if lazy loading is still working during filtering ds = ds.filter(partial(_filter_func, **filter_kwargs), num_proc=num_proc, desc="Filtering invalid samples") (i.e. no peak memory consumption

Hi, I've double-checked the underlying mechanism. ds.filter does indeed maintain the lazy loading behavior.

Under the hood, datasets processes the data in batches by reading them from the memory-mapped (mmap) storage into RAM. Once the filter function is executed, the Python objects are immediately released. It does not materialize the entire dataset in memory at once. Therefore, even during the filtering phase, memory consumption stays at a small constant level (determined by the batch_size and num_proc), avoiding any memory peaks that scale with the dataset size. This is exactly why integrating the datasets library is so beneficial for our use case.

zhaochenyang20 added a commit that referenced this pull request Dec 31, 2025
@zhaochenyang20
Copy link
Collaborator

#372 (comment)

Sorry for my bad. This is reverted. We need more unit tests for dataset consistency. 😭

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants