1212# limitations under the License.
1313"""Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, LFM2, etc."""
1414
15+ import logging
1516from abc import ABC
1617from dataclasses import dataclass , field
1718from typing import List , Optional
2223from sglang .srt .distributed .utils import divide
2324from sglang .srt .environ import envs
2425
26+ logger = logging .getLogger (__name__ )
27+
2528
2629def extra_groups_for_head_shards (ngroups : int , tp_size : int ):
2730 """Compute the increase in group numbers to account for
@@ -41,20 +44,72 @@ class Mamba2StateDType:
4144 temporal : torch .dtype
4245
4346
44- def mamba2_state_dtype () -> Mamba2StateDType :
47+ def mamba2_state_dtype (config = None ) -> Mamba2StateDType :
48+ """
49+ Get mamba2 state dtype from config or environment variable.
50+
51+ Priority (from highest to lowest):
52+ 1. Environment variable SGLANG_MAMBA_SSM_DTYPE
53+ 2. Config file (config.mamba_ssm_dtype or config.text_config.mamba_ssm_dtype)
54+ 3. Default "float32"
55+
56+ Args:
57+ config: Optional config object (PretrainedConfig). If provided, will read
58+ mamba_ssm_dtype from it. For VL models, reads from text_config.
59+
60+ Returns:
61+ Mamba2StateDType with conv and temporal dtypes
62+ """
4563 dtype_map = {
4664 "float32" : torch .float32 ,
4765 "bfloat16" : torch .bfloat16 ,
4866 "float16" : torch .float16 ,
4967 }
5068 conv_dtype = dtype_map .get (envs .SGLANG_MAMBA_CONV_DTYPE .get (), torch .bfloat16 )
51- ssm_dtype = dtype_map .get (envs .SGLANG_MAMBA_SSM_DTYPE .get (), torch .float32 )
69+
70+ # Get SSM dtype: default -> config -> env var
71+ ssm_dtype = torch .float32 # Step 1: Default value
72+
73+ # Step 2: Try to read from config
74+ if config is not None :
75+ config_dtype = None
76+ if hasattr (config , "text_config" ) and hasattr (
77+ config .text_config , "mamba_ssm_dtype"
78+ ):
79+ # VL model: read from text_config
80+ config_dtype = config .text_config .mamba_ssm_dtype
81+ elif hasattr (config , "mamba_ssm_dtype" ):
82+ # Text model: read from root config
83+ config_dtype = config .mamba_ssm_dtype
84+
85+ if config_dtype is not None :
86+ if config_dtype not in dtype_map :
87+ logger .warning (
88+ f"Invalid mamba_ssm_dtype '{ config_dtype } ' in config. "
89+ f"Must be one of { list (dtype_map .keys ())} . Using default 'float32'."
90+ )
91+ else :
92+ ssm_dtype = dtype_map [config_dtype ]
93+
94+ # Step 3: Check environment variable, if not None, override
95+ env_ssm_dtype = envs .SGLANG_MAMBA_SSM_DTYPE .get ()
96+ if env_ssm_dtype is not None :
97+ if env_ssm_dtype not in dtype_map :
98+ logger .warning (
99+ f"Invalid mamba_ssm_dtype '{ env_ssm_dtype } ' from environment variable. "
100+ f"Must be one of { list (dtype_map .keys ())} . Using default 'float32'."
101+ )
102+ else :
103+ ssm_dtype = dtype_map [env_ssm_dtype ]
104+
105+ logger .info (f"Mamba2 state dtype: conv_dtype={ conv_dtype } , ssm_dtype={ ssm_dtype } " )
106+
52107 return Mamba2StateDType (conv = conv_dtype , temporal = ssm_dtype )
53108
54109
55110@dataclass (kw_only = True , frozen = True )
56111class BaseLinearStateParams (ABC ):
57- dtype : Mamba2StateDType = field (default_factory = mamba2_state_dtype )
112+ dtype : Mamba2StateDType = field (default_factory = lambda : mamba2_state_dtype ( None ) )
58113 layers : list [int ]
59114
60115 @property
0 commit comments