Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c93c103
[Feature] Implement update_weights_from_disk for SGLang-D (Diffusion …
dreamyang-liu Feb 5, 2026
7703c69
[diffusion] refactor: extract WeightsUpdater for update_weights_from_…
dreamyang-liu Feb 6, 2026
886a6bc
chore: isort lint
dreamyang-liu Feb 6, 2026
3d19d61
[diffusion] offload-aware weight updates and cleanup
dreamyang-liu Feb 7, 2026
d0e2fec
[diffusion] refactor: move post-training API to dedicated package and…
dreamyang-liu Feb 8, 2026
07cc679
[diffusion] refactor: extract get_updatable_modules and harden Weight…
dreamyang-liu Feb 10, 2026
c442af2
[diffusion] address comments
dreamyang-liu Feb 11, 2026
881d8b3
adds doc string to diffusion rifit test
zhaochenyang20 Feb 11, 2026
dfc93f6
[diffusion] Add /get_weights_checksum endpoint for SHA-256 weight ver…
dreamyang-liu Feb 12, 2026
68902f3
[diffusion] Add corrupted-weight rollback test for update_weights_fro…
dreamyang-liu Feb 12, 2026
719d31f
[diffusion] Parametrize weight-update tests over FLUX and Qwen model …
dreamyang-liu Feb 13, 2026
81d585b
[diffusion] Optimize weight-update tests
dreamyang-liu Feb 13, 2026
85b646c
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 14, 2026
32a743b
[TODO] model weights is updated only once
zhaochenyang20 Feb 14, 2026
c35eecd
Deduplicated tests; Should clean up
zhaochenyang20 Feb 14, 2026
41148c4
clean up codes with mixin; currently spanning 16mins; too long; shoul…
zhaochenyang20 Feb 14, 2026
14e69ec
Update docstring for GetWeightsChecksumReqInput
zhaochenyang20 Feb 14, 2026
b64b185
Refine docstring for weight checksum verification
zhaochenyang20 Feb 14, 2026
d68da5c
Simplify comments in layerwise offload method
zhaochenyang20 Feb 14, 2026
d3728cb
Improve documentation for iter_materialized_weights
zhaochenyang20 Feb 14, 2026
f831d94
fix lint
zhaochenyang20 Feb 15, 2026
bd23519
Refactor update_weights_from_disk tests
dreamyang-liu Feb 15, 2026
22b32c4
Merge branch 'main' into feat/diffusion-update-weights-from-disk
zhaochenyang20 Feb 16, 2026
a9f6808
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 16, 2026
a32aed9
Merge branch 'main' into feat/diffusion-update-weights-from-disk
zhaochenyang20 Feb 16, 2026
66e3c17
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 16, 2026
52ba35b
new docs string
zhaochenyang20 Feb 16, 2026
c3d478e
remove one line function
zhaochenyang20 Feb 16, 2026
cde71fe
consolidate rollback tests
zhaochenyang20 Feb 16, 2026
40bba8d
finalize the test
zhaochenyang20 Feb 16, 2026
7467682
fix CI random choice
zhaochenyang20 Feb 16, 2026
3100b2f
fix paring issue
zhaochenyang20 Feb 16, 2026
fb87570
incline path finding
zhaochenyang20 Feb 17, 2026
fab939e
remove redundant comments
zhaochenyang20 Feb 17, 2026
9e030f0
Merge branch 'main' into feat/diffusion-update-weights-from-disk
zhaochenyang20 Feb 18, 2026
2b44290
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 18, 2026
f60f638
fix isort
zhaochenyang20 Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
new docs string
  • Loading branch information
zhaochenyang20 committed Feb 16, 2026
commit 52ba35b6d99dd180c428c10608a049eeb148a4c1
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,36 @@
Menyang Liu, https://github.com/dreamyang-liu
Chenyang Zhao, https://github.com/zhaochenyang20

We use two model pairs for testing (base model / source model pairs):
We use two model pairs for testing (base model / instruct model pairs):

- FLUX.2-klein-base-4B / FLUX.2-klein-4B
- Qwen/Qwen-Image / Qwen/Qwen-Image-2512

These model pairs share the same architecture but differ in transformer
weights (all other modules — vae, text_encoder, … — are identical).

The source model is not used directly by any test. Instead, at fixture
setup time we clone it and perturb its vae weights, producing a synthetic
perturbed checkpoint (perturbed_vae_model_dir) where both transformer AND
vae differ from the base model. This perturbed checkpoint is used in all
tests, giving us two modules with known different checksums to verify.

NOTE: Disk-vs-server checksum verification currently ONLY covers transformer.
Other modules have weight-name remapping / QKV merge mismatches to resolve first.
weights. The basic testing logic is to refit the instruct model into the
base model and verify the checksum of the transformer weights are the same,
which simulates the real-world RL scenario. However, since these two model
pairs only differ in transformer weights, and we want to verify update a
specific module with update_weights_from_disk API, we need to create a perturbed
instruct model that adds noise to the vae weights. In this sense, the instruct
model differs from the base model in vae and transformer weights, the text
encoder are still the same.

To strictly verify the correctness of the refit API, we compare the checksum in
SHA-256 on the disk and the server.

NOTE and TODO: In the refit a specific module test, we randomly select one module
from the transformer and vae to refit the server and keep other modules the same.
As described above, the vae's weights are perturbed. If we select the vae to be the
target module, ideally speaking, we should assert that the refitted vae's checksum
is the same as directly computed from the perturbed vae weights in the disk. However,
since the there is complex weight-name remapping and QKV merge during model loading,
it is not easy to compare the server-disk checksum for vae and text encoder directly.
Therefore, if the target module is vae, we only verify that the refitted vae's checksum
is different from the base model's vae's checksum.

It should be good issue to solve for the community to adds comparison the server-disk
checksum for vae and text encoder in this test.

=============================================================================

Expand All @@ -46,49 +60,61 @@

All tests share one class-scoped server (same process, same in-memory weights).
Tests that require "base model then update" should be explicitly reset to
default_model first so behavior is order-independent and updates are real
(baseperturbed), not no-ops (perturbedperturbed).
base model first so behavior is order-independent and updates are real
(base -> perturbed), not no-ops (perturbed -> perturbed).

• test_update_weights_from_disk_default

base -> perturbed with flush_cache=True.
Verifies after-update checksum == perturbed checkpoint disk checksum
(implicitly confirms weights changed, since fixture guarantees
base ≠ perturbed).
base model -> perturbed model with flush_cache=True.
Verifies after-update transformer checksum == perturbed model's
transformer disk checksum


• test_update_weights_specific_modules

base -> perturbed with flush_cache=False. Randomly selects one module
from _DIFFERING_MODULES (modules whose weights differ between base and
perturbed checkpoint) as target_modules, updates only that module. Verifies:
from _DIFFERING_MODULES (transformer and vae) as target_modules, updates
only that module. Verifies that:
(1) targeted module's in-memory checksum changed;
(2) non-targeted modules' in-memory checksums are unchanged.

• test_update_weights_nonexistent_model

model_path set to a non-existent path; must fail (400, success=False).

Ensure server is healthy after failed update and server's checksums
equal base model's disk checksums.
Ensure server is healthy after failed update and server's transformer
checksums equal base model's transformer disk checksum.

• test_update_weights_missing_model_path

Request body empty (no model_path); must fail (400, success=False).

Ensure server is healthy after failed update and server's checksums
equal base model's disk checksums.
Ensure server is healthy after failed update and server's transformer
checksums equal base model's transformer disk checksum.

• test_update_weights_nonexistent_module

target_modules=["nonexistent_module"]; must fail (400, success=False).

Verify server is healthy after failed update and server's checksums
equal base model's disk checksums.
equal base model's transformer disk checksum.

• test_corrupted_weights_rollback

All-or-nothing rollback: base→perturbed succeeds, then perturbed→corrupted
fails (truncated vae), server rolls back to the perturbed checkpoint.
All-or-nothing rollback: We first refit the server from base model ->
perturbed model. We manually truncate the vae weights of the base
model to get a corrupted model. We then call the refit to update
the server from the perturbed model -> corrupted model. Verify that:

1. The update fails due to truncated vae, server should roll back to the
perturbed model, i.e., server's transformer weights == perturbed model's
transformer weights != base model's transformer weights.

2. After the rollback, server's vae weights == perturbed model's vae
weights != base model's vae weights.

3. After the rollback, server's text encoder weights == base model's
text encoder weights == perturbed model's text encoder weights.

-----------------------------------------------------------------------------

Expand All @@ -103,7 +129,8 @@
• test_update_weights_with_offload_enabled

Server with --dit-layerwise-offload (base). Load perturbed checkpoint;
must succeed (200, success=True), no "Shape mismatch". Checksums match disk.
must succeed (200, success=True), no "Shape mismatch". server's transformer checksum
matches perturbed model's transformer disk checksum.
"""

from __future__ import annotations
Expand Down Expand Up @@ -146,7 +173,7 @@ class _Module(StrEnum):
VAE = "vae"


# Modules whose weights differ between the base model and the synthetic
# Modules whose weights differ between the base model and the perturbed
# perturbed checkpoint
_DIFFERING_MODULES: list[str] = [_Module.TRANSFORMER, _Module.VAE]

Expand Down Expand Up @@ -362,13 +389,13 @@ class TestUpdateWeightsFromDisk(_UpdateWeightsApiMixin):
def diffusion_server_no_offload(self, request):
"""Start a diffusion server (no offload) for this test class.

Builds two synthetic checkpoints from the source model:
Builds two perturbed checkpoints from the source model:
- perturbed_vae_model_dir: source model with perturbed vae (both
transformer and vae differ from base).
- corrupted_vae_model_dir: base model with truncated vae — triggers
load failure for rollback testing.

Checksum cache warmup and synthetic checkpoints building run in background
Checksum cache warmup and perturbed checkpoints building run in background
threads while the server boots, so everything is ready by the time
tests start.
"""
Expand Down