Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c93c103
[Feature] Implement update_weights_from_disk for SGLang-D (Diffusion …
dreamyang-liu Feb 5, 2026
7703c69
[diffusion] refactor: extract WeightsUpdater for update_weights_from_…
dreamyang-liu Feb 6, 2026
886a6bc
chore: isort lint
dreamyang-liu Feb 6, 2026
3d19d61
[diffusion] offload-aware weight updates and cleanup
dreamyang-liu Feb 7, 2026
d0e2fec
[diffusion] refactor: move post-training API to dedicated package and…
dreamyang-liu Feb 8, 2026
07cc679
[diffusion] refactor: extract get_updatable_modules and harden Weight…
dreamyang-liu Feb 10, 2026
c442af2
[diffusion] address comments
dreamyang-liu Feb 11, 2026
881d8b3
adds doc string to diffusion rifit test
zhaochenyang20 Feb 11, 2026
dfc93f6
[diffusion] Add /get_weights_checksum endpoint for SHA-256 weight ver…
dreamyang-liu Feb 12, 2026
68902f3
[diffusion] Add corrupted-weight rollback test for update_weights_fro…
dreamyang-liu Feb 12, 2026
719d31f
[diffusion] Parametrize weight-update tests over FLUX and Qwen model …
dreamyang-liu Feb 13, 2026
81d585b
[diffusion] Optimize weight-update tests
dreamyang-liu Feb 13, 2026
85b646c
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 14, 2026
32a743b
[TODO] model weights is updated only once
zhaochenyang20 Feb 14, 2026
c35eecd
Deduplicated tests; Should clean up
zhaochenyang20 Feb 14, 2026
41148c4
clean up codes with mixin; currently spanning 16mins; too long; shoul…
zhaochenyang20 Feb 14, 2026
14e69ec
Update docstring for GetWeightsChecksumReqInput
zhaochenyang20 Feb 14, 2026
b64b185
Refine docstring for weight checksum verification
zhaochenyang20 Feb 14, 2026
d68da5c
Simplify comments in layerwise offload method
zhaochenyang20 Feb 14, 2026
d3728cb
Improve documentation for iter_materialized_weights
zhaochenyang20 Feb 14, 2026
f831d94
fix lint
zhaochenyang20 Feb 15, 2026
bd23519
Refactor update_weights_from_disk tests
dreamyang-liu Feb 15, 2026
22b32c4
Merge branch 'main' into feat/diffusion-update-weights-from-disk
zhaochenyang20 Feb 16, 2026
a9f6808
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 16, 2026
a32aed9
Merge branch 'main' into feat/diffusion-update-weights-from-disk
zhaochenyang20 Feb 16, 2026
66e3c17
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 16, 2026
52ba35b
new docs string
zhaochenyang20 Feb 16, 2026
c3d478e
remove one line function
zhaochenyang20 Feb 16, 2026
cde71fe
consolidate rollback tests
zhaochenyang20 Feb 16, 2026
40bba8d
finalize the test
zhaochenyang20 Feb 16, 2026
7467682
fix CI random choice
zhaochenyang20 Feb 16, 2026
3100b2f
fix paring issue
zhaochenyang20 Feb 16, 2026
fb87570
incline path finding
zhaochenyang20 Feb 17, 2026
fab939e
remove redundant comments
zhaochenyang20 Feb 17, 2026
9e030f0
Merge branch 'main' into feat/diffusion-update-weights-from-disk
zhaochenyang20 Feb 18, 2026
2b44290
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 18, 2026
f60f638
fix isort
zhaochenyang20 Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions python/sglang/multimodal_gen/runtime/loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import re
from collections import defaultdict
from collections.abc import Callable, Iterator
from pathlib import Path
from typing import Any, Dict, Type

import torch
Expand Down Expand Up @@ -148,14 +149,14 @@ def _list_safetensors_files(model_path: str) -> list[str]:
return sorted(glob.glob(os.path.join(str(model_path), "*.safetensors")))


def find_weights_dir(local_path: str, module_name: str) -> str | None:
def find_weights_dir(local_path: str, module_name: str) -> Path | None:
"""Locate the safetensors directory for module_name under local_path.

Diffusion models store weights in per-module subdirectories (e.g.
transformer/, vae/, text_encoder/).
"""
dir_path = os.path.join(local_path, module_name)
if os.path.exists(dir_path):
dir_path = Path(local_path) / module_name
if dir_path.exists():
return dir_path
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/model_loader/weight_utils.py
"""Utilities for downloading, loading, and verifying model weights."""
"""Utilities for downloading, loading, initializing and verifying model weights."""

import hashlib
import json
import os
Expand Down
151 changes: 71 additions & 80 deletions python/sglang/multimodal_gen/runtime/loader/weights_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,77 @@ def get_updatable_modules(pipeline) -> dict[str, torch.nn.Module]:
return {n: m for n, m in raw.items() if isinstance(m, torch.nn.Module)}


def _get_weights_iter(weights_dir: str):
"""Return a (name, tensor) iterator over safetensors in weights_dir."""
safetensors_files = _list_safetensors_files(weights_dir)
if not safetensors_files:
raise FileNotFoundError(f"No safetensors files found in {weights_dir}")
return safetensors_weights_iterator(safetensors_files)


def _validate_weight_files(
local_model_path: str,
modules_to_update: list[tuple[str, torch.nn.Module]],
) -> tuple[dict[str, str], list[str]]:
"""Check that every module has a weights directory with safetensors files.

Returns:
(weights_map, missing) where weights_map maps module name to its
weights directory and missing lists modules without weight files.
"""
weights_map: dict[str, str] = {}
missing: list[str] = []
for module_name, _ in modules_to_update:
weights_dir = find_weights_dir(local_model_path, module_name)
if weights_dir and _list_safetensors_files(weights_dir):
weights_map[module_name] = weights_dir
else:
missing.append(module_name)
return weights_map, missing


def _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None:
"""Load weights into a module, handling offload-managed parameters.

For offloaded modules, updates CPU buffers directly via
update_cpu_weights(); non-offloaded parameters use in-place copy.
"""
offload_managers: list = []
if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers:
offload_managers = [m for m in module.layerwise_offload_managers if m.enabled]

if offload_managers:
weight_dict = dict(weights_iter)
offloaded_names: set[str] = set()
for manager in offload_managers:
offloaded_names.update(manager.update_cpu_weights(weight_dict))
remaining = ((n, w) for n, w in weight_dict.items() if n not in offloaded_names)
load_weights_into_model(remaining, dict(module.named_parameters()))
else:
load_weights_into_model(weights_iter, dict(module.named_parameters()))


def load_weights_into_model(weights_iter, model_params: dict) -> None:
"""Copy weights from weights_iter into model_params in-place."""
for name, loaded_weight in weights_iter:
if name not in model_params:
continue
param = model_params[name]
if param.shape != loaded_weight.shape:
raise ValueError(
f"Shape mismatch for {name}: model={param.shape}, loaded={loaded_weight.shape}"
)
if isinstance(param, DTensor):
distributed_weight = distribute_tensor(
loaded_weight.to(param.dtype),
param.device_mesh,
param.placements,
)
param._local_tensor.copy_(distributed_weight._local_tensor)
else:
param.data.copy_(loaded_weight.to(param.dtype))


class WeightsUpdater:
"""In-place weight updates for diffusion pipeline modules.

Expand Down Expand Up @@ -168,10 +239,6 @@ def update_weights_from_disk(
logger.info(message)
return success, message

# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------

def _collect_modules(
self, target_modules: list[str] | None
) -> list[tuple[str, torch.nn.Module]]:
Expand Down Expand Up @@ -244,79 +311,3 @@ def _rollback(self, updated_modules: list[str]) -> None:
continue
weights_iter = _get_weights_iter(weights_dir)
_load_weights_into_module(module, weights_iter)
Comment on lines +276 to +293
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said, I think we should have other methods to _rollback.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify what you mean by "other methods to _rollback"?



# ---------------------------------------------------------------------------
# Module-level utility functions
# ---------------------------------------------------------------------------


def _get_weights_iter(weights_dir: str):
"""Return a (name, tensor) iterator over safetensors in weights_dir."""
safetensors_files = _list_safetensors_files(weights_dir)
if not safetensors_files:
raise FileNotFoundError(f"No safetensors files found in {weights_dir}")
return safetensors_weights_iterator(safetensors_files)


def _validate_weight_files(
local_model_path: str,
modules_to_update: list[tuple[str, torch.nn.Module]],
) -> tuple[dict[str, str], list[str]]:
"""Check that every module has a weights directory with safetensors files.

Returns:
(weights_map, missing) where weights_map maps module name to its
weights directory and missing lists modules without weight files.
"""
weights_map: dict[str, str] = {}
missing: list[str] = []
for module_name, _ in modules_to_update:
weights_dir = find_weights_dir(local_model_path, module_name)
if weights_dir and _list_safetensors_files(weights_dir):
weights_map[module_name] = weights_dir
else:
missing.append(module_name)
return weights_map, missing


def _load_weights_into_module(module: torch.nn.Module, weights_iter) -> None:
"""Load weights into a module, handling offload-managed parameters.

For offloaded modules, updates CPU buffers directly via
update_cpu_weights(); non-offloaded parameters use in-place copy.
"""
offload_managers: list = []
if isinstance(module, OffloadableDiTMixin) and module.layerwise_offload_managers:
offload_managers = [m for m in module.layerwise_offload_managers if m.enabled]

if offload_managers:
weight_dict = dict(weights_iter)
offloaded_names: set[str] = set()
for manager in offload_managers:
offloaded_names.update(manager.update_cpu_weights(weight_dict))
remaining = ((n, w) for n, w in weight_dict.items() if n not in offloaded_names)
load_weights_into_model(remaining, dict(module.named_parameters()))
else:
load_weights_into_model(weights_iter, dict(module.named_parameters()))


def load_weights_into_model(weights_iter, model_params: dict) -> None:
"""Copy weights from weights_iter into model_params in-place."""
for name, loaded_weight in weights_iter:
if name not in model_params:
continue
param = model_params[name]
if param.shape != loaded_weight.shape:
raise ValueError(
f"Shape mismatch for {name}: model={param.shape}, loaded={loaded_weight.shape}"
)
if isinstance(param, DTensor):
distributed_weight = distribute_tensor(
loaded_weight.to(param.dtype),
param.device_mesh,
param.placements,
)
param._local_tensor.copy_(distributed_weight._local_tensor)
else:
param.data.copy_(loaded_weight.to(param.dtype))
Loading