Skip to content

[PyTorch] Guard/document single parameter feature for grouped linear#2955

Open
ksivaman wants to merge 2 commits intoNVIDIA:mainfrom
ksivaman:single_param_better_doc
Open

[PyTorch] Guard/document single parameter feature for grouped linear#2955
ksivaman wants to merge 2 commits intoNVIDIA:mainfrom
ksivaman:single_param_better_doc

Conversation

@ksivaman
Copy link
Copy Markdown
Member

@ksivaman ksivaman commented May 1, 2026

Description

Guard/document single parameter feature for grouped linear.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Guard/document single parameter feature for grouped linear

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested review from ptrendx and timmoon10 May 1, 2026 22:28
@ksivaman ksivaman added the 2.15.0 label May 1, 2026
@ksivaman
Copy link
Copy Markdown
Member Author

ksivaman commented May 1, 2026

/te-ci pytorch L0

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 1, 2026

Greptile Summary

This PR guards the experimental single_grouped_weight / single_grouped_bias feature for both GroupedLinear implementations behind the NVTE_GROUPED_LINEAR_SINGLE_PARAM environment variable and documents the behaviour in both docstrings. The logic is extracted into a shared resolve_grouped_linear_single_param_flags utility that emits a UserWarning whether the feature is silently disabled (env var absent) or actively enabled (env var set), and the CI test runner is updated to set the env var for the suites that exercise this path.

Confidence Score: 4/5

Safe to merge — changes are limited to documentation, an env-var guard, and a shared utility; no logic regressions identified.

Only P2 (style) findings present; the implementation is correct and consistent with existing patterns in the codebase.

transformer_engine/pytorch/utils.py — env var parsing without error handling on line 93.

Important Files Changed

Filename Overview
transformer_engine/pytorch/utils.py Adds resolve_grouped_linear_single_param_flags to gate the experimental single-param feature on NVTE_GROUPED_LINEAR_SINGLE_PARAM; issues UserWarning in both the disabled and enabled cases using stacklevel=3.
transformer_engine/pytorch/module/grouped_linear.py Imports and calls resolve_grouped_linear_single_param_flags in __init__ before assigning flags; updates docstrings to document env-var gate and non-determinism caveat.
transformer_engine/pytorch/ops/basic/grouped_linear.py Same gating call as the module-level class; docstrings updated consistently.
qa/L0_pytorch_unittest/test.sh Sets NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 for test_sanity.py and test_fusible_ops.py so those suites continue to exercise the single-param code path now that the feature is gated by default.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["User constructs GroupedLinear\n(single_grouped_weight=True\nor single_grouped_bias=True)"] --> B["resolve_grouped_linear_single_param_flags()"]
    B --> C{Either flag true?}
    C -- No --> D["Return flags unchanged (no-op)"]
    C -- Yes --> E{NVTE_GROUPED_LINEAR_SINGLE_PARAM set?}
    E -- No --> F["Emit UserWarning: env var not set, feature disabled"]
    F --> G["Return False, False (fall back to per-expert params)"]
    E -- Yes --> H["Emit UserWarning: experimental / non-deterministic"]
    H --> I["Return original flags (feature enabled)"]
Loading

Reviews (1): Last reviewed commit: "Merge branch 'main' into single_param_be..." | Re-trigger Greptile

if not (single_grouped_weight or single_grouped_bias):
return single_grouped_weight, single_grouped_bias

env_enabled = int(os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0")) > 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Non-integer env var value will raise ValueError

int(os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0")) will throw an uncaught ValueError if the variable is set to a non-numeric string (e.g. "true", "yes"). Wrapping in a try/except would give a cleaner error message. This is consistent with other similar env-var checks in the file, so this is a pre-existing pattern rather than a new regression — flagging for awareness.

Suggested change
env_enabled = int(os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0")) > 0
try:
env_enabled = int(os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0")) > 0
except ValueError:
env_enabled = False

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic breaks backward compatibility since we can no longer rely on the module kwargs to configure single grouped params. I guess we've already been treating single grouped params as an experimental feature. We should keep this instability in mind whenever we use this feature externally, e.g. in Mcore.

Comment on lines +86 to +87
single_grouped_weight: bool,
single_grouped_bias: bool,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we are breaking backward compatibility, we might consider consolidating these options together. Do we really want to take on the burden of supporting the case with a single grouped weight and discrete bias, or discrete weights and single grouped bias?

if not (single_grouped_weight or single_grouped_bias):
return single_grouped_weight, single_grouped_bias

env_enabled = int(os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0")) > 0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we only respect the kwargs if an envvar is set, it doesn't really make sense to keep the envvars rather than just checking the envvar. I guess we're half-heartedly maintaining/preparing the stable API for this feature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants