-
Notifications
You must be signed in to change notification settings - Fork 4.7k
[Feature] Implement update_weights_from_disk for SGLang-D (Diffusion … #18306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
zhaochenyang20
merged 37 commits into
sgl-project:main
from
dreamyang-liu:feat/diffusion-update-weights-from-disk
Feb 18, 2026
+1,307
−5
Merged
Changes from 3 commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
c93c103
[Feature] Implement update_weights_from_disk for SGLang-D (Diffusion …
dreamyang-liu 7703c69
[diffusion] refactor: extract WeightsUpdater for update_weights_from_…
dreamyang-liu 886a6bc
chore: isort lint
dreamyang-liu 3d19d61
[diffusion] offload-aware weight updates and cleanup
dreamyang-liu d0e2fec
[diffusion] refactor: move post-training API to dedicated package and…
dreamyang-liu 07cc679
[diffusion] refactor: extract get_updatable_modules and harden Weight…
dreamyang-liu c442af2
[diffusion] address comments
dreamyang-liu 881d8b3
adds doc string to diffusion rifit test
zhaochenyang20 dfc93f6
[diffusion] Add /get_weights_checksum endpoint for SHA-256 weight ver…
dreamyang-liu 68902f3
[diffusion] Add corrupted-weight rollback test for update_weights_fro…
dreamyang-liu 719d31f
[diffusion] Parametrize weight-update tests over FLUX and Qwen model …
dreamyang-liu 81d585b
[diffusion] Optimize weight-update tests
dreamyang-liu 85b646c
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 32a743b
[TODO] model weights is updated only once
zhaochenyang20 c35eecd
Deduplicated tests; Should clean up
zhaochenyang20 41148c4
clean up codes with mixin; currently spanning 16mins; too long; shoul…
zhaochenyang20 14e69ec
Update docstring for GetWeightsChecksumReqInput
zhaochenyang20 b64b185
Refine docstring for weight checksum verification
zhaochenyang20 d68da5c
Simplify comments in layerwise offload method
zhaochenyang20 d3728cb
Improve documentation for iter_materialized_weights
zhaochenyang20 f831d94
fix lint
zhaochenyang20 bd23519
Refactor update_weights_from_disk tests
dreamyang-liu 22b32c4
Merge branch 'main' into feat/diffusion-update-weights-from-disk
zhaochenyang20 a9f6808
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 a32aed9
Merge branch 'main' into feat/diffusion-update-weights-from-disk
zhaochenyang20 66e3c17
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 52ba35b
new docs string
zhaochenyang20 c3d478e
remove one line function
zhaochenyang20 cde71fe
consolidate rollback tests
zhaochenyang20 40bba8d
finalize the test
zhaochenyang20 7467682
fix CI random choice
zhaochenyang20 3100b2f
fix paring issue
zhaochenyang20 fb87570
incline path finding
zhaochenyang20 fab939e
remove redundant comments
zhaochenyang20 9e030f0
Merge branch 'main' into feat/diffusion-update-weights-from-disk
zhaochenyang20 2b44290
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 f60f638
fix isort
zhaochenyang20 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
zhaochenyang20 marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -87,6 +87,77 @@ def get_updatable_modules(pipeline) -> dict[str, torch.nn.Module]: | |
| return {n: m for n, m in raw.items() if isinstance(m, torch.nn.Module)} | ||
|
|
||
|
|
||
| def _get_weights_iter(weights_dir: str): | ||
| """Return a (name, tensor) iterator over safetensors in weights_dir.""" | ||
| safetensors_files = _list_safetensors_files(weights_dir) | ||
| if not safetensors_files: | ||
| raise FileNotFoundError(f"No safetensors files found in {weights_dir}") | ||
| return safetensors_weights_iterator(safetensors_files) | ||
|
|
||
|
|
||
| def _validate_weight_files( | ||
| local_model_path: str, | ||
| modules_to_update: list[tuple[str, torch.nn.Module]], | ||
| ) -> tuple[dict[str, str], list[str]]: | ||
| """Check that every module has a weights directory with safetensors files. | ||
|
|
||
| Returns: | ||
| (weights_map, missing) where weights_map maps module name to its | ||
| weights directory and missing lists modules without weight files. | ||
| """ | ||
| weights_map: dict[str, str] = {} | ||
| missing: list[str] = [] | ||
| for module_name, _ in modules_to_update: | ||
| weights_dir = find_weights_dir(local_model_path, module_name) | ||
| if weights_dir and _list_safetensors_files(weights_dir): | ||
| weights_map[module_name] = weights_dir | ||
| else: | ||
| missing.append(module_name) | ||
| return weights_map, missing | ||
|
|
||
|
|
||
| def _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None: | ||
| """Load weights into a module, handling offload-managed parameters. | ||
|
|
||
| For offloaded modules, updates CPU buffers directly via | ||
| update_cpu_weights(); non-offloaded parameters use in-place copy. | ||
| """ | ||
| offload_managers: list = [] | ||
| if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: | ||
| offload_managers = [m for m in module.layerwise_offload_managers if m.enabled] | ||
|
|
||
| if offload_managers: | ||
| weight_dict = dict(weights_iter) | ||
| offloaded_names: set[str] = set() | ||
| for manager in offload_managers: | ||
| offloaded_names.update(manager.update_cpu_weights(weight_dict)) | ||
| remaining = ((n, w) for n, w in weight_dict.items() if n not in offloaded_names) | ||
| load_weights_into_model(remaining, dict(module.named_parameters())) | ||
| else: | ||
| load_weights_into_model(weights_iter, dict(module.named_parameters())) | ||
|
|
||
|
|
||
| def load_weights_into_model(weights_iter, model_params: dict) -> None: | ||
| """Copy weights from weights_iter into model_params in-place.""" | ||
| for name, loaded_weight in weights_iter: | ||
| if name not in model_params: | ||
| continue | ||
| param = model_params[name] | ||
| if param.shape != loaded_weight.shape: | ||
| raise ValueError( | ||
| f"Shape mismatch for {name}: model={param.shape}, loaded={loaded_weight.shape}" | ||
| ) | ||
| if isinstance(param, DTensor): | ||
| distributed_weight = distribute_tensor( | ||
| loaded_weight.to(param.dtype), | ||
| param.device_mesh, | ||
| param.placements, | ||
| ) | ||
| param._local_tensor.copy_(distributed_weight._local_tensor) | ||
| else: | ||
| param.data.copy_(loaded_weight.to(param.dtype)) | ||
|
|
||
|
|
||
| class WeightsUpdater: | ||
| """In-place weight updates for diffusion pipeline modules. | ||
|
|
||
|
|
@@ -168,10 +239,6 @@ def update_weights_from_disk( | |
| logger.info(message) | ||
| return success, message | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Private helpers | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| def _collect_modules( | ||
| self, target_modules: list[str] | None | ||
| ) -> list[tuple[str, torch.nn.Module]]: | ||
|
|
@@ -244,79 +311,3 @@ def _rollback(self, updated_modules: list[str]) -> None: | |
| continue | ||
| weights_iter = _get_weights_iter(weights_dir) | ||
| _load_weights_into_module(module, weights_iter) | ||
|
Comment on lines
+276
to
+293
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I said, I think we should have other methods to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you clarify what you mean by "other methods to _rollback"? |
||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Module-level utility functions | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def _get_weights_iter(weights_dir: str): | ||
| """Return a (name, tensor) iterator over safetensors in weights_dir.""" | ||
| safetensors_files = _list_safetensors_files(weights_dir) | ||
| if not safetensors_files: | ||
| raise FileNotFoundError(f"No safetensors files found in {weights_dir}") | ||
| return safetensors_weights_iterator(safetensors_files) | ||
|
|
||
|
|
||
| def _validate_weight_files( | ||
| local_model_path: str, | ||
| modules_to_update: list[tuple[str, torch.nn.Module]], | ||
| ) -> tuple[dict[str, str], list[str]]: | ||
| """Check that every module has a weights directory with safetensors files. | ||
|
|
||
| Returns: | ||
| (weights_map, missing) where weights_map maps module name to its | ||
| weights directory and missing lists modules without weight files. | ||
| """ | ||
| weights_map: dict[str, str] = {} | ||
| missing: list[str] = [] | ||
| for module_name, _ in modules_to_update: | ||
| weights_dir = find_weights_dir(local_model_path, module_name) | ||
| if weights_dir and _list_safetensors_files(weights_dir): | ||
| weights_map[module_name] = weights_dir | ||
| else: | ||
| missing.append(module_name) | ||
| return weights_map, missing | ||
|
|
||
|
|
||
| def _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None: | ||
| """Load weights into a module, handling offload-managed parameters. | ||
|
|
||
| For offloaded modules, updates CPU buffers directly via | ||
| update_cpu_weights(); non-offloaded parameters use in-place copy. | ||
| """ | ||
| offload_managers: list = [] | ||
| if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers: | ||
| offload_managers = [m for m in module.layerwise_offload_managers if m.enabled] | ||
|
|
||
| if offload_managers: | ||
| weight_dict = dict(weights_iter) | ||
| offloaded_names: set[str] = set() | ||
| for manager in offload_managers: | ||
| offloaded_names.update(manager.update_cpu_weights(weight_dict)) | ||
| remaining = ((n, w) for n, w in weight_dict.items() if n not in offloaded_names) | ||
| load_weights_into_model(remaining, dict(module.named_parameters())) | ||
| else: | ||
| load_weights_into_model(weights_iter, dict(module.named_parameters())) | ||
|
|
||
|
|
||
| def load_weights_into_model(weights_iter, model_params: dict) -> None: | ||
| """Copy weights from weights_iter into model_params in-place.""" | ||
| for name, loaded_weight in weights_iter: | ||
| if name not in model_params: | ||
| continue | ||
| param = model_params[name] | ||
| if param.shape != loaded_weight.shape: | ||
| raise ValueError( | ||
| f"Shape mismatch for {name}: model={param.shape}, loaded={loaded_weight.shape}" | ||
| ) | ||
| if isinstance(param, DTensor): | ||
| distributed_weight = distribute_tensor( | ||
| loaded_weight.to(param.dtype), | ||
| param.device_mesh, | ||
| param.placements, | ||
| ) | ||
| param._local_tensor.copy_(distributed_weight._local_tensor) | ||
| else: | ||
| param.data.copy_(loaded_weight.to(param.dtype)) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.