Skip to content

Conversation

@mmathew23
Copy link
Collaborator

Add env variables to target the number of chunks used in cross entropy loss.

Also make chunk sizes consistent for tiled mlp arctic.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @mmathew23, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the flexibility and consistency of chunking mechanisms within the unsloth_zoo library. It allows users to control cross-entropy loss chunking through environment variables, providing a dynamic way to optimize memory usage or performance. Additionally, it refines the chunk size determination logic for tiled Multi-Layer Perceptrons (MLPs) in the 'arctic' configuration, aiming for more predictable and consistent behavior across different input sizes.

Highlights

  • Cross-Entropy Loss Configuration: Introduced the ability to configure the number of chunks (n_chunks) or target gigabytes (target_gb) for cross-entropy loss calculations using environment variables (UNSLOTH_CE_LOSS_N_CHUNKS and UNSLOTH_CE_LOSS_TARGET_GB).
  • Tiled MLP Chunk Size Consistency: Revised the calculation of n_shards and chunk_size within the tiled_forward_arctic_size method for tiled_mlp to ensure more consistent chunk sizing, passing the determined chunk_size to the underlying TiledMLP.apply function.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces environment variables to configure chunking in the cross-entropy loss function and updates the chunking logic for the tiled MLP in arctic models. The changes for the cross-entropy loss look good, though there's a small opportunity for code clarification. However, the updated chunking logic in tiled_mlp.py appears to have a bug where it doesn't account for the batch size, which could lead to unbalanced processing chunks. I've provided suggestions to address these points.

B, S, H = x.shape
n_shards = int(max(1, min(S, math.ceil(S / max(1, H)))))
chunk_size = max(1, H)
n_shards, remainder = divmod(S, chunk_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation of n_shards here only considers the sequence length S and not the batch size B. Since TiledMLP chunks the flattened tensor of size B*S, this will result in unbalanced chunks when B > 1. The last chunk will process a much larger amount of data than the preceding ones.

To ensure chunks are balanced, n_shards should be calculated based on the total number of elements B*S.

Suggested change
n_shards, remainder = divmod(S, chunk_size)
n_shards, remainder = divmod(B * S, chunk_size)

if hasattr(scaling, "get_scale"): scaling = scaling.get_scale()
if target_gb is None and 'n_chunks' not in kwargs:
target_gb = os.environ.get("UNSLOTH_CE_LOSS_TARGET_GB", None)
n_chunks = kwargs.get("n_chunks", os.environ.get("UNSLOTH_CE_LOSS_N_CHUNKS", None))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line is a bit confusing. Since this code block is only executed when 'n_chunks' not in kwargs, the kwargs.get("n_chunks", ...) will always fall back to its default value. You can simplify this by directly getting the value from the environment variable.

Suggested change
n_chunks = kwargs.get("n_chunks", os.environ.get("UNSLOTH_CE_LOSS_N_CHUNKS", None))
n_chunks = os.environ.get("UNSLOTH_CE_LOSS_N_CHUNKS", None)

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 247 to +250
B, S, H = x.shape
n_shards = int(max(1, min(S, math.ceil(S / max(1, H)))))
chunk_size = max(1, H)
n_shards, remainder = divmod(S, chunk_size)
n_shards = max(1, n_shards)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Arctic tiling ignores batch size when setting shard counts

When target_arctic is used, tiled_forward_arctic_size derives n_shards solely from the sequence length (S // H) but then passes max_flat_qlen=chunk_size into TiledMLP.apply. TiledMLP.forward uses max_flat_qlen to split the flattened tensor of length B*S, so with batch size > 1 the computed n_shards is too small and the remainder chunk contains almost all additional batches (e.g., B=8,S=4096,H=4096 produces shard sizes [4096, 28672]). Chunk sizes therefore still blow past the intended chunk_size and VRAM reduction becomes ineffective for multi-batch inputs, contradicting the intended “consistent” tiling.

Useful? React with 👍 / 👎.

@danielhanchen danielhanchen merged commit 7209a76 into unslothai:main Nov 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants