Skip to content

Commit 4460376

Browse files
zju-stu-lizheng瑀澈
andauthored
fix(config): Support setting Mamba state dtype via config file (sgl-project#18532)
Co-authored-by: 瑀澈 <yuche.lz@alibaba-inc.com>
1 parent d0d387d commit 4460376

File tree

8 files changed

+103
-19
lines changed

8 files changed

+103
-19
lines changed

python/sglang/srt/configs/falcon_h1.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
from transformers.configuration_utils import PretrainedConfig
1919
from transformers.utils import logging
2020

21-
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
21+
from sglang.srt.configs.mamba_utils import (
22+
Mamba2CacheParams,
23+
Mamba2StateShape,
24+
mamba2_state_dtype,
25+
)
2226

2327
logger = logging.get_logger(__name__)
2428

@@ -307,4 +311,6 @@ def mamba2_cache_params(self):
307311
state_size=self.mamba_d_state,
308312
conv_kernel=self.mamba_d_conv,
309313
)
310-
return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)
314+
return Mamba2CacheParams(
315+
shape=shape, layers=self.linear_layer_ids, dtype=mamba2_state_dtype(self)
316+
)

python/sglang/srt/configs/jet_nemotron.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33

44
from transformers.configuration_utils import PretrainedConfig
55

6-
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
6+
from sglang.srt.configs.mamba_utils import (
7+
Mamba2CacheParams,
8+
Mamba2StateShape,
9+
mamba2_state_dtype,
10+
)
711

812

913
@dataclass
@@ -71,4 +75,6 @@ def mamba2_cache_params(self) -> Mamba2CacheParams:
7175
conv_kernel=jet_block_config.conv_size,
7276
)
7377

74-
return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)
78+
return Mamba2CacheParams(
79+
shape=shape, layers=self.linear_layer_ids, dtype=mamba2_state_dtype(self)
80+
)

python/sglang/srt/configs/lfm2.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
from transformers import Lfm2Config as HFLfm2Config
2121
from transformers.utils import logging
2222

23-
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
23+
from sglang.srt.configs.mamba_utils import (
24+
Mamba2CacheParams,
25+
Mamba2StateShape,
26+
mamba2_state_dtype,
27+
)
2428

2529
logger = logging.get_logger(__name__)
2630

@@ -87,11 +91,10 @@ def mamba2_cache_params(self) -> Optional[Mamba2CacheParams]:
8791
conv_kernel=conv_kernel,
8892
)
8993

90-
# Uses default mamba2_state_dtype() which reads SGLANG_MAMBA_CONV_DTYPE env var
91-
# (defaults to bfloat16). Set SGLANG_MAMBA_CONV_DTYPE=float16 for fp16 inference.
9294
return Mamba2CacheParams(
9395
shape=shape,
9496
layers=conv_layer_ids,
97+
dtype=mamba2_state_dtype(self),
9598
)
9699

97100

python/sglang/srt/configs/mamba_utils.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# limitations under the License.
1313
"""Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, LFM2, etc."""
1414

15+
import logging
1516
from abc import ABC
1617
from dataclasses import dataclass, field
1718
from typing import List, Optional
@@ -22,6 +23,8 @@
2223
from sglang.srt.distributed.utils import divide
2324
from sglang.srt.environ import envs
2425

26+
logger = logging.getLogger(__name__)
27+
2528

2629
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
2730
"""Compute the increase in group numbers to account for
@@ -41,20 +44,72 @@ class Mamba2StateDType:
4144
temporal: torch.dtype
4245

4346

44-
def mamba2_state_dtype() -> Mamba2StateDType:
47+
def mamba2_state_dtype(config=None) -> Mamba2StateDType:
48+
"""
49+
Get mamba2 state dtype from config or environment variable.
50+
51+
Priority (from highest to lowest):
52+
1. Environment variable SGLANG_MAMBA_SSM_DTYPE
53+
2. Config file (config.mamba_ssm_dtype or config.text_config.mamba_ssm_dtype)
54+
3. Default "float32"
55+
56+
Args:
57+
config: Optional config object (PretrainedConfig). If provided, will read
58+
mamba_ssm_dtype from it. For VL models, reads from text_config.
59+
60+
Returns:
61+
Mamba2StateDType with conv and temporal dtypes
62+
"""
4563
dtype_map = {
4664
"float32": torch.float32,
4765
"bfloat16": torch.bfloat16,
4866
"float16": torch.float16,
4967
}
5068
conv_dtype = dtype_map.get(envs.SGLANG_MAMBA_CONV_DTYPE.get(), torch.bfloat16)
51-
ssm_dtype = dtype_map.get(envs.SGLANG_MAMBA_SSM_DTYPE.get(), torch.float32)
69+
70+
# Get SSM dtype: default -> config -> env var
71+
ssm_dtype = torch.float32 # Step 1: Default value
72+
73+
# Step 2: Try to read from config
74+
if config is not None:
75+
config_dtype = None
76+
if hasattr(config, "text_config") and hasattr(
77+
config.text_config, "mamba_ssm_dtype"
78+
):
79+
# VL model: read from text_config
80+
config_dtype = config.text_config.mamba_ssm_dtype
81+
elif hasattr(config, "mamba_ssm_dtype"):
82+
# Text model: read from root config
83+
config_dtype = config.mamba_ssm_dtype
84+
85+
if config_dtype is not None:
86+
if config_dtype not in dtype_map:
87+
logger.warning(
88+
f"Invalid mamba_ssm_dtype '{config_dtype}' in config. "
89+
f"Must be one of {list(dtype_map.keys())}. Using default 'float32'."
90+
)
91+
else:
92+
ssm_dtype = dtype_map[config_dtype]
93+
94+
# Step 3: Check environment variable, if not None, override
95+
env_ssm_dtype = envs.SGLANG_MAMBA_SSM_DTYPE.get()
96+
if env_ssm_dtype is not None:
97+
if env_ssm_dtype not in dtype_map:
98+
logger.warning(
99+
f"Invalid mamba_ssm_dtype '{env_ssm_dtype}' from environment variable. "
100+
f"Must be one of {list(dtype_map.keys())}. Using default 'float32'."
101+
)
102+
else:
103+
ssm_dtype = dtype_map[env_ssm_dtype]
104+
105+
logger.info(f"Mamba2 state dtype: conv_dtype={conv_dtype}, ssm_dtype={ssm_dtype}")
106+
52107
return Mamba2StateDType(conv=conv_dtype, temporal=ssm_dtype)
53108

54109

55110
@dataclass(kw_only=True, frozen=True)
56111
class BaseLinearStateParams(ABC):
57-
dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
112+
dtype: Mamba2StateDType = field(default_factory=lambda: mamba2_state_dtype(None))
58113
layers: list[int]
59114

60115
@property

python/sglang/srt/configs/nemotron_h.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
from transformers.configuration_utils import PretrainedConfig
2020
from transformers.utils import logging
2121

22-
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
22+
from sglang.srt.configs.mamba_utils import (
23+
Mamba2CacheParams,
24+
Mamba2StateShape,
25+
mamba2_state_dtype,
26+
)
2327

2428
logger = logging.get_logger(__name__)
2529

@@ -305,4 +309,6 @@ def mamba2_cache_params(self) -> Mamba2CacheParams:
305309
conv_kernel=self.conv_kernel,
306310
)
307311

308-
return Mamba2CacheParams(shape=shape, layers=self.mamba_layer_ids)
312+
return Mamba2CacheParams(
313+
shape=shape, layers=self.mamba_layer_ids, dtype=mamba2_state_dtype(self)
314+
)

python/sglang/srt/configs/qwen3_next.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
from transformers.configuration_utils import PretrainedConfig
2020
from transformers.utils import logging
2121

22-
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
22+
from sglang.srt.configs.mamba_utils import (
23+
Mamba2CacheParams,
24+
Mamba2StateShape,
25+
mamba2_state_dtype,
26+
)
2327
from sglang.srt.configs.update_config import adjust_tp_num_heads_if_necessary
2428
from sglang.srt.utils import is_cpu
2529

@@ -293,4 +297,6 @@ def mamba2_cache_params(self) -> Mamba2CacheParams:
293297
conv_kernel=self.linear_conv_kernel_dim,
294298
)
295299

296-
return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)
300+
return Mamba2CacheParams(
301+
shape=shape, layers=self.linear_layer_ids, dtype=mamba2_state_dtype(self)
302+
)

python/sglang/srt/environ.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ class Envs:
418418

419419
# Mamba
420420
SGLANG_MAMBA_CONV_DTYPE = EnvStr("bfloat16")
421-
SGLANG_MAMBA_SSM_DTYPE = EnvStr("float32")
421+
SGLANG_MAMBA_SSM_DTYPE = EnvStr(None)
422422

423423
# Release & Resume Memory
424424
SGLANG_MEMORY_SAVER_CUDA_GRAPH = EnvBool(False)

python/sglang/srt/server_args.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ class ServerArgs:
513513

514514
# Mamba cache
515515
max_mamba_cache_size: Optional[int] = None
516-
mamba_ssm_dtype: str = "float32"
516+
mamba_ssm_dtype: Optional[str] = None
517517
mamba_full_memory_ratio: float = 0.9
518518
mamba_scheduler_strategy: str = "auto"
519519
mamba_track_interval: int = 256
@@ -2600,7 +2600,8 @@ def _handle_tokenizer_batching(self):
26002600

26012601
def _handle_environment_variables(self):
26022602
envs.SGLANG_ENABLE_TORCH_COMPILE.set("1" if self.enable_torch_compile else "0")
2603-
envs.SGLANG_MAMBA_SSM_DTYPE.set(self.mamba_ssm_dtype)
2603+
if self.mamba_ssm_dtype is not None:
2604+
envs.SGLANG_MAMBA_SSM_DTYPE.set(self.mamba_ssm_dtype)
26042605
envs.SGLANG_DISABLE_OUTLINES_DISK_CACHE.set(
26052606
"1" if self.disable_outlines_disk_cache else "0"
26062607
)
@@ -4130,9 +4131,10 @@ def add_cli_args(parser: argparse.ArgumentParser):
41304131
parser.add_argument(
41314132
"--mamba-ssm-dtype",
41324133
type=str,
4133-
default=ServerArgs.mamba_ssm_dtype,
4134+
default=None,
41344135
choices=MAMBA_SSM_DTYPE_CHOICES,
4135-
help="The data type of the SSM states in mamba cache.",
4136+
help="The data type of the SSM states in mamba cache. "
4137+
"If not set, will be read from model config (mamba_ssm_dtype).",
41364138
)
41374139
parser.add_argument(
41384140
"--mamba-full-memory-ratio",

0 commit comments

Comments
 (0)