Skip to content

Support updating grain data mixture during training#2697

Closed
aireenmei wants to merge 6 commits into
mainfrom
aireen/grain_mix
Closed

Support updating grain data mixture during training#2697
aireenmei wants to merge 6 commits into
mainfrom
aireen/grain_mix

Conversation

@aireenmei

@aireenmei aireenmei commented Nov 15, 2025

Copy link
Copy Markdown
Collaborator

Description

See the added "9." in data_input_grain.md about this new feature.

FIXES: b/454051801

Tests

  1. config the new mixture in grain_mixture.json
  2. test script test_grain_mix.sh, training log
  3. Inspect the checkpoints under gs://aireenmei-multipod/test/grain_ckpt/gemini-test/grain-mixture-test-2025-11-15-06-44-13/checkpoints

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@github-actions

Copy link
Copy Markdown

🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This Pull Request introduces a new feature allowing dynamic updates to Grain data mixtures during training, which is well-documented and implemented. The associated changes involve refactoring data input validation and modifying checkpointing logic to support this new functionality.

🔍 General Feedback

  • The refactoring of data input validation into types.py is a good improvement for code organization and maintainability.
  • The new functions for merging iterator states in checkpointing.py correctly handle the complexity of updating data sources while resuming training.
  • The documentation for the new feature in data_input_grain.md is clear and provides good examples.
  • Consider refactoring duplicated code for applying dataset transformations in _grain_data_processing.py into a helper function to improve maintainability (as noted in an inline comment).

@@ -64,15 +65,21 @@ def get_datasets(
dataset_list = [

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The logic for applying shuffle, repeat, and sharding is duplicated for both mixed and single datasets. Consider extracting this into a helper function to improve maintainability and reduce redundancy.

Suggested change
dataset_list = [
def _apply_dataset_transforms(dataset, shuffle, shuffle_seed, num_epoch, dataloading_host_index, dataloading_host_count):
if shuffle:
dataset = dataset.shuffle(seed=shuffle_seed)
dataset = dataset.repeat(num_epoch)
dataset = dataset[dataloading_host_index::dataloading_host_count]
return dataset.to_iter_dataset()
# In get_datasets, for mixed datasets:
for d, _ in enumerate(dataset_list):
dataset_list[d] = _apply_dataset_transforms(
dataset_list[d],
shuffle,
shuffle_seed,
num_epoch,
dataloading_host_index,
dataloading_host_count,
)
dataset = grain.IterDataset.mix(dataset_list, weights)
else:
data_files = find_data_files(data_file_pattern)
dataset = grain.MapDataset.source(grain.ArrayRecordDataSource(data_files))
dataset = _apply_dataset_transforms(
dataset,
shuffle,
shuffle_seed,
num_epoch,
dataloading_host_index,
dataloading_host_count,
)


When you are ready to introduce the new dataset, create a JSON configuration file (e.g., `grain_mixture.json`). This file defines both the original mixture and the new one, along with the desired weights for blending them.

**Important:** The `old_dataset` section in this file must exactly match the sources and weights you used in `grain_train_files` for the initial run.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will not allow removing datasets, right? is that ok?

)
old_weight = mixture_config["old_dataset_weight"]
new_weight = mixture_config["new_dataset_weight"]
train_ds = grain.IterDataset.mix([old_dataset, new_dataset], weights=[old_weight, new_weight])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will work too, but now that you're using IterDataset.mix, checkpoint will have per-component checkpoint, so I think we could do some surgery there and recover each component separately. That would allow completely changing weights (e.g. changing and removing old weights as well). Is that something you'd be interested in? if so we can discuss further

@aireenmei aireenmei closed this Nov 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants