Skip to content

Commit 666fb8b

Browse files
committed
[diffusion] address comments
1 parent c465ec0 commit 666fb8b

File tree

8 files changed

+57
-70
lines changed

8 files changed

+57
-70
lines changed

docs/advanced_features/sglang_for_rl.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ This path trades some I/O overhead for simplicity and flexibility. It integrates
127127
| `success` | Whether the update succeeded. | - | Type: bool |
128128
| `message` | Status / error message. | - | Type: str |
129129

130+
> **Note:** The diffusion engine (SGLang-Diffusion) does not currently support hot refit (updating weights while inference is in progress). The diffusion scheduler processes one request at a time and completes the entire inference before handling the next request, so weight updates and inference never run concurrently.
131+
130132
### Update Weights from Tensor
131133

132134
**When to use:**

python/sglang/multimodal_gen/runtime/entrypoints/post_training/utils.py renamed to python/sglang/multimodal_gen/runtime/entrypoints/post_training/io_struct.py

File renamed without changes.

python/sglang/multimodal_gen/runtime/entrypoints/post_training/weights_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
"""Post-training APIs: weight updates and related operations."""
1+
"""Weight update API for the diffusion engine."""
22

33
from fastapi import APIRouter, Request
44
from fastapi.responses import ORJSONResponse
55

6-
from sglang.multimodal_gen.runtime.entrypoints.post_training.utils import (
6+
from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import (
77
UpdateWeightFromDiskReqInput,
88
)
99
from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client

python/sglang/multimodal_gen/runtime/loader/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,18 @@ def _list_safetensors_files(model_path: str) -> list[str]:
145145
return sorted(glob.glob(os.path.join(str(model_path), "*.safetensors")))
146146

147147

148+
def find_weights_dir(local_path: str, module_name: str) -> str | None:
149+
"""Locate the safetensors directory for module_name under local_path.
150+
151+
Diffusion models store weights in per-module subdirectories (e.g.
152+
transformer/, vae/, text_encoder/).
153+
"""
154+
dir_path = os.path.join(local_path, module_name)
155+
if os.path.exists(dir_path):
156+
return dir_path
157+
return None
158+
159+
148160
def get_memory_usage_of_component(module) -> float | None:
149161
"""
150162
returned value is in GB, rounded to 2 decimal digits

python/sglang/multimodal_gen/runtime/loader/weights_updater.py

Lines changed: 26 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,23 @@
55
without restarting the server. It is the diffusion-engine counterpart of the
66
LLM engine's ModelRunner.update_weights_from_disk.
77
8-
Typical usage (from GPUWorker):
8+
Typical usage (from GPUWorker.update_weights_from_disk):
99
1010
updater = WeightsUpdater(self.pipeline)
1111
success, message = updater.update_weights_from_disk(
1212
model_path,
13-
original_model_path=self.server_args.model_path,
13+
flush_cache=flush_cache,
14+
target_modules=target_modules,
1415
)
16+
if success:
17+
self.server_args.model_path = model_path
18+
return success, message
1519
1620
Key design decisions:
1721
1822
- All-or-nothing: if any module fails to load, all previously updated
1923
modules are rolled back to the original weights by reloading from
20-
original_model_path. No partial updates are left behind.
24+
pipeline.model_path. No partial updates are left behind.
2125
2226
- Rollback failures propagate: if rollback itself fails, the exception is
2327
not caught so the caller knows the model is in an inconsistent state.
@@ -41,13 +45,15 @@
4145
from __future__ import annotations
4246

4347
import gc
44-
import os
45-
import time
4648

4749
import torch
50+
from torch.distributed.tensor import DTensor, distribute_tensor
4851

4952
from sglang.multimodal_gen.runtime.cache.teacache import TeaCacheMixin
50-
from sglang.multimodal_gen.runtime.loader.utils import _list_safetensors_files
53+
from sglang.multimodal_gen.runtime.loader.utils import (
54+
_list_safetensors_files,
55+
find_weights_dir,
56+
)
5157
from sglang.multimodal_gen.runtime.loader.weight_utils import (
5258
safetensors_weights_iterator,
5359
)
@@ -56,12 +62,6 @@
5662
from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin
5763
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
5864

59-
try:
60-
from torch.distributed.tensor import DTensor, distribute_tensor
61-
except ImportError:
62-
DTensor = None
63-
distribute_tensor = None
64-
6565
logger = init_logger(__name__)
6666

6767

@@ -87,7 +87,8 @@ class WeightsUpdater:
8787
8888
Args:
8989
pipeline: A ComposedPipelineBase (or DiffusersPipeline) instance
90-
whose modules will be updated.
90+
whose modules will be updated. The pipeline's model_path
91+
attribute is used for rollback on failure.
9192
"""
9293

9394
def __init__(self, pipeline):
@@ -96,26 +97,21 @@ def __init__(self, pipeline):
9697
def update_weights_from_disk(
9798
self,
9899
model_path: str,
99-
original_model_path: str,
100100
flush_cache: bool = True,
101101
target_modules: list[str] | None = None,
102102
) -> tuple[bool, str]:
103103
"""Update model weights from disk without restarting the server.
104104
105105
Args:
106106
model_path: HF repo id or local path to the new weights.
107-
original_model_path: Path to the currently loaded weights (used
108-
for rollback on failure).
109107
flush_cache: If True, reset TeaCache state after a successful
110108
update so that stale cached residuals are not reused.
111109
target_modules: Explicit list of module names to update. None
112-
or ["all"] updates every nn.Module in the pipeline.
110+
updates every nn.Module in the pipeline.
113111
114112
Returns:
115-
(success, message) tuple.
113+
(success, message) tuple where success is True on success.
116114
"""
117-
tic = time.perf_counter()
118-
self._original_model_path = original_model_path
119115
logger.info(f"Updating weights from disk: {model_path}")
120116

121117
try:
@@ -161,10 +157,9 @@ def update_weights_from_disk(
161157

162158
if success and flush_cache:
163159
for _, module in modules_to_update:
164-
_reset_cache_state(module)
160+
if isinstance(module, TeaCacheMixin):
161+
module.reset_teacache_state()
165162

166-
elapsed = time.perf_counter() - tic
167-
message = f"{message} elapsed={elapsed:.2f}s"
168163
logger.info(message)
169164
return success, message
170165

@@ -182,7 +177,7 @@ def _collect_modules(
182177
"""
183178
components = get_updatable_modules(self.pipeline)
184179

185-
if target_modules is None or target_modules == ["all"]:
180+
if target_modules is None:
186181
names = list(components.keys())
187182
else:
188183
unknown = [n for n in target_modules if n not in components]
@@ -232,7 +227,7 @@ def _rollback(self, updated_modules: list[str]) -> None:
232227
"""
233228
if not updated_modules:
234229
return
235-
original_path = maybe_download_model(self._original_model_path)
230+
original_path = maybe_download_model(self.pipeline.model_path)
236231
for name in updated_modules:
237232
module = self.pipeline.get_module(name)
238233
if module is None:
@@ -249,23 +244,6 @@ def _rollback(self, updated_modules: list[str]) -> None:
249244
# ---------------------------------------------------------------------------
250245

251246

252-
def find_weights_dir(local_path: str, module_name: str) -> str | None:
253-
"""Locate the safetensors directory for module_name under local_path.
254-
255-
Diffusion models store weights in per-module subdirectories (e.g.
256-
transformer/, vae/, text_encoder/). This function tries
257-
<local_path>/<module_name>/ first, then falls back to local_path
258-
itself if it directly contains safetensors files (common for RL
259-
checkpoints that save weights in a flat directory).
260-
"""
261-
dir_path = os.path.join(local_path, module_name)
262-
if os.path.exists(dir_path):
263-
return dir_path
264-
if _list_safetensors_files(local_path):
265-
return local_path
266-
return None
267-
268-
269247
def _get_weights_iter(weights_dir: str):
270248
"""Return a (name, tensor) iterator over safetensors in weights_dir."""
271249
safetensors_files = _list_safetensors_files(weights_dir)
@@ -295,25 +273,21 @@ def _validate_weight_files(
295273
return weights_map, missing
296274

297275

298-
def _get_offload_managers(module: torch.nn.Module) -> list:
299-
"""Return active offload managers for the given module, if any."""
300-
if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers:
301-
return [m for m in module.layerwise_offload_managers if m.enabled]
302-
return []
303-
304-
305276
def _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None:
306277
"""Load weights into a module, handling offload-managed parameters.
307278
308279
For offloaded modules, updates CPU buffers directly via
309280
update_cpu_weights(); non-offloaded parameters use in-place copy.
310281
"""
311-
offload_managers = _get_offload_managers(module)
282+
offload_managers: list = []
283+
if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers:
284+
offload_managers = [m for m in module.layerwise_offload_managers if m.enabled]
285+
312286
if offload_managers:
313287
weight_dict = dict(weights_iter)
314288
offloaded_names: set[str] = set()
315289
for manager in offload_managers:
316-
offloaded_names |= manager.update_cpu_weights(weight_dict)
290+
offloaded_names.update(manager.update_cpu_weights(weight_dict))
317291
remaining = ((n, w) for n, w in weight_dict.items() if n not in offloaded_names)
318292
load_weights_into_model(remaining, dict(module.named_parameters()))
319293
else:
@@ -330,7 +304,7 @@ def load_weights_into_model(weights_iter, model_params: dict) -> None:
330304
raise ValueError(
331305
f"Shape mismatch for {name}: model={param.shape}, loaded={loaded_weight.shape}"
332306
)
333-
if DTensor is not None and isinstance(param, DTensor):
307+
if isinstance(param, DTensor):
334308
distributed_weight = distribute_tensor(
335309
loaded_weight.to(param.dtype),
336310
param.device_mesh,
@@ -339,13 +313,3 @@ def load_weights_into_model(weights_iter, model_params: dict) -> None:
339313
param._local_tensor.copy_(distributed_weight._local_tensor)
340314
else:
341315
param.data.copy_(loaded_weight.to(param.dtype))
342-
343-
344-
def _reset_cache_state(module: torch.nn.Module) -> None:
345-
"""Reset Cache state after weight updates.
346-
347-
After weights change, any cached residuals from previous denoising steps
348-
are invalid and must be cleared.
349-
"""
350-
if isinstance(module, TeaCacheMixin):
351-
module.reset_teacache_state()

python/sglang/multimodal_gen/runtime/managers/gpu_worker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,6 @@ def update_weights_from_disk(
355355
updater = WeightsUpdater(self.pipeline)
356356
success, message = updater.update_weights_from_disk(
357357
model_path,
358-
original_model_path=self.server_args.model_path,
359358
flush_cache=flush_cache,
360359
target_modules=target_modules,
361360
)

python/sglang/multimodal_gen/runtime/managers/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
_parse_size,
2121
save_image_to_path,
2222
)
23-
from sglang.multimodal_gen.runtime.entrypoints.post_training.utils import (
23+
from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import (
2424
UpdateWeightFromDiskReqInput,
2525
)
2626
from sglang.multimodal_gen.runtime.managers.gpu_worker import GPUWorker

python/sglang/multimodal_gen/runtime/utils/layerwise_offload.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,11 +277,21 @@ def sync_all_layers_to_cpu(self) -> None:
277277
self.sync_layer_to_cpu(layer_idx)
278278

279279
@torch.compiler.disable
280-
def update_cpu_weights(self, weight_dict: Dict[str, torch.Tensor]) -> Set[str]:
280+
def update_cpu_weights(
281+
self, weight_dict: Dict[str, torch.Tensor]
282+
) -> Set[str] | None:
281283
"""Update consolidated CPU buffers with new weights.
282284
283-
For layers currently on GPU, the live GPU parameter is also updated
284-
so the change takes effect immediately.
285+
When layerwise offload (--dit-layerwise-offload) is enabled, the
286+
offload manager replaces GPU parameters with small torch.empty((1,))
287+
placeholders while real weights live in consolidated pinned CPU
288+
buffers. A naive param.data.copy_() would fail with a shape
289+
mismatch. Instead, this method writes new weights directly into
290+
the CPU buffers, bypassing the placeholders entirely. For any
291+
layer that happens to be resident on GPU at update time, the live
292+
GPU tensor is also updated so the change takes effect immediately.
293+
This requires no extra GPU memory and does not disturb the offload
294+
state.
285295
286296
Args:
287297
weight_dict: Mapping of parameter name to new weight tensor.
@@ -294,7 +304,7 @@ def update_cpu_weights(self, weight_dict: Dict[str, torch.Tensor]) -> Set[str]:
294304
metadata (i.e. the real shape, not the placeholder shape).
295305
"""
296306
if not self.enabled:
297-
return set()
307+
return None
298308

299309
updated_names: Set[str] = set()
300310
for name, loaded_weight in weight_dict.items():

0 commit comments

Comments
 (0)