Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c93c103
[Feature] Implement update_weights_from_disk for SGLang-D (Diffusion …
dreamyang-liu Feb 5, 2026
7703c69
[diffusion] refactor: extract WeightsUpdater for update_weights_from_…
dreamyang-liu Feb 6, 2026
886a6bc
chore: isort lint
dreamyang-liu Feb 6, 2026
3d19d61
[diffusion] offload-aware weight updates and cleanup
dreamyang-liu Feb 7, 2026
d0e2fec
[diffusion] refactor: move post-training API to dedicated package and…
dreamyang-liu Feb 8, 2026
07cc679
[diffusion] refactor: extract get_updatable_modules and harden Weight…
dreamyang-liu Feb 10, 2026
c442af2
[diffusion] address comments
dreamyang-liu Feb 11, 2026
881d8b3
adds doc string to diffusion rifit test
zhaochenyang20 Feb 11, 2026
dfc93f6
[diffusion] Add /get_weights_checksum endpoint for SHA-256 weight ver…
dreamyang-liu Feb 12, 2026
68902f3
[diffusion] Add corrupted-weight rollback test for update_weights_fro…
dreamyang-liu Feb 12, 2026
719d31f
[diffusion] Parametrize weight-update tests over FLUX and Qwen model …
dreamyang-liu Feb 13, 2026
81d585b
[diffusion] Optimize weight-update tests
dreamyang-liu Feb 13, 2026
85b646c
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 14, 2026
32a743b
[TODO] model weights is updated only once
zhaochenyang20 Feb 14, 2026
c35eecd
Deduplicated tests; Should clean up
zhaochenyang20 Feb 14, 2026
41148c4
clean up codes with mixin; currently spanning 16mins; too long; shoul…
zhaochenyang20 Feb 14, 2026
14e69ec
Update docstring for GetWeightsChecksumReqInput
zhaochenyang20 Feb 14, 2026
b64b185
Refine docstring for weight checksum verification
zhaochenyang20 Feb 14, 2026
d68da5c
Simplify comments in layerwise offload method
zhaochenyang20 Feb 14, 2026
d3728cb
Improve documentation for iter_materialized_weights
zhaochenyang20 Feb 14, 2026
f831d94
fix lint
zhaochenyang20 Feb 15, 2026
bd23519
Refactor update_weights_from_disk tests
dreamyang-liu Feb 15, 2026
22b32c4
Merge branch 'main' into feat/diffusion-update-weights-from-disk
zhaochenyang20 Feb 16, 2026
a9f6808
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 16, 2026
a32aed9
Merge branch 'main' into feat/diffusion-update-weights-from-disk
zhaochenyang20 Feb 16, 2026
66e3c17
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 16, 2026
52ba35b
new docs string
zhaochenyang20 Feb 16, 2026
c3d478e
remove one line function
zhaochenyang20 Feb 16, 2026
cde71fe
consolidate rollback tests
zhaochenyang20 Feb 16, 2026
40bba8d
finalize the test
zhaochenyang20 Feb 16, 2026
7467682
fix CI random choice
zhaochenyang20 Feb 16, 2026
3100b2f
fix paring issue
zhaochenyang20 Feb 16, 2026
fb87570
incline path finding
zhaochenyang20 Feb 17, 2026
fab939e
remove redundant comments
zhaochenyang20 Feb 17, 2026
9e030f0
Merge branch 'main' into feat/diffusion-update-weights-from-disk
zhaochenyang20 Feb 18, 2026
2b44290
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 18, 2026
f60f638
fix isort
zhaochenyang20 Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
finalize the test
  • Loading branch information
zhaochenyang20 committed Feb 16, 2026
commit 40bba8d5af4dc0efa7ebbf912c34d77418c611ac
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,6 @@ def _clone_model_with_modified_module(
target_module: str,
transform_safetensor: Callable[[str, str], None],
) -> None:
"""Clone a model directory via symlinks, applying transform to one module.

Everything is symlinked except the target module's first .safetensors
file, which is transformed (causing a checksum difference or corruption);
remaining files are symlinked for speed.
"""
# Symlink root-level files (model_index.json, etc.).
for fname in os.listdir(src_model):
src_path = os.path.join(src_model, fname)
Expand Down Expand Up @@ -260,7 +254,6 @@ def _clone_model_with_modified_module(


def _truncate_safetensor(src_file: str, dst_file: str) -> None:
"""Copy then truncate — produces an invalid safetensors that triggers rollback."""
shutil.copy2(src_file, dst_file)
size = os.path.getsize(dst_file)
with open(dst_file, "r+b") as f:
Expand All @@ -274,7 +267,6 @@ def _truncate_safetensor(src_file: str, dst_file: str) -> None:


def _perturb_safetensor(src_file: str, dst_file: str) -> None:
"""Load, add small perturbation to floating-point tensors, and save."""

tensors = load_file(src_file)
perturbed = {
Expand Down Expand Up @@ -327,16 +319,6 @@ def _assert_server_matches_model(
base_url: str,
expected_model: str,
) -> None:
"""Assert the server's transformer checksum matches expected_model on disk.

Only the transformer is verified because weight-name remapping and
QKV merge during model loading cause in-memory parameter names/shapes
to diverge from on-disk safetensors for other modules (e.g. vae),
making their checksums incomparable.

TODO: Extend to verify all modules once these
discrepancies are resolved.
"""
server_checksums = self._get_weights_checksum(
base_url, module_names=[_TRANSFORMER_MODULE]
)
Expand All @@ -350,30 +332,13 @@ def _assert_server_matches_model(


class TestUpdateWeightsFromDisk(_UpdateWeightsApiMixin):
"""Test suite for update_weights_from_disk API and corrupted-weight rollback.

Uses a class-scoped server fixture so the server is torn down at class end,
freeing the port and GPU memory before the offload class starts.
"""

@pytest.fixture(
scope="class",
params=_ACTIVE_MODEL_PAIRS,
ids=_PAIR_IDS,
)
def diffusion_server_no_offload(self, request):
"""Start a diffusion server (no offload) for this test class.

Builds two perturbed checkpoints from the source model:
- perturbed_vae_model_dir: source model with perturbed vae (both
transformer and vae differ from base).
- corrupted_vae_model_dir: base model with truncated vae — triggers
load failure for rollback testing.

Checksum cache warmup and perturbed checkpoints building run in background
threads while the server boots, so everything is ready by the time
tests start.
"""
default_model, source_model = request.param
port = get_dynamic_server_port()
wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600"))
Expand Down Expand Up @@ -457,15 +422,6 @@ def test_update_weights_from_disk_default(self, diffusion_server_no_offload):
self._assert_server_matches_model(base_url, perturbed_model_dir)

def test_update_weights_specific_modules(self, diffusion_server_no_offload):
"""Verify target_modules filtering: only the specified module is updated.

The perturbed checkpoint has different weights for both transformer and
vae. This test randomly picks ONE of them as target_modules and loads
from the perturbed checkpoint. Assertions:
(1) the targeted module's in-memory checksum changed (before != after);
(2) every non-targeted module's in-memory checksum is unchanged,
proving the server only touched what was requested.
"""
ctx, default_model, perturbed_model_dir, _ = diffusion_server_no_offload
base_url = f"http://localhost:{ctx.port}"

Expand Down Expand Up @@ -564,18 +520,6 @@ def test_update_weights_nonexistent_module(self, diffusion_server_no_offload):
self._assert_server_matches_model(base_url, default_model)

def test_corrupted_weights_rollback(self, diffusion_server_no_offload):
"""Verify all-or-nothing rollback on corrupted weights.

Steps:
1. base → perturbed (succeeds, server now on perturbed checkpoint).
2. perturbed → corrupted with target_modules=[transformer, vae].
The corrupted checkpoint has a truncated vae safetensors file.
We explicitly assert the first failed module is vae from the API
error message (which reports the failing module name), proving
transformer was attempted before the vae parse failure and that
rollback then covered both modules.
3. Assert the server rolled back to the perturbed checkpoint, not base.
"""
ctx, default_model, perturbed_model_dir, corrupted_vae_model_dir = (
diffusion_server_no_offload
)
Expand Down Expand Up @@ -650,11 +594,6 @@ class TestUpdateWeightsFromDiskWithOffload(_UpdateWeightsApiMixin):

@pytest.fixture(scope="class", params=_ACTIVE_MODEL_PAIRS, ids=_PAIR_IDS)
def diffusion_server_with_offload(self, request):
"""Start a diffusion server with layerwise offload enabled.

Also builds perturbed_vae_model_dir in a background thread
while the server boots.
"""
default_model, source_model = request.param
port = get_dynamic_server_port()
wait_deadline = float(os.environ.get("SGLANG_TEST_WAIT_SECS", "600"))
Expand Down Expand Up @@ -690,7 +629,6 @@ def diffusion_server_with_offload(self, request):
shutil.rmtree(perturbed_vae_model_dir, ignore_errors=True)

def test_update_weights_with_offload_enabled(self, diffusion_server_with_offload):
"""Offload: base→perturbed; no Shape mismatch; checksums == perturbed disk."""
ctx, _, perturbed_model_dir = diffusion_server_with_offload
base_url = f"http://localhost:{ctx.port}"

Expand Down
Loading