-
Notifications
You must be signed in to change notification settings - Fork 182
Configure ce target gb #365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @mmathew23, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the flexibility and consistency of chunking mechanisms within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces environment variables to configure chunking in the cross-entropy loss function and updates the chunking logic for the tiled MLP in arctic models. The changes for the cross-entropy loss look good, though there's a small opportunity for code clarification. However, the updated chunking logic in tiled_mlp.py appears to have a bug where it doesn't account for the batch size, which could lead to unbalanced processing chunks. I've provided suggestions to address these points.
| B, S, H = x.shape | ||
| n_shards = int(max(1, min(S, math.ceil(S / max(1, H))))) | ||
| chunk_size = max(1, H) | ||
| n_shards, remainder = divmod(S, chunk_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation of n_shards here only considers the sequence length S and not the batch size B. Since TiledMLP chunks the flattened tensor of size B*S, this will result in unbalanced chunks when B > 1. The last chunk will process a much larger amount of data than the preceding ones.
To ensure chunks are balanced, n_shards should be calculated based on the total number of elements B*S.
| n_shards, remainder = divmod(S, chunk_size) | |
| n_shards, remainder = divmod(B * S, chunk_size) |
| if hasattr(scaling, "get_scale"): scaling = scaling.get_scale() | ||
| if target_gb is None and 'n_chunks' not in kwargs: | ||
| target_gb = os.environ.get("UNSLOTH_CE_LOSS_TARGET_GB", None) | ||
| n_chunks = kwargs.get("n_chunks", os.environ.get("UNSLOTH_CE_LOSS_N_CHUNKS", None)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is a bit confusing. Since this code block is only executed when 'n_chunks' not in kwargs, the kwargs.get("n_chunks", ...) will always fall back to its default value. You can simplify this by directly getting the value from the environment variable.
| n_chunks = kwargs.get("n_chunks", os.environ.get("UNSLOTH_CE_LOSS_N_CHUNKS", None)) | |
| n_chunks = os.environ.get("UNSLOTH_CE_LOSS_N_CHUNKS", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| B, S, H = x.shape | ||
| n_shards = int(max(1, min(S, math.ceil(S / max(1, H))))) | ||
| chunk_size = max(1, H) | ||
| n_shards, remainder = divmod(S, chunk_size) | ||
| n_shards = max(1, n_shards) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arctic tiling ignores batch size when setting shard counts
When target_arctic is used, tiled_forward_arctic_size derives n_shards solely from the sequence length (S // H) but then passes max_flat_qlen=chunk_size into TiledMLP.apply. TiledMLP.forward uses max_flat_qlen to split the flattened tensor of length B*S, so with batch size > 1 the computed n_shards is too small and the remainder chunk contains almost all additional batches (e.g., B=8,S=4096,H=4096 produces shard sizes [4096, 28672]). Chunk sizes therefore still blow past the intended chunk_size and VRAM reduction becomes ineffective for multi-batch inputs, contradicting the intended “consistent” tiling.
Useful? React with 👍 / 👎.
Add env variables to target the number of chunks used in cross entropy loss.
Also make chunk sizes consistent for tiled mlp arctic.