Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ mkdir -p "$XML_LOG_DIR"

pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"

python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
Expand All @@ -41,7 +41,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_backward_override.xml $TE_PATH/tests/pytorch/test_backward_override.py || test_fail "test_backward_override.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
Expand Down
14 changes: 12 additions & 2 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
clear_tensor_data,
init_method_constant,
requires_grad,
resolve_grouped_linear_single_param_flags,
get_nvtx_range_context,
)
from ..distributed import (
Expand Down Expand Up @@ -673,11 +674,17 @@ class GroupedLinear(TransformerEngineBaseModule):
single_grouped_weight : bool, default = False
If set to ``True``, grouped weights are stored as a single grouped parameter
instead of one parameter per GEMM.
EXPERIMENTAL and subject to change.
EXPERIMENTAL and subject to change. Gated by the
``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var
is not set this argument is forced to ``False`` with a warning.
When enabled, this path is known to be non-deterministic in certain cases.
single_grouped_bias : bool, default = False
If set to ``True``, grouped biases are stored as a single grouped bias
instead of one bias per GEMM.
EXPERIMENTAL and subject to change.
EXPERIMENTAL and subject to change. Gated by the
``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var
is not set this argument is forced to ``False`` with a warning.
When enabled, this path is known to be non-deterministic in certain cases.

Notes
-----
Expand Down Expand Up @@ -726,6 +733,9 @@ def __init__(
self.ub_overlap_ag = ub_overlap_ag
self.ub_name = ub_name
self.save_original_input = save_original_input
single_grouped_weight, single_grouped_bias = resolve_grouped_linear_single_param_flags(
single_grouped_weight, single_grouped_bias
)
self.single_grouped_weight = single_grouped_weight
self.single_grouped_bias = single_grouped_bias
if ub_overlap_rs or ub_overlap_ag:
Expand Down
12 changes: 12 additions & 0 deletions transformer_engine/pytorch/ops/basic/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
canonicalize_dtype,
clear_tensor_data,
devices_match,
resolve_grouped_linear_single_param_flags,
round_up_to_nearest_multiple,
)
from .._common import is_quantized_tensor, maybe_dequantize
Expand Down Expand Up @@ -78,11 +79,19 @@ class GroupedLinear(BasicOperation):
``main_grad`` instead of accumulating.
single_grouped_weight : bool, default = ``False``
Store all expert weights as one ``GroupedTensor`` parameter ``weight``.
EXPERIMENTAL and subject to change. Gated by the
``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var
is not set this argument is forced to ``False`` with a warning.
When enabled, this path is known to be non-deterministic in certain cases.
delay_wgrad_compute : bool, default = ``False``
Whether to delay weight gradient computation
single_grouped_bias : bool, default = ``False``
If ``True`` (and ``bias=True``), store all expert biases as one ``GroupedTensor``
parameter named ``bias`` instead of ``bias0``..``bias{N-1}``.
EXPERIMENTAL and subject to change. Gated by the
``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var
is not set this argument is forced to ``False`` with a warning.
When enabled, this path is known to be non-deterministic in certain cases.
scale_bias : bool, default = ``False``
If ``True`` (and ``bias=True``), expects a probability tensor as an
additional extra input and adds ``bias * scales`` instead of ``bias``
Expand Down Expand Up @@ -123,6 +132,9 @@ def __init__(
self.num_groups: int = num_groups
self.in_features: int = in_features
self.out_features: int = out_features
single_grouped_weight, single_grouped_bias = resolve_grouped_linear_single_param_flags(
single_grouped_weight, single_grouped_bias
)
self.single_grouped_weight: bool = single_grouped_weight
self.single_grouped_bias: bool = single_grouped_bias
self.use_bias: bool = bias
Expand Down
31 changes: 31 additions & 0 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import functools
import math
import os
import warnings
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from contextlib import nullcontext
import numpy as np
Expand Down Expand Up @@ -81,6 +82,36 @@ def get_device_compute_capability() -> Tuple[int, int]:
return _get_device_compute_capability(torch.cuda.current_device())


def resolve_grouped_linear_single_param_flags(
single_grouped_weight: bool,
single_grouped_bias: bool,
Comment on lines +86 to +87
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we are breaking backward compatibility, we might consider consolidating these options together. Do we really want to take on the burden of supporting the case with a single grouped weight and discrete bias, or discrete weights and single grouped bias?

) -> Tuple[bool, bool]:
"""Gate ``single_grouped_weight`` / ``single_grouped_bias`` on ``NVTE_GROUPED_LINEAR_SINGLE_PARAM``."""
if not (single_grouped_weight or single_grouped_bias):
return single_grouped_weight, single_grouped_bias

env_enabled = int(os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0")) > 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Non-integer env var value will raise ValueError

int(os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0")) will throw an uncaught ValueError if the variable is set to a non-numeric string (e.g. "true", "yes"). Wrapping in a try/except would give a cleaner error message. This is consistent with other similar env-var checks in the file, so this is a pre-existing pattern rather than a new regression — flagging for awareness.

Suggested change
env_enabled = int(os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0")) > 0
try:
env_enabled = int(os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0")) > 0
except ValueError:
env_enabled = False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we only respect the kwargs if an envvar is set, it doesn't really make sense to keep the envvars rather than just checking the envvar. I guess we're half-heartedly maintaining/preparing the stable API for this feature.

if not env_enabled:
warnings.warn(
f"GroupedLinear was constructed with single_grouped_weight={single_grouped_weight} "
f"and single_grouped_bias={single_grouped_bias}, but the "
"NVTE_GROUPED_LINEAR_SINGLE_PARAM environment variable is not set. "
"Disabling single grouped weight/bias and falling back to per-expert parameters.",
UserWarning,
stacklevel=3,
)
return False, False

warnings.warn(
"GroupedLinear is using single_grouped_weight/single_grouped_bias. "
"This feature is experimental, may change in future "
"releases, and is known to be non-deterministic in certain cases.",
UserWarning,
stacklevel=3,
)
return single_grouped_weight, single_grouped_bias


def attention_mask_func(
attention_scores: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
Expand Down
Loading