-
Notifications
You must be signed in to change notification settings - Fork 719
[PyTorch] Guard/document single parameter feature for grouped linear #2955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
7650120
ba5bddc
554abb8
16bd7be
0d20387
cbace12
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |||||||||||
| import functools | ||||||||||||
| import math | ||||||||||||
| import os | ||||||||||||
| import warnings | ||||||||||||
| from typing import Any, Callable, List, Optional, Sequence, Tuple, Union | ||||||||||||
| from contextlib import nullcontext | ||||||||||||
| import numpy as np | ||||||||||||
|
|
@@ -81,6 +82,36 @@ def get_device_compute_capability() -> Tuple[int, int]: | |||||||||||
| return _get_device_compute_capability(torch.cuda.current_device()) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def resolve_grouped_linear_single_param_flags( | ||||||||||||
| single_grouped_weight: bool, | ||||||||||||
| single_grouped_bias: bool, | ||||||||||||
| ) -> Tuple[bool, bool]: | ||||||||||||
| """Gate ``single_grouped_weight`` / ``single_grouped_bias`` on ``NVTE_GROUPED_LINEAR_SINGLE_PARAM``.""" | ||||||||||||
| if not (single_grouped_weight or single_grouped_bias): | ||||||||||||
| return single_grouped_weight, single_grouped_bias | ||||||||||||
|
|
||||||||||||
| env_enabled = int(os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0")) > 0 | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we only respect the kwargs if an envvar is set, it doesn't really make sense to keep the envvars rather than just checking the envvar. I guess we're half-heartedly maintaining/preparing the stable API for this feature. |
||||||||||||
| if not env_enabled: | ||||||||||||
| warnings.warn( | ||||||||||||
| f"GroupedLinear was constructed with single_grouped_weight={single_grouped_weight} " | ||||||||||||
| f"and single_grouped_bias={single_grouped_bias}, but the " | ||||||||||||
| "NVTE_GROUPED_LINEAR_SINGLE_PARAM environment variable is not set. " | ||||||||||||
| "Disabling single grouped weight/bias and falling back to per-expert parameters.", | ||||||||||||
| UserWarning, | ||||||||||||
| stacklevel=3, | ||||||||||||
| ) | ||||||||||||
| return False, False | ||||||||||||
|
|
||||||||||||
| warnings.warn( | ||||||||||||
| "GroupedLinear is using single_grouped_weight/single_grouped_bias. " | ||||||||||||
| "This feature is experimental, may change in future " | ||||||||||||
| "releases, and is known to be non-deterministic in certain cases.", | ||||||||||||
| UserWarning, | ||||||||||||
| stacklevel=3, | ||||||||||||
| ) | ||||||||||||
| return single_grouped_weight, single_grouped_bias | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def attention_mask_func( | ||||||||||||
| attention_scores: torch.Tensor, attention_mask: torch.Tensor | ||||||||||||
| ) -> torch.Tensor: | ||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While we are breaking backward compatibility, we might consider consolidating these options together. Do we really want to take on the burden of supporting the case with a single grouped weight and discrete bias, or discrete weights and single grouped bias?