Support updating grain data mixture during training#2697
Conversation
bc4ec1d to
7bd865d
Compare
e9c8d43 to
8cb5c7d
Compare
8cb5c7d to
f0835a3
Compare
|
🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This Pull Request introduces a new feature allowing dynamic updates to Grain data mixtures during training, which is well-documented and implemented. The associated changes involve refactoring data input validation and modifying checkpointing logic to support this new functionality.
🔍 General Feedback
- The refactoring of data input validation into
types.pyis a good improvement for code organization and maintainability. - The new functions for merging iterator states in
checkpointing.pycorrectly handle the complexity of updating data sources while resuming training. - The documentation for the new feature in
data_input_grain.mdis clear and provides good examples. - Consider refactoring duplicated code for applying dataset transformations in
_grain_data_processing.pyinto a helper function to improve maintainability (as noted in an inline comment).
| @@ -64,15 +65,21 @@ def get_datasets( | |||
| dataset_list = [ | |||
There was a problem hiding this comment.
🟡 The logic for applying shuffle, repeat, and sharding is duplicated for both mixed and single datasets. Consider extracting this into a helper function to improve maintainability and reduce redundancy.
| dataset_list = [ | |
| def _apply_dataset_transforms(dataset, shuffle, shuffle_seed, num_epoch, dataloading_host_index, dataloading_host_count): | |
| if shuffle: | |
| dataset = dataset.shuffle(seed=shuffle_seed) | |
| dataset = dataset.repeat(num_epoch) | |
| dataset = dataset[dataloading_host_index::dataloading_host_count] | |
| return dataset.to_iter_dataset() | |
| # In get_datasets, for mixed datasets: | |
| for d, _ in enumerate(dataset_list): | |
| dataset_list[d] = _apply_dataset_transforms( | |
| dataset_list[d], | |
| shuffle, | |
| shuffle_seed, | |
| num_epoch, | |
| dataloading_host_index, | |
| dataloading_host_count, | |
| ) | |
| dataset = grain.IterDataset.mix(dataset_list, weights) | |
| else: | |
| data_files = find_data_files(data_file_pattern) | |
| dataset = grain.MapDataset.source(grain.ArrayRecordDataSource(data_files)) | |
| dataset = _apply_dataset_transforms( | |
| dataset, | |
| shuffle, | |
| shuffle_seed, | |
| num_epoch, | |
| dataloading_host_index, | |
| dataloading_host_count, | |
| ) |
|
|
||
| When you are ready to introduce the new dataset, create a JSON configuration file (e.g., `grain_mixture.json`). This file defines both the original mixture and the new one, along with the desired weights for blending them. | ||
|
|
||
| **Important:** The `old_dataset` section in this file must exactly match the sources and weights you used in `grain_train_files` for the initial run. |
There was a problem hiding this comment.
this will not allow removing datasets, right? is that ok?
| ) | ||
| old_weight = mixture_config["old_dataset_weight"] | ||
| new_weight = mixture_config["new_dataset_weight"] | ||
| train_ds = grain.IterDataset.mix([old_dataset, new_dataset], weights=[old_weight, new_weight]) |
There was a problem hiding this comment.
this will work too, but now that you're using IterDataset.mix, checkpoint will have per-component checkpoint, so I think we could do some surgery there and recover each component separately. That would allow completely changing weights (e.g. changing and removing old weights as well). Is that something you'd be interested in? if so we can discuss further
Description
See the added "9." in data_input_grain.md about this new feature.
FIXES: b/454051801
Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.