feat: Implement lazy data loading for Dataset#246
feat: Implement lazy data loading for Dataset#246zhaochenyang20 merged 14 commits intoradixark:mainfrom
Conversation
|
Too busy now, I will try to squeeze some time and have a review later :) |
|
No problem |
|
Hello @fzyzcjy, I wanted to follow up on this PR , but if you are busy its fine you can look at it later. Thanks. |
|
After an internal discussion, we think the most ideal code may be somehow different. Therefore, may be need to wait a little bit. |
|
Hello @fzyzcjy , I wanted to follow up on this. Thanks. |
|
Hi ! Quentin from HF Datasets here :) I was checking if I could help with anything, in particular if you are interested in loading larger than RAM datasets. E.g. the |
|
@Ratish1 I am really too busy recently to check more details about this :/ @zhaochenyang20 do you want to find someone to review these? @lhoestq looks great! Yes it would be great to handle huge datasets here, and at the same time keep code clean, b/c the framework supports SFT and can have big datasets. |
|
@Ratish1 Have conflicts right now. |
PopSoda2002
left a comment
There was a problem hiding this comment.
Hi, thanks for your effort and contribution! Nice work! Can I ask why do not choose memmap in python and can you share some comparison before and after?
|
Hi @Ratish1 , are you working with |
No I'm not. I will refactor this code with the |
037abea to
d5add95
Compare
PopSoda2002
left a comment
There was a problem hiding this comment.
Hi, thanks for the effort, can you please show some results or comparison?
Hi, I haven't compared since I think |
PopSoda2002
left a comment
There was a problem hiding this comment.
It looks good for me but I think it needs a more careful look!
lhoestq
left a comment
There was a problem hiding this comment.
I had a quick look and it looks good yo me, I just had one small nit:
| if self.tool_key is not None and self.tool_key in data: | ||
| tools = data[self.tool_key] | ||
| if isinstance(tools, str): | ||
| tools = json.loads(tools) |
There was a problem hiding this comment.
This is a comment. May not need to be done this time:
You are parsing JSON and building messages dynamically in every getitem call. While this saves RAM, it adds significant CPU overhead during the training loop.
If the JSON parsing is heavy, we might need to use hf_dataset.map() during init to pre-process these fields into a more efficient format (Arrow-native), rather than parsing raw strings on the fly.
There was a problem hiding this comment.
Yes this is correct. But I have not done this change yet, since our primary goal for this PR was resolving the RAM spike during initialization, I believe this lazy approach is the safest first step. If we find that JSON parsing becomes a bottleneck for GPU throughput in future benchmarks, we can definitely change it to a .map() based pre-processing step to offload that work to the initialization phase.
f2169a7 to
82065f2
Compare
|
Reproduce/Verification commands: docker pull radixark/miles:latest
docker run --rm --gpus all --ipc=host --shm-size=16g \
--ulimit memlock=-1 --ulimit stack=67108864 \
-it radixark/miles:latest /bin/bash
rm -rf /root/miles
git clone -b data-loading https://github.com/Ratish1/miles.git
cd miles
pip install -e .
hf download zai-org/GLM-Z1-9B-0414 --local-dir /data/ratish/GLM-Z1-9B-0414
hf download --repo-type dataset zhuzilin/dapo-math-17k --local-dir /data/ratish/dapo-math-17k
source scripts/models/glm4-9B.sh
PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py ${MODEL_ARGS[@]} --hf-checkpoint /data/ratish/GLM-Z1-9B-0414 --save /data/ratish/GLM-Z1-9B-0414_torch_dist
export CUDA_DEVICE_MAX_CONNECTIONS=1
export CUDA_VISIBLE_DEVICES=5,7
PYTHONPATH=/root/Megatron-LM python train.py \
${MODEL_ARGS[@]} \
--hf-checkpoint /data/ratish/GLM-Z1-9B-0414 \
--load /data/ratish/GLM-Z1-9B-0414_torch_dist \
--prompt-data /data/ratish/dapo-math-17k/dapo-math-17k.jsonl \
--input-key prompt \
--label-key label \
--apply-chat-template \
--use-miles-router \
--num-proc 4 \
--rollout-batch-size 16 \
--global-batch-size 16 \
--n-samples-per-prompt 1 \
--num-rollout 1 \
--colocate \
--sglang-mem-fraction-static 0.8 |
PopSoda2002
left a comment
There was a problem hiding this comment.
LGTM! And I think adding this to a doc which will explain a little bit will be better!
|
@Ratish1 hi, i am wondering if lazy loading is still working during filtering |
Hi, I've double-checked the underlying mechanism. ds.filter does indeed maintain the lazy loading behavior. Under the hood, datasets processes the data in batches by reading them from the memory-mapped (mmap) storage into RAM. Once the filter function is executed, the Python objects are immediately released. It does not materialize the entire dataset in memory at once. Therefore, even during the filtering phase, memory consumption stays at a small constant level (determined by the batch_size and num_proc), avoiding any memory peaks that scale with the dataset size. This is exactly why integrating the datasets library is so beneficial for our use case. |
This reverts commit 9a3b297.
|
Sorry for my bad. This is reverted. We need more unit tests for dataset consistency. 😭 |
this pr improves the
Datasetclass inmiles/utils/data.pyto support lazy data loading, filtering, and indexed access for large datasets. This addresses memory consumption issues, in SFT workloads, by avoiding the in-memory materialization of entire datasets. @fzyzcjy #226read_fileto use row-by-row iteration for JSONL and batched reading for Parquet.Datasetto build a lightweight index of valid sample locations on initialization, applyingmax_lengthfiltering during this phase.__getitem__to read and process only the requested sample on demand.