Skip to content

Commit efcdda0

Browse files
authored
[diffusion] fix: fix fsdp (sgl-project#18187)
1 parent 49cbb46 commit efcdda0

File tree

13 files changed

+102
-8
lines changed

13 files changed

+102
-8
lines changed

python/sglang/multimodal_gen/configs/models/dits/zimage.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@
77
from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig
88

99

10+
def is_zimage_layer(n: str, m) -> bool:
11+
"""Returns if the module should be sharded for Z-Image model."""
12+
if "layers" in n and str.isdigit(n.split(".")[-1]):
13+
return True
14+
if ("noise_refiner" in n or "context_refiner" in n) and str.isdigit(
15+
n.split(".")[-1]
16+
):
17+
return True
18+
return False
19+
20+
1021
@dataclass
1122
class ZImageArchConfig(DiTArchConfig):
1223
all_patch_size: Tuple[int, ...] = (2,)
@@ -26,6 +37,8 @@ class ZImageArchConfig(DiTArchConfig):
2637
axes_dims: Tuple[int, int, int] = (32, 48, 48)
2738
axes_lens: Tuple[int, int, int] = (1024, 512, 512)
2839

40+
_fsdp_shard_conditions: list = field(default_factory=lambda: [is_zimage_layer])
41+
2942
stacked_params_mapping: list[tuple[str, str, str]] = field(
3043
default_factory=lambda: [
3144
# (param_name, shard_name, shard_id)

python/sglang/multimodal_gen/runtime/layers/layernorm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def forward_cuda(
8181
if x.dtype == torch.float:
8282
# fp32
8383
out = self.forward_triton(x, residual)
84+
if residual is not None:
85+
return out[0].view(shape), out[1].view(residual_shape)
86+
out = out.view(shape)
87+
return out
8488
elif self.variance_size_override is not None:
8589
return self.forward_native(x, residual)
8690
elif residual is not None:
@@ -94,6 +98,7 @@ def forward_cuda(
9498
else:
9599
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
96100
out = out.view(shape)
101+
97102
return out
98103

99104
def forward_native(

python/sglang/multimodal_gen/runtime/layers/lora/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def __init__(
342342
super().__init__(base_layer, lora_rank, lora_alpha)
343343

344344
def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor:
345-
return A.to(self.base_layer.weight)
345+
return A
346346

347347
def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor:
348348
tp_rank = get_tp_rank()

python/sglang/multimodal_gen/runtime/layers/triton_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,9 @@ def forward(
948948
)
949949
)
950950
y = y.reshape(x_shape_og)
951+
if residual is not None:
952+
residual_out = residual_out.reshape(x_shape_og)
953+
return y, residual_out
951954
return y
952955

953956

python/sglang/multimodal_gen/runtime/loader/component_loaders/text_encoder_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def load_model(
279279
# if loaded_weights is not None:
280280
weights_not_loaded = weights_to_load - loaded_weights
281281
if weights_not_loaded:
282-
raise ValueError(
282+
logger.warning(
283283
"Following model weights were not initialized from "
284284
f"checkpoint: {weights_not_loaded}"
285285
)

python/sglang/multimodal_gen/runtime/loader/fsdp_load.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,20 @@ def load_model_from_full_model_state_dict(
231231
custom_param_sd, reverse_param_names_mapping = hf_to_custom_state_dict(
232232
full_sd_iterator, param_names_mapping
233233
) # type: ignore
234-
for target_param_name, full_tensor in custom_param_sd.items():
234+
235+
is_fsdp_model = isinstance(model, FSDPModule) or any(
236+
hasattr(p, "device_mesh") for p in meta_sd.values()
237+
)
238+
239+
# sort parameter names to ensure all ranks process parameters in the same order
240+
sorted_param_names = sorted(custom_param_sd.keys())
241+
242+
for target_param_name in sorted_param_names:
243+
full_tensor = custom_param_sd[target_param_name]
235244
meta_sharded_param = meta_sd.get(target_param_name)
236245
if meta_sharded_param is None:
237-
if strict:
246+
# For FSDP models, ensure all ranks process parameters consistently
247+
if strict or is_fsdp_model:
238248
raise ValueError(
239249
f"Parameter {target_param_name} not found in custom model state dict. The hf to custom mapping may be incorrect."
240250
)
@@ -261,6 +271,9 @@ def load_model_from_full_model_state_dict(
261271
sharded_tensor = temp_param.data
262272
else:
263273
sharded_tensor = full_tensor
274+
275+
if cpu_offload:
276+
sharded_tensor = sharded_tensor.cpu()
264277
else:
265278
full_tensor = full_tensor.to(device=device, dtype=param_dtype)
266279
sharded_tensor = distribute_tensor(
@@ -296,6 +309,8 @@ def load_model_from_full_model_state_dict(
296309
sharded_tensor = torch.zeros_like(
297310
meta_sharded_param, device=device, dtype=param_dtype
298311
)
312+
if cpu_offload:
313+
sharded_tensor = sharded_tensor.cpu()
299314
else:
300315
# Initialize with zeros and distribute
301316
full_tensor = torch.zeros_like(

python/sglang/multimodal_gen/runtime/managers/gpu_worker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,8 @@ def list_loras(self) -> OutputBatch:
349349
- If the OOM occurs during runtime:
350350
1. Reduce the number of output tokens by lowering resolution or decreasing `--num-frames`
351351
2. Enable SP and/or TP
352-
3. Enable a sparse-attention backend
352+
3. Opt for a sparse-attention backend
353+
4. Enable FSDP by `--use-fsdp-inference` (in a multi-GPU setup)
353354
Or, open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
354355
"""
355356

@@ -402,7 +403,7 @@ def run_scheduler_process(
402403
)
403404
scheduler.event_loop()
404405
except torch.OutOfMemoryError as _e:
405-
print(OOM_MSG)
406+
logger.warning(OOM_MSG)
406407
raise
407408
finally:
408409
# Clean up resources to speed up shutdown

python/sglang/multimodal_gen/runtime/models/dits/zimage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ def __call__(self, ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
381381
class ZImageTransformer2DModel(CachableDiT, OffloadableDiTMixin):
382382
_supports_gradient_checkpointing = True
383383
_no_split_modules = ["ZImageTransformerBlock"]
384+
_fsdp_shard_conditions = ZImageDitConfig().arch_config._fsdp_shard_conditions
384385
param_names_mapping = ZImageDitConfig().arch_config.param_names_mapping
385386

386387
param_names_mapping = ZImageDitConfig().arch_config.param_names_mapping

python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,10 @@ def _manage_device_placement(
846846
if not server_args.dit_cpu_offload:
847847
return
848848

849+
# FSDP manages offloading internally
850+
if server_args.use_fsdp_inference:
851+
return
852+
849853
# Offload the unused model if it's on CUDA
850854
if (
851855
model_to_offload is not None

python/sglang/multimodal_gen/test/scripts/gen_perf_baselines.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def _build_server_extra_args(case: DiffusionTestCase) -> str:
6767
a += f" --lora-path {server_args.lora_path}"
6868
if server_args.warmup:
6969
a += " --warmup"
70+
71+
for extra_arg in server_args.extras:
72+
a += f" {extra_arg}"
7073
return a
7174

7275

0 commit comments

Comments
 (0)