Skip to content

[Feature] Implement update_weights_from_disk for SGLang-D (Diffusion …#18306

Merged
zhaochenyang20 merged 37 commits intosgl-project:mainfrom
dreamyang-liu:feat/diffusion-update-weights-from-disk
Feb 18, 2026
Merged

[Feature] Implement update_weights_from_disk for SGLang-D (Diffusion …#18306
zhaochenyang20 merged 37 commits intosgl-project:mainfrom
dreamyang-liu:feat/diffusion-update-weights-from-disk

Conversation

@dreamyang-liu
Copy link
Contributor

@dreamyang-liu dreamyang-liu commented Feb 5, 2026

Motivation

Implement the update_weights_from_disk interface for SGLang-D (diffusion engine) to enable dynamic weight updates for RL workflows and iterative model fine-tuning without restarting the server.

This feature is essential for:

  • RL training loops: Update policy model weights after each training iteration
  • Iterative fine-tuning: Hot-swap checkpoints during experimentation
  • A/B testing: Switch between model versions without downtime

Mirrors the existing LLM engine's update_weights_from_disk functionality (see sglang/srt/managers/scheduler_update_weights_mixin.py).

Closes #18078

Modifications

Core Implementation (gpu_worker.py)

  • Add update_weights_from_disk() method with:
    • Support for all nn.Module components by default (transformer, vae, text_encoder, etc.)
    • Layerwise offload handling: Properly disable offload before update (load weights from CPU to GPU), then re-enable after (sync new weights back to CPU)
    • DTensor support: Handle distributed tensor parameters for tensor parallelism using distribute_tensor()
    • Atomic updates with rollback: If any module fails to update, rollback all previously updated modules
    • TeaCache state reset: Clear cached states after weight updates

Request Handling

  • Add UpdateWeightsFromDiskReq dataclass in new io_struct.py
  • Add scheduler handler for update_weights requests via ZMQ
  • Add /update_weights_from_disk HTTP endpoint in http_server.py

Tests (test_update_weights_from_disk.py)

  • Basic API tests: same model reload, flush_cache options, specific modules
  • Layerwise offload integration tests
  • End-to-end tests: verify generation works after weight update

Accuracy Tests

N/A - This feature only reloads weights from disk without modifying model forward logic. The same weights produce identical outputs.

Tested the images generated before weight update and after the weight update are identical.

Tested with

  • Qwen/Qwen-Image
  • Wan-AI/Wan2.1-T2V-1.3B-Diffusers
  • black-forest-labs/FLUX.2-klein-4B
  • zai-org/GLM-Image (weight update succeed, the image generation is having some issue on my machine, but it's not related to this change, other models tested that update weight can start inference and produce same result)

on 8 * A100 80G machine.

Logging output example:

Updating 4 modules: text_encoder <- /sgl-workspace/sglang/glm-image/text_encoder, vae <- /sgl-workspace/sglang/glm-image/vae, vision_language_encoder <- /sgl-workspace/sglang/glm-image/vision_language_encoder, transformer <- /sgl-workspace/sglang/glm-image/transformer

This log shows which modules are being updated and their corresponding weight directories, making it easy to verify the weight loading source for each module.

Example request

curl -X POST http://localhost:30010/update_weights_from_disk \
  -H "Content-Type: application/json" \
  -d '{"model_path": "/sgl-workspace/sglang/glm-image"}'

Benchmarking and Profiling

N/A - This is an operational API for weight management, not a performance-critical path. Weight loading time depends on model size and storage I/O.

Checklist

@github-actions github-actions bot added the diffusion SGLang Diffusion label Feb 5, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @dreamyang-liu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant enhancement to the SGLang-D diffusion engine by allowing dynamic updates of model weights from disk without requiring a server restart. This functionality is designed to streamline advanced machine learning workflows such as Reinforcement Learning and iterative fine-tuning, where frequent model adjustments are necessary. The changes ensure that models can be updated efficiently and reliably, maintaining operational continuity and providing a robust mechanism for managing model evolution.

Highlights

  • Dynamic Weight Updates: Implemented update_weights_from_disk for SGLang-D (diffusion engine), enabling in-place model weight updates without server restarts. This is crucial for Reinforcement Learning (RL) workflows and iterative model fine-tuning.
  • Robust Implementation: The core update_weights_from_disk method in gpu_worker.py supports all nn.Module components, handles layerwise offloading by temporarily disabling it, includes DTensor support for tensor parallelism, and features an atomic update mechanism with rollback capabilities for stability.
  • API and Request Handling: Introduced a new UpdateWeightsFromDiskReq dataclass, integrated a scheduler handler for these requests, and exposed a new /update_weights_from_disk HTTP endpoint in http_server.py for external access.
  • Comprehensive Testing: Added extensive tests in test_update_weights_from_disk.py covering basic API functionality, flush_cache options, specific module updates, layerwise offload integration, and end-to-end generation verification after weight updates.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/multimodal_gen/runtime/entrypoints/http_server.py
    • Added a new GET endpoint /get_model_info to retrieve the current model path.
    • Introduced a new POST endpoint /update_weights_from_disk to trigger dynamic model weight updates.
    • Integrated the weight_update_router into the FastAPI application.
  • python/sglang/multimodal_gen/runtime/managers/gpu_worker.py
    • Implemented the update_weights_from_disk method to manage the loading and application of new model weights.
    • Added logic for handling layerwise offload, ensuring weights are correctly loaded to GPU before update and synced back after.
    • Included support for DTensor to manage distributed tensor parameters.
    • Incorporated a rollback mechanism to restore original weights if an update fails.
    • Defined _MODULE_WEIGHT_DIR_MAPPING and _DEFAULT_TARGET_MODULES for flexible module identification.
    • Added _reset_cache_state_after_weight_update to clear TeaCache state post-update.
  • python/sglang/multimodal_gen/runtime/managers/io_struct.py
    • Created a new file to define the UpdateWeightsFromDiskReq dataclass, specifying the structure for weight update requests.
  • python/sglang/multimodal_gen/runtime/managers/scheduler.py
    • Imported the new UpdateWeightsFromDiskReq dataclass.
    • Registered a new handler _handle_update_weights_from_disk to process weight update requests.
    • Implemented the _handle_update_weights_from_disk method to orchestrate the weight update process via the GPU worker.
  • python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py
    • Added a new test file dedicated to validating the update_weights_from_disk API.
    • Included tests for basic API functionality, flush_cache options, and updating specific modules.
    • Provided tests for graceful failure handling with non-existent model paths and missing request parameters.
    • Added specific tests to ensure correct behavior when layerwise offload is enabled.
    • Implemented end-to-end tests to confirm that image generation functions correctly after a weight update.
Activity
  • The author dreamyang-liu initiated this pull request to introduce a new feature for dynamic model weight updates in the SGLang-D diffusion engine.
  • The PR includes a detailed motivation for the feature, highlighting its importance for RL training loops, iterative fine-tuning, and A/B testing.
  • The author has implemented the core logic, API endpoints, and scheduler integration, mirroring existing LLM engine functionality.
  • Comprehensive unit and end-to-end tests have been added to validate the new feature's functionality and robustness, including handling of layerwise offload.
  • The PR description indicates that code formatting, unit tests, and code style guidance have been addressed, with documentation updates pending.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new feature to dynamically update model weights from disk for the diffusion engine, which is very useful for RL workflows and iterative fine-tuning. The implementation is comprehensive, covering an HTTP endpoint, scheduler handling, and the core logic in the GPU worker. It includes important features like atomic updates with rollback, handling of layerwise offload, and support for distributed tensors. The addition of a full suite of tests, including unit, integration, and end-to-end tests, is commendable and ensures the feature is robust.

I have found one critical issue in the logic for collecting modules to be updated, which could lead to partial updates. My review includes a suggestion to fix this. Overall, this is a great contribution.

Copy link
Collaborator

@zhaochenyang20 zhaochenyang20 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One critical thing I wanna mention is that, do not copy yourself. I think a lot of code for update_weights_from_disk in LLM can be reused in Diffusion. Share these functions and do not copy these lines, which introduce quite a difficulty for maintenance.

If we can not reuse it, please also explain why. And, provide your running results of your tests, with the speed.

@dreamyang-liu
Copy link
Contributor Author

dreamyang-liu commented Feb 5, 2026

One critical thing I wanna mention is that, do not copy yourself. I think a lot of code for update_weights_from_disk in LLM can be reused in Diffusion. Share these functions and do not copy these lines, which introduce quite a difficulty for maintenance.

If we can not reuse it, please also explain why. And, provide your running results of your tests, with the speed.

Thanks for the comments and suggestions @zhaochenyang20 ! I agree we don't want to reinvent the wheel, but following reasons prevent us from reusing more from LLM engine's implementation.

1. Missing load_weights Interface in DiT Models

  • LLM's load_weights_and_postprocess requires models to implement load_weights() method
  • DiT models (Flux, HunyuanVideo, etc.) don't have this interface - they use param_names_mapping instead
# LLM: sglang/srt/model_loader/loader.py:683-684
def load_weights_and_postprocess(model, weights, target_device):
    model.load_weights(weights)  # Requires load_weights() method

# Diffusion DiT models have no load_weights() - check:
# sglang/multimodal_gen/runtime/models/dits/flux.py
# sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py

Unless we implement this interface in all diffusion models, it's hard to reuse the DefaultModelLoader

2. Different Distributed Mechanisms

  • LLM: Manual sharding via narrow() in weight_loader, no DTensor
  • Diffusion: FSDP2/DTensor with automatic distribute_tensor()
# LLM: sglang/srt/layers/linear.py:370-417
def weight_loader(self, param, loaded_weight):
    shard_size = param_data.shape[output_dim]
    start_idx = self.tp_rank * shard_size
    loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
    param_data.copy_(loaded_weight)

# Diffusion: sglang/multimodal_gen/runtime/loader/fsdp_load.py:264-270
if hasattr(meta_sharded_param, "device_mesh"):
    sharded_tensor = distribute_tensor(
        full_tensor, meta_sharded_param.device_mesh, meta_sharded_param.placements
    )

3. LLM Weight Loader Does Not Handle DTensor

  • LLM's weight_loader assumes regular torch.Tensor, no DTensor checks
  • Diffusion explicitly handles DTensor in weight loading
# Diffusion: sglang/multimodal_gen/runtime/managers/gpu_worker.py:408-415
if DTensor is not None and isinstance(param, DTensor):
    distributed_weight = distribute_tensor(
        loaded_weight.to(param.device, param.dtype),
        param.device_mesh, param.placements,
    )
    param._local_tensor.copy_(distributed_weight._local_tensor)

4. Different Weight Mapping Strategies

  • LLM: Complex stacked_params_mapping for qkv merging, quantization
  • Diffusion: Simple param_names_mapping for name translation
# LLM: sglang/srt/models/qwen2.py:556-561
stacked_params_mapping = [
    ("qkv_proj", "q_proj", "q"),
    ("qkv_proj", "k_proj", "k"),
    ("qkv_proj", "v_proj", "v"),
]

# Diffusion: uses model.param_names_mapping attribute
# sglang/multimodal_gen/runtime/loader/fsdp_load.py:118
param_names_mapping_fn = get_param_names_mapping(model.param_names_mapping)

5. Different Model Structures

  • LLM: Single self.model
  • Diffusion: Multi-module pipeline (transformer, vae, text_encoder, etc.)
# LLM: sglang/srt/model_executor/model_runner.py:1085
model = model_load_weights(self.model, iter)

# Diffusion: sglang/multimodal_gen/runtime/managers/gpu_worker.py:379-385
for name in module_names:
    module = self.pipeline.get_module(name)  # Multiple modules
    if module is not None and isinstance(module, torch.nn.Module):
        modules_to_update.append((name, module))

I do agree if you think this function too long, I can try to separate if you think that would make it much easier for maintain.

Please let me know your thoughts on above.

@dreamyang-liu
Copy link
Contributor Author

dreamyang-liu commented Feb 5, 2026

Test

Unit Test

Test with single A100 80G with model black-forest-labs/FLUX.2-klein-4B

============================= test session starts ==============================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0 -- /usr/bin/python
cachedir: .pytest_cache
rootdir: /sgl-workspace/sglang/python
configfile: pyproject.toml
plugins: anyio-4.12.1, typeguard-4.4.4
collecting ... collected 9 items

python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py::TestUpdateWeightsFromDisk::test_get_model_info PASSED [ 11%]
python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py::TestUpdateWeightsFromDisk::test_update_weights_same_model PASSED [ 22%]
python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py::TestUpdateWeightsFromDisk::test_update_weights_with_flush_cache PASSED [ 33%]
python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py::TestUpdateWeightsFromDisk::test_update_weights_without_flush_cache PASSED [ 44%]
python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py::TestUpdateWeightsFromDisk::test_update_weights_nonexistent_model PASSED [ 55%]
python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py::TestUpdateWeightsFromDisk::test_update_weights_missing_model_path PASSED [ 66%]
python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py::TestUpdateWeightsFromDisk::test_update_weights_specific_modules PASSED [ 77%]
python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py::TestUpdateWeightsFromDiskWithOffload::test_update_weights_with_offload_enabled PASSED [ 88%]
python/sglang/multimodal_gen/test/server/test_update_weights_from_disk.py::TestUpdateWeightsEndToEnd::test_generation_after_weight_update PASSED [100%]

=============================== warnings summary ===============================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================== 9 passed, 2 warnings in 187.21s (0:03:07) ===================

Speed Test (8 * A100):

============================================================
>>> update_weights_from_disk took 22.381 seconds <<< (black-forest-labs/FLUX.2-klein-4B)
============================================================

============================================================
>>> update_weights_from_disk took 56.255 seconds <<< (Qwen/Qwen-Image)
============================================================

@zhaochenyang20
Copy link
Collaborator

@dreamyang-liu, thanks for your detailed PR. Love to see your explanation. Could you adds that we also support update diffusion models weights in this docs:

https://github.com/sgl-project/sglang/blob/main/docs/advanced_features/sglang_for_rl.md

Copy link
Collaborator

@zhaochenyang20 zhaochenyang20 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to clean up the code and make it compact. thanks

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Feb 6, 2026
@dreamyang-liu
Copy link
Contributor Author

dreamyang-liu commented Feb 6, 2026

Try to clean up the code and make it compact. thanks

@zhaochenyang20 Thanks for the suggestions! I just made some refactor and clean up, please check the revision when you get time. Thanks!

@dreamyang-liu
Copy link
Contributor Author

@zhaochenyang20 @mickqian Thanks for the comments. I've made one more revision, please take a look when you get time and let me know if there're any other concerns.

@dreamyang-liu dreamyang-liu force-pushed the feat/diffusion-update-weights-from-disk branch 2 times, most recently from 7e4ff9e to 5759564 Compare February 10, 2026 03:51

@torch.compiler.disable
def update_cpu_weights(self, weight_dict: Dict[str, torch.Tensor]) -> Set[str]:
"""Update consolidated CPU buffers with new weights.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adds this here:

when layerwise offload (--dit-layerwise-offload) is enabled, the diffusion offload manager replaces GPU parameters with small torch.empty((1,)) placeholders while real weights live in consolidated pinned CPU buffers. A naive param.data.copy_() would fail with a shape mismatch. Instead, the updater dynamically detects active offload managers and writes new weights directly into their CPU buffers, bypassing the placeholders entirely. For any layer that happens to be prefetched on GPU at update time, the live GPU tensor is also updated so the change takes effect immediately. This requires no extra GPU memory and does not disturb the offload state.

@zhaochenyang20
Copy link
Collaborator

After this PR is merged, I think there are two unit tests we can make up in following PRs:

#18306 (comment)

#18306 (comment)

@zhaochenyang20
Copy link
Collaborator

A further quesiton here:

#18572

@dreamyang-liu dreamyang-liu force-pushed the feat/diffusion-update-weights-from-disk branch 2 times, most recently from 51c83bd to 666fb8b Compare February 11, 2026 08:45
@zhaochenyang20
Copy link
Collaborator

We use two model pairs for testing (base model / instruct model pairs):

  • FLUX.2-klein-base-4B / FLUX.2-klein-4B
  • Qwen/Qwen-Image / Qwen/Qwen-Image-2512

These model pairs share the same architecture but differ in transformer
weights. The basic testing logic is to refit the instruct model into the
base model and verify the checksum of the transformer weights are the same,
which simulates the real-world RL scenario. However, since these two model
pairs only differ in transformer weights, and we want to verify update a
specific module with update_weights_from_disk API, we need to create a perturbed
instruct model that adds noise to the vae weights. In this sense, the instruct
model differs from the base model in vae and transformer weights, the text
encoder are still the same.

To strictly verify the correctness of the refit API, we compare the checksum in
SHA-256 on the disk and the server.

NOTE and TODO: In the refit a specific module test, we randomly select one module
from the transformer and vae to refit the server and keep other modules the same.
As described above, the vae's weights are perturbed. If we select the vae to be the
target module, ideally speaking, we should assert that the refitted vae's checksum
is the same as directly computed from the perturbed vae weights in the disk. However,
since the there is complex weight-name remapping and QKV merge during model loading,
it is not easy to compare the server-disk checksum for vae and text encoder directly.
Therefore, if the target module is vae, we only verify that the refitted vae's checksum
is different from the base model's vae's checksum.

It should be good issue to solve for the community to adds comparison the server-disk
checksum for vae and text encoder in this test.

@zhaochenyang20
Copy link
Collaborator

Further profiling:

#18979

t = tensor.detach()
# DTensor doesn't support .numpy(); extract the local tensor.
if isinstance(t, DTensor):
t = t._local_tensor
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since some of the DTensors may be sharded across the devices and local_tensor is only the tensor on current device, do we need all-gather or some hash value merging logics here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, let me check.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Implement update_weights_from_disk for SGLang-D (Diffusion Engine)

4 participants