-
Notifications
You must be signed in to change notification settings - Fork 123
feat: Implement lazy data loading for Dataset #246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
21fa4d0
2266c4b
d5add95
65c58c4
1b803c2
f7a013f
5ebfce4
82065f2
eaecb71
0ee4037
9b3c607
e5872be
bd80ae9
89c0b4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,9 @@ | ||
| import json | ||
| import logging | ||
| import os | ||
| import random | ||
| import re | ||
|
|
||
| import datasets | ||
| import numpy as np | ||
| import pandas as pd | ||
| import ray | ||
|
|
@@ -16,6 +16,11 @@ | |
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _FILE_TYPE_MAP = { | ||
| ".jsonl": "json", | ||
| ".parquet": "parquet", | ||
| } | ||
|
|
||
|
|
||
| # TODO: don't read the whole file into memory. | ||
| def read_file(path): | ||
|
|
@@ -124,53 +129,103 @@ def __init__( | |
| seed=42, | ||
| apply_chat_template=False, | ||
| apply_chat_template_kwargs=None, | ||
| num_proc=8, | ||
| ): | ||
| self.origin_samples = [] | ||
| for data in read_file(path): | ||
| prompt = _build_messages(data, prompt_key, multimodal_keys) | ||
|
|
||
| metadata = data.get(metadata_key) or {} | ||
| if tool_key is not None and tool_key in data: | ||
| tools = data[tool_key] | ||
| if isinstance(tools, str): | ||
| tools = json.loads(tools) | ||
| elif isinstance(tools, np.ndarray): | ||
| tools = tools.tolist() | ||
| assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead" | ||
| metadata["tools"] = tools | ||
|
|
||
| # TODO: this is slow. | ||
| if _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs): | ||
| continue | ||
| # 1. Store basic config | ||
| self.tokenizer = tokenizer | ||
| self.processor = processor | ||
| self.max_length = max_length | ||
| self.prompt_key = prompt_key | ||
| self.multimodal_keys = multimodal_keys | ||
| self.label_key = label_key | ||
| self.tool_key = tool_key | ||
| self.metadata_key = metadata_key | ||
| self.apply_chat_template_kwargs = apply_chat_template_kwargs or {} | ||
| self.seed = seed | ||
| self.epoch_id = -1 | ||
|
|
||
| self.origin_samples.append( | ||
| Sample( | ||
| prompt=prompt, | ||
| label=data[label_key] if label_key is not None else None, | ||
| metadata=metadata, | ||
| ) | ||
| # 2. Load and process dataset | ||
| self.hf_dataset = self._load_and_filter_dataset(path, num_proc) | ||
| self.origin_hf_dataset = self.hf_dataset | ||
|
|
||
| def _get_file_type(self, path: str) -> str: | ||
| _, ext = os.path.splitext(path) | ||
|
|
||
| if ext not in _FILE_TYPE_MAP: | ||
| supported = list(_FILE_TYPE_MAP.keys()) | ||
| raise ValueError(f"Unsupported file format: {ext}. Supported: {supported}") | ||
|
|
||
| return _FILE_TYPE_MAP[ext] | ||
|
|
||
| def _load_and_filter_dataset(self, path, num_proc): | ||
| raw_file_path, row_slice = _parse_generalized_path(path) | ||
|
|
||
| if not os.path.exists(raw_file_path): | ||
| raise FileNotFoundError(f"Prompt dataset path '{raw_file_path}' does not exist.") | ||
|
|
||
| logger.info(f"Loading dataset from {raw_file_path} using Hugging Face datasets.") | ||
|
|
||
| # Determine file type and load using datasets library for memory-mapped access | ||
| file_type = self._get_file_type(raw_file_path) | ||
| ds = datasets.load_dataset(file_type, data_files=raw_file_path, split="train") | ||
|
|
||
| # Apply row slicing if specified | ||
| if row_slice: | ||
| num_rows = len(ds) | ||
| indices = range(num_rows)[row_slice] | ||
| ds = ds.select(indices) | ||
| logger.info(f"Applied slice {row_slice}, dataset size: {len(ds)}") | ||
|
|
||
| # Apply filtering using the existing helper functions | ||
| def filter_func(example): | ||
| prompt = _build_messages(example, self.prompt_key, self.multimodal_keys) | ||
| return not _should_skip_prompt( | ||
| prompt, self.tokenizer, self.processor, self.max_length, self.apply_chat_template_kwargs | ||
| ) | ||
|
|
||
| self.epoch_id = -1 | ||
| self.seed = seed | ||
| self.samples = self.origin_samples | ||
| original_size = len(ds) | ||
| ds = ds.filter(filter_func, num_proc=num_proc, desc="Filtering invalid samples during init") | ||
| new_size = len(ds) | ||
| logger.info(f"Filtered dataset from {original_size} to {new_size} samples.") | ||
|
|
||
| return ds | ||
|
|
||
| def __len__(self): | ||
| return len(self.hf_dataset) | ||
|
|
||
| def __getitem__(self, idx): | ||
| # The underlying HF dataset handles lazy fetching | ||
| data = self.hf_dataset[idx] | ||
|
|
||
| # Process the data using existing logic | ||
| prompt = _build_messages(data, self.prompt_key, self.multimodal_keys) | ||
|
|
||
| metadata = data.get(self.metadata_key) or {} | ||
| if self.tool_key is not None and self.tool_key in data: | ||
| tools = data[self.tool_key] | ||
| if isinstance(tools, str): | ||
| tools = json.loads(tools) | ||
|
Comment on lines
+213
to
+216
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a comment. May not need to be done this time: You are parsing JSON and building messages dynamically in every getitem call. While this saves RAM, it adds significant CPU overhead during the training loop. If the JSON parsing is heavy, we might need to use hf_dataset.map() during init to pre-process these fields into a more efficient format (Arrow-native), rather than parsing raw strings on the fly.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes this is correct. But I have not done this change yet, since our primary goal for this PR was resolving the RAM spike during initialization, I believe this lazy approach is the safest first step. If we find that JSON parsing becomes a bottleneck for GPU throughput in future benchmarks, we can definitely change it to a .map() based pre-processing step to offload that work to the initialization phase. |
||
| elif isinstance(tools, np.ndarray): | ||
| tools = tools.tolist() | ||
| assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead" | ||
| metadata["tools"] = tools | ||
|
|
||
| sample = Sample( | ||
| prompt=prompt, | ||
| label=data.get(self.label_key) if self.label_key is not None else None, | ||
| metadata=metadata, | ||
| ) | ||
|
|
||
| return sample | ||
|
|
||
| def shuffle(self, new_epoch_id): | ||
| if self.epoch_id == new_epoch_id: | ||
| return | ||
|
|
||
| random.seed(self.seed + new_epoch_id) | ||
| permutation = list(range(len(self.samples))) | ||
| random.shuffle(permutation) | ||
| self.samples = [self.origin_samples[i] for i in permutation] | ||
| logger.info(f"Shuffling dataset for epoch {new_epoch_id} with seed {self.seed + new_epoch_id}") | ||
| self.hf_dataset = self.origin_hf_dataset.shuffle(seed=self.seed + new_epoch_id) | ||
| self.epoch_id = new_epoch_id | ||
|
|
||
| def __getitem__(self, idx): | ||
| return self.samples[idx] | ||
|
|
||
| def __len__(self): | ||
| return len(self.samples) | ||
|
|
||
|
|
||
| def get_minimum_num_micro_batch_size(total_lengths, max_tokens_per_gpu): | ||
| # use first fit to get the number of micro batches | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.