55without restarting the server. It is the diffusion-engine counterpart of the
66LLM engine's ModelRunner.update_weights_from_disk.
77
8- Typical usage (from GPUWorker):
8+ Typical usage (from GPUWorker.update_weights_from_disk ):
99
1010 updater = WeightsUpdater(self.pipeline)
1111 success, message = updater.update_weights_from_disk(
1212 model_path,
13- original_model_path=self.server_args.model_path,
13+ flush_cache=flush_cache,
14+ target_modules=target_modules,
1415 )
16+ if success:
17+ self.server_args.model_path = model_path
18+ return success, message
1519
1620Key design decisions:
1721
1822- All-or-nothing: if any module fails to load, all previously updated
1923 modules are rolled back to the original weights by reloading from
20- original_model_path. No partial updates are left behind.
24+ pipeline.model_path. No partial updates are left behind.
2125
2226- Rollback failures propagate: if rollback itself fails, the exception is
2327 not caught so the caller knows the model is in an inconsistent state.
4145from __future__ import annotations
4246
4347import gc
44- import os
45- import time
4648
4749import torch
50+ from torch .distributed .tensor import DTensor , distribute_tensor
4851
4952from sglang .multimodal_gen .runtime .cache .teacache import TeaCacheMixin
50- from sglang .multimodal_gen .runtime .loader .utils import _list_safetensors_files
53+ from sglang .multimodal_gen .runtime .loader .utils import (
54+ _list_safetensors_files ,
55+ find_weights_dir ,
56+ )
5157from sglang .multimodal_gen .runtime .loader .weight_utils import (
5258 safetensors_weights_iterator ,
5359)
5662from sglang .multimodal_gen .runtime .utils .layerwise_offload import OffloadableDiTMixin
5763from sglang .multimodal_gen .runtime .utils .logging_utils import init_logger
5864
59- try :
60- from torch .distributed .tensor import DTensor , distribute_tensor
61- except ImportError :
62- DTensor = None
63- distribute_tensor = None
64-
6565logger = init_logger (__name__ )
6666
6767
@@ -87,7 +87,8 @@ class WeightsUpdater:
8787
8888 Args:
8989 pipeline: A ComposedPipelineBase (or DiffusersPipeline) instance
90- whose modules will be updated.
90+ whose modules will be updated. The pipeline's model_path
91+ attribute is used for rollback on failure.
9192 """
9293
9394 def __init__ (self , pipeline ):
@@ -96,26 +97,21 @@ def __init__(self, pipeline):
9697 def update_weights_from_disk (
9798 self ,
9899 model_path : str ,
99- original_model_path : str ,
100100 flush_cache : bool = True ,
101101 target_modules : list [str ] | None = None ,
102102 ) -> tuple [bool , str ]:
103103 """Update model weights from disk without restarting the server.
104104
105105 Args:
106106 model_path: HF repo id or local path to the new weights.
107- original_model_path: Path to the currently loaded weights (used
108- for rollback on failure).
109107 flush_cache: If True, reset TeaCache state after a successful
110108 update so that stale cached residuals are not reused.
111109 target_modules: Explicit list of module names to update. None
112- or ["all"] updates every nn.Module in the pipeline.
110+ updates every nn.Module in the pipeline.
113111
114112 Returns:
115- (success, message) tuple.
113+ (success, message) tuple where success is True on success .
116114 """
117- tic = time .perf_counter ()
118- self ._original_model_path = original_model_path
119115 logger .info (f"Updating weights from disk: { model_path } " )
120116
121117 try :
@@ -161,10 +157,9 @@ def update_weights_from_disk(
161157
162158 if success and flush_cache :
163159 for _ , module in modules_to_update :
164- _reset_cache_state (module )
160+ if isinstance (module , TeaCacheMixin ):
161+ module .reset_teacache_state ()
165162
166- elapsed = time .perf_counter () - tic
167- message = f"{ message } elapsed={ elapsed :.2f} s"
168163 logger .info (message )
169164 return success , message
170165
@@ -182,7 +177,7 @@ def _collect_modules(
182177 """
183178 components = get_updatable_modules (self .pipeline )
184179
185- if target_modules is None or target_modules == [ "all" ] :
180+ if target_modules is None :
186181 names = list (components .keys ())
187182 else :
188183 unknown = [n for n in target_modules if n not in components ]
@@ -232,7 +227,7 @@ def _rollback(self, updated_modules: list[str]) -> None:
232227 """
233228 if not updated_modules :
234229 return
235- original_path = maybe_download_model (self ._original_model_path )
230+ original_path = maybe_download_model (self .pipeline . model_path )
236231 for name in updated_modules :
237232 module = self .pipeline .get_module (name )
238233 if module is None :
@@ -249,23 +244,6 @@ def _rollback(self, updated_modules: list[str]) -> None:
249244# ---------------------------------------------------------------------------
250245
251246
252- def find_weights_dir (local_path : str , module_name : str ) -> str | None :
253- """Locate the safetensors directory for module_name under local_path.
254-
255- Diffusion models store weights in per-module subdirectories (e.g.
256- transformer/, vae/, text_encoder/). This function tries
257- <local_path>/<module_name>/ first, then falls back to local_path
258- itself if it directly contains safetensors files (common for RL
259- checkpoints that save weights in a flat directory).
260- """
261- dir_path = os .path .join (local_path , module_name )
262- if os .path .exists (dir_path ):
263- return dir_path
264- if _list_safetensors_files (local_path ):
265- return local_path
266- return None
267-
268-
269247def _get_weights_iter (weights_dir : str ):
270248 """Return a (name, tensor) iterator over safetensors in weights_dir."""
271249 safetensors_files = _list_safetensors_files (weights_dir )
@@ -295,25 +273,21 @@ def _validate_weight_files(
295273 return weights_map , missing
296274
297275
298- def _get_offload_managers (module : torch .nn .Module ) -> list :
299- """Return active offload managers for the given module, if any."""
300- if isinstance (module , OffloadableDiTMixin ) and module .layerwise_offload_managers :
301- return [m for m in module .layerwise_offload_managers if m .enabled ]
302- return []
303-
304-
305276def _load_weights_into_module (module : torch .nn .Module , weights_iter ) -> None :
306277 """Load weights into a module, handling offload-managed parameters.
307278
308279 For offloaded modules, updates CPU buffers directly via
309280 update_cpu_weights(); non-offloaded parameters use in-place copy.
310281 """
311- offload_managers = _get_offload_managers (module )
282+ offload_managers : list = []
283+ if isinstance (module , OffloadableDiTMixin ) and module .layerwise_offload_managers :
284+ offload_managers = [m for m in module .layerwise_offload_managers if m .enabled ]
285+
312286 if offload_managers :
313287 weight_dict = dict (weights_iter )
314288 offloaded_names : set [str ] = set ()
315289 for manager in offload_managers :
316- offloaded_names |= manager .update_cpu_weights (weight_dict )
290+ offloaded_names . update ( manager .update_cpu_weights (weight_dict ) )
317291 remaining = ((n , w ) for n , w in weight_dict .items () if n not in offloaded_names )
318292 load_weights_into_model (remaining , dict (module .named_parameters ()))
319293 else :
@@ -330,7 +304,7 @@ def load_weights_into_model(weights_iter, model_params: dict) -> None:
330304 raise ValueError (
331305 f"Shape mismatch for { name } : model={ param .shape } , loaded={ loaded_weight .shape } "
332306 )
333- if DTensor is not None and isinstance (param , DTensor ):
307+ if isinstance (param , DTensor ):
334308 distributed_weight = distribute_tensor (
335309 loaded_weight .to (param .dtype ),
336310 param .device_mesh ,
@@ -339,13 +313,3 @@ def load_weights_into_model(weights_iter, model_params: dict) -> None:
339313 param ._local_tensor .copy_ (distributed_weight ._local_tensor )
340314 else :
341315 param .data .copy_ (loaded_weight .to (param .dtype ))
342-
343-
344- def _reset_cache_state (module : torch .nn .Module ) -> None :
345- """Reset Cache state after weight updates.
346-
347- After weights change, any cached residuals from previous denoising steps
348- are invalid and must be cleared.
349- """
350- if isinstance (module , TeaCacheMixin ):
351- module .reset_teacache_state ()
0 commit comments