Skip to content

Commit 8d789b5

Browse files
authored
[diffusion] feat: support nunchaku for Z-Image-Turbo and flux.1 (int4) (#18959)
1 parent 7d95344 commit 8d789b5

File tree

9 files changed

+661
-169
lines changed

9 files changed

+661
-169
lines changed

python/sglang/multimodal_gen/configs/models/dits/flux.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,36 @@ class FluxArchConfig(DiTArchConfig):
2323

2424
stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=list)
2525

26+
# nunchaku checkpoint uses different weight names; map to sglang flux layout
2627
param_names_mapping: dict = field(
2728
default_factory=lambda: {
28-
r"transformer\.(\w*)\.(.*)$": r"\1.\2",
29+
# HF diffusers format
30+
r"^transformer\.(\w*)\.(.*)$": r"\1.\2",
31+
# transformer_blocks nunchaku format (raw export - before internal conversion)
32+
r"^transformer_blocks\.(\d+)\.mlp_fc1\.(.*)$": r"transformer_blocks.\1.ff.net.0.proj.\2",
33+
r"^transformer_blocks\.(\d+)\.mlp_fc2\.(.*)$": r"transformer_blocks.\1.ff.net.2.\2",
34+
r"^transformer_blocks\.(\d+)\.mlp_context_fc1\.(.*)$": r"transformer_blocks.\1.ff_context.net.0.proj.\2",
35+
r"^transformer_blocks\.(\d+)\.mlp_context_fc2\.(.*)$": r"transformer_blocks.\1.ff_context.net.2.\2",
36+
r"^transformer_blocks\.(\d+)\.qkv_proj\.(.*)$": r"transformer_blocks.\1.attn.to_qkv.\2",
37+
r"^transformer_blocks\.(\d+)\.qkv_proj_context\.(.*)$": r"transformer_blocks.\1.attn.to_added_qkv.\2",
38+
r"^transformer_blocks\.(\d+)\.out_proj\.(.*)$": r"transformer_blocks.\1.attn.to_out.0.\2",
39+
r"^transformer_blocks\.(\d+)\.out_proj_context\.(.*)$": r"transformer_blocks.\1.attn.to_add_out.\2",
40+
r"^transformer_blocks\.(\d+)\.norm_q\.(.*)$": r"transformer_blocks.\1.attn.norm_q.\2",
41+
r"^transformer_blocks\.(\d+)\.norm_k\.(.*)$": r"transformer_blocks.\1.attn.norm_k.\2",
42+
r"^transformer_blocks\.(\d+)\.norm_added_q\.(.*)$": r"transformer_blocks.\1.attn.norm_added_q.\2",
43+
r"^transformer_blocks\.(\d+)\.norm_added_k\.(.*)$": r"transformer_blocks.\1.attn.norm_added_k.\2",
44+
# transformer_blocks nunchaku format (already converted with convert_flux_state_dict)
45+
r"^transformer_blocks\.(\d+)\.attn\.add_qkv_proj\.(.*)$": r"transformer_blocks.\1.attn.to_added_qkv.\2",
46+
# single_transformer_blocks nunchaku format (raw export - before internal conversion)
47+
r"^single_transformer_blocks\.(\d+)\.qkv_proj\.(.*)$": r"single_transformer_blocks.\1.attn.to_qkv.\2",
48+
r"^single_transformer_blocks\.(\d+)\.out_proj\.(.*)$": r"single_transformer_blocks.\1.attn.to_out.0.\2",
49+
r"^single_transformer_blocks\.(\d+)\.norm_q\.(.*)$": r"single_transformer_blocks.\1.attn.norm_q.\2",
50+
r"^single_transformer_blocks\.(\d+)\.norm_k\.(.*)$": r"single_transformer_blocks.\1.attn.norm_k.\2",
51+
# nunchaku quantization parameter name conversions (apply to all blocks)
52+
r"^(.*)\.smooth_orig$": r"\1.smooth_factor_orig",
53+
r"^(.*)\.smooth$": r"\1.smooth_factor",
54+
r"^(.*)\.lora_down$": r"\1.proj_down",
55+
r"^(.*)\.lora_up$": r"\1.proj_up",
2956
}
3057
)
3158

python/sglang/multimodal_gen/runtime/layers/quantization/configs/nunchaku_config.py

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,6 @@
1616

1717
logger = init_logger(__name__)
1818

19-
SVDQ_W4A4_LAYER_PATTERNS = [
20-
"attn.to_qkv",
21-
"attn.to_out",
22-
"attn.add_qkv_proj",
23-
"attn.to_add_out",
24-
"img_mlp",
25-
"txt_mlp",
26-
]
27-
28-
AWQ_W4A16_LAYER_PATTERNS = [
29-
"img_mod",
30-
"txt_mod",
31-
]
32-
33-
SKIP_QUANTIZATION_PATTERNS = [
34-
"norm",
35-
"embed",
36-
"rotary",
37-
"pos_embed",
38-
]
39-
4019

4120
@lru_cache(maxsize=1)
4221
def is_nunchaku_available() -> bool:
@@ -61,13 +40,15 @@ class NunchakuConfig(QuantizationConfig):
6140
group_size: Quantization group size (automatically set based on precision)
6241
act_unsigned: Use unsigned activation quantization
6342
quantized_model_path: Path to pre-quantized model weights (.safetensors)
43+
model_cls: DiT model class that provides quantization rules via get_nunchaku_quant_rules()
6444
"""
6545

66-
precision: str = "int4" # "int4" or "nvfp4"
46+
precision: str = "int4"
6747
rank: int = 32
6848
group_size: Optional[int] = None
6949
act_unsigned: bool = False
7050
quantized_model_path: Optional[str] = None
51+
model_cls: Optional[type] = None
7152

7253
@classmethod
7354
def get_name(cls) -> str:
@@ -99,15 +80,27 @@ def from_config(cls, config: dict[str, Any]) -> "NunchakuConfig":
9980
def get_quant_method(
10081
self, layer: torch.nn.Module, prefix: str
10182
) -> Optional[QuantizeMethodBase]:
102-
10383
if not isinstance(layer, LinearBase):
10484
return None
10585

106-
for pattern in SKIP_QUANTIZATION_PATTERNS:
86+
# get quantization rules from model class
87+
quant_rules = self._get_quant_rules()
88+
89+
# priority: skip > awq_w4a16 > svdq_w4a4 > default
90+
skip_patterns = quant_rules.get("skip", [])
91+
for pattern in skip_patterns:
10792
if pattern in prefix.lower():
10893
return None
10994

110-
for pattern in SVDQ_W4A4_LAYER_PATTERNS:
95+
awq_patterns = quant_rules.get("awq_w4a16", [])
96+
for pattern in awq_patterns:
97+
if pattern in prefix:
98+
from ..nunchaku_linear import NunchakuAWQLinearMethod
99+
100+
return NunchakuAWQLinearMethod(group_size=64)
101+
102+
svdq_patterns = quant_rules.get("svdq_w4a4", [])
103+
for pattern in svdq_patterns:
111104
if pattern in prefix:
112105
from ..nunchaku_linear import NunchakuSVDQLinearMethod
113106

@@ -117,14 +110,7 @@ def get_quant_method(
117110
act_unsigned=self.act_unsigned,
118111
)
119112

120-
for pattern in AWQ_W4A16_LAYER_PATTERNS:
121-
if pattern in prefix:
122-
from ..nunchaku_linear import NunchakuAWQLinearMethod
123-
124-
return NunchakuAWQLinearMethod(
125-
group_size=64,
126-
)
127-
113+
# default: apply svdq_w4a4 to all remaining linear layers
128114
from ..nunchaku_linear import NunchakuSVDQLinearMethod
129115

130116
return NunchakuSVDQLinearMethod(
@@ -133,6 +119,13 @@ def get_quant_method(
133119
act_unsigned=self.act_unsigned,
134120
)
135121

122+
def _get_quant_rules(self) -> dict[str, list[str]]:
123+
if self.model_cls is not None and hasattr(
124+
self.model_cls, "get_nunchaku_quant_rules"
125+
):
126+
return self.model_cls.get_nunchaku_quant_rules()
127+
return {}
128+
136129
def __post_init__(self):
137130
if self.group_size is None:
138131
if self.precision == "nvfp4":

python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,9 @@ def load_customized(
9494
model_cls, _ = ModelRegistry.resolve_model_cls(cls_name)
9595

9696
nunchaku_config = server_args.nunchaku_config
97-
9897
if nunchaku_config is not None:
98+
nunchaku_config.model_cls = model_cls
99+
99100
# respect dtype from checkpoint
100101
# TODO: improve the condition
101102
param_dtype = None
@@ -158,7 +159,7 @@ def load_customized(
158159
logger.info("Loaded model with %.2fB parameters", total_params / 1e9)
159160

160161
# considering the existent of mixed-precision models (e.g., nunchaku)
161-
if next(model.parameters()).dtype != param_dtype:
162+
if next(model.parameters()).dtype != param_dtype and param_dtype:
162163
logger.warning(
163164
f"Model dtype does not match expected param dtype, {next(model.parameters()).dtype} vs {param_dtype}"
164165
)

python/sglang/multimodal_gen/runtime/loader/fsdp_load.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ def load_model_from_full_model_state_dict(
238238
"""
239239
meta_sd = model.state_dict()
240240
param_dict = dict(model.named_parameters())
241-
sharded_sd = {}
241+
242+
# map names from checkpoint to customized names
242243
custom_param_sd, reverse_param_names_mapping = hf_to_custom_state_dict(
243244
full_sd_iterator, param_names_mapping
244245
) # type: ignore
@@ -250,7 +251,7 @@ def load_model_from_full_model_state_dict(
250251
# sort parameter names to ensure all ranks process parameters in the same order
251252
sorted_param_names = sorted(custom_param_sd.keys())
252253

253-
requires_grad = False
254+
sharded_sd = {}
254255

255256
# shard from loaded state_dict, custom_param_sd -> sharded_sd
256257
for target_param_name in sorted_param_names:

python/sglang/multimodal_gen/runtime/loader/utils.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def set_default_torch_dtype(dtype: torch.dtype):
3131

3232

3333
def get_param_names_mapping(
34-
mapping_dict: dict[str, str],
34+
mapping_dict: dict[str, str | tuple[str, int, int]],
3535
) -> Callable[[str], tuple[str, Any, Any]]:
3636
"""
3737
Creates a mapping function that transforms parameter names using regex patterns.
@@ -44,21 +44,50 @@ def get_param_names_mapping(
4444
"""
4545

4646
def mapping_fn(name: str) -> tuple[str, Any, Any]:
47-
# Try to match and transform the name using the regex patterns in mapping_dict
48-
for pattern, replacement in mapping_dict.items():
49-
match = re.match(pattern, name)
50-
if match:
51-
merge_index = None
52-
total_split_params = None
47+
# support chained conversions, e.g.:
48+
# transformer.xxx.lora_down -> xxx.lora_down -> xxx.proj_down
49+
merge_index = None
50+
total_split_params = None
51+
max_steps = max(8, len(mapping_dict) * 2)
52+
applied_patterns: set[str] = set()
53+
visited_names: set[str] = {name}
54+
55+
for _ in range(max_steps):
56+
transformed = False
57+
for pattern, replacement in mapping_dict.items():
58+
# avoid re-applying the same rule on its own output
59+
if pattern in applied_patterns:
60+
continue
61+
if re.match(pattern, name) is None:
62+
continue
63+
64+
curr_merge_index = None
65+
curr_total_split_params = None
5366
if isinstance(replacement, tuple):
54-
merge_index = replacement[1]
55-
total_split_params = replacement[2]
67+
curr_merge_index = replacement[1]
68+
curr_total_split_params = replacement[2]
5669
replacement = replacement[0]
57-
name = re.sub(pattern, replacement, name)
58-
return name, merge_index, total_split_params
5970

60-
# If no pattern matches, return the original name
61-
return name, None, None
71+
new_name = re.sub(pattern, replacement, name)
72+
73+
if new_name != name:
74+
if curr_merge_index is not None:
75+
merge_index = curr_merge_index
76+
total_split_params = curr_total_split_params
77+
78+
name = new_name
79+
applied_patterns.add(pattern)
80+
if name in visited_names:
81+
transformed = False
82+
break
83+
visited_names.add(name)
84+
transformed = True
85+
break
86+
87+
if not transformed:
88+
break
89+
90+
return name, merge_index, total_split_params
6291

6392
return mapping_fn
6493

@@ -150,25 +179,5 @@ def _list_safetensors_files(model_path: str) -> list[str]:
150179

151180
BYTES_PER_GB = 1024**3
152181

153-
154-
def get_memory_usage_of_component(module) -> float | None:
155-
"""
156-
returned value is in GB, rounded to 2 decimal digits
157-
"""
158-
if not isinstance(module, nn.Module):
159-
return None
160-
if hasattr(module, "get_memory_footprint"):
161-
usage = module.get_memory_footprint() / BYTES_PER_GB
162-
else:
163-
# manually
164-
param_size = sum(p.numel() * p.element_size() for p in module.parameters())
165-
buffer_size = sum(b.numel() * b.element_size() for b in module.buffers())
166-
167-
total_size_bytes = param_size + buffer_size
168-
usage = total_size_bytes / (1024**3)
169-
170-
return round(usage, 2)
171-
172-
173182
# component name -> ComponentLoader class
174183
component_name_to_loader_cls: Dict[str, Type[Any]] = {}

python/sglang/multimodal_gen/runtime/models/dits/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,17 @@ class CachableDiT(TeaCacheMixin, BaseDiT):
107107
def __init__(self, config: DiTConfig, **kwargs) -> None:
108108
super().__init__(config, **kwargs)
109109
self._init_teacache_state()
110+
111+
@classmethod
112+
def get_nunchaku_quant_rules(cls) -> dict[str, dict[str, Any]]:
113+
"""
114+
Get quantization rules for Nunchaku quantization.
115+
116+
Returns a dict mapping layer name patterns to quantization configs:
117+
{
118+
"skip": [list of patterns to skip quantization],
119+
"svdq_w4a4": [list of patterns for SVDQ W4A4],
120+
"awq_w4a16": [list of patterns for AWQ W4A16],
121+
}
122+
"""
123+
return {}

0 commit comments

Comments
 (0)