Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c93c103
[Feature] Implement update_weights_from_disk for SGLang-D (Diffusion …
dreamyang-liu Feb 5, 2026
7703c69
[diffusion] refactor: extract WeightsUpdater for update_weights_from_…
dreamyang-liu Feb 6, 2026
886a6bc
chore: isort lint
dreamyang-liu Feb 6, 2026
3d19d61
[diffusion] offload-aware weight updates and cleanup
dreamyang-liu Feb 7, 2026
d0e2fec
[diffusion] refactor: move post-training API to dedicated package and…
dreamyang-liu Feb 8, 2026
07cc679
[diffusion] refactor: extract get_updatable_modules and harden Weight…
dreamyang-liu Feb 10, 2026
c442af2
[diffusion] address comments
dreamyang-liu Feb 11, 2026
881d8b3
adds doc string to diffusion rifit test
zhaochenyang20 Feb 11, 2026
dfc93f6
[diffusion] Add /get_weights_checksum endpoint for SHA-256 weight ver…
dreamyang-liu Feb 12, 2026
68902f3
[diffusion] Add corrupted-weight rollback test for update_weights_fro…
dreamyang-liu Feb 12, 2026
719d31f
[diffusion] Parametrize weight-update tests over FLUX and Qwen model …
dreamyang-liu Feb 13, 2026
81d585b
[diffusion] Optimize weight-update tests
dreamyang-liu Feb 13, 2026
85b646c
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 14, 2026
32a743b
[TODO] model weights is updated only once
zhaochenyang20 Feb 14, 2026
c35eecd
Deduplicated tests; Should clean up
zhaochenyang20 Feb 14, 2026
41148c4
clean up codes with mixin; currently spanning 16mins; too long; shoul…
zhaochenyang20 Feb 14, 2026
14e69ec
Update docstring for GetWeightsChecksumReqInput
zhaochenyang20 Feb 14, 2026
b64b185
Refine docstring for weight checksum verification
zhaochenyang20 Feb 14, 2026
d68da5c
Simplify comments in layerwise offload method
zhaochenyang20 Feb 14, 2026
d3728cb
Improve documentation for iter_materialized_weights
zhaochenyang20 Feb 14, 2026
f831d94
fix lint
zhaochenyang20 Feb 15, 2026
bd23519
Refactor update_weights_from_disk tests
dreamyang-liu Feb 15, 2026
22b32c4
Merge branch 'main' into feat/diffusion-update-weights-from-disk
zhaochenyang20 Feb 16, 2026
a9f6808
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 16, 2026
a32aed9
Merge branch 'main' into feat/diffusion-update-weights-from-disk
zhaochenyang20 Feb 16, 2026
66e3c17
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 16, 2026
52ba35b
new docs string
zhaochenyang20 Feb 16, 2026
c3d478e
remove one line function
zhaochenyang20 Feb 16, 2026
cde71fe
consolidate rollback tests
zhaochenyang20 Feb 16, 2026
40bba8d
finalize the test
zhaochenyang20 Feb 16, 2026
7467682
fix CI random choice
zhaochenyang20 Feb 16, 2026
3100b2f
fix paring issue
zhaochenyang20 Feb 16, 2026
fb87570
incline path finding
zhaochenyang20 Feb 17, 2026
fab939e
remove redundant comments
zhaochenyang20 Feb 17, 2026
9e030f0
Merge branch 'main' into feat/diffusion-update-weights-from-disk
zhaochenyang20 Feb 18, 2026
2b44290
Merge branch 'feat/diffusion-update-weights-from-disk' of github.com:…
zhaochenyang20 Feb 18, 2026
f60f638
fix isort
zhaochenyang20 Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/advanced_features/sglang_for_rl.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,29 @@ This path trades some I/O overhead for simplicity and flexibility. It integrates

**Python Engine API:** `engine.update_weights_from_disk(model_path, load_format=None)`

**Diffusion engine (SGLang-Diffusion):** The diffusion engine exposes the same `POST /update_weights_from_disk` endpoint with the following behavior:

- **All-or-nothing with rollback:** if any module fails to load, all previously updated modules are rolled back to the original weights by reloading from the original model path. No partial updates are left behind. If rollback itself fails, the exception propagates so the caller knows the model is in an inconsistent state.
- **Offload-aware:** when layerwise offload (`--dit-layerwise-offload`) is enabled, the diffusion offload manager replaces GPU parameters with small `torch.empty((1,))` placeholders while real weights live in consolidated pinned CPU buffers. A naive `param.data.copy_()` would fail with a shape mismatch. Instead, the updater dynamically detects active offload managers and writes new weights directly into their CPU buffers, bypassing the placeholders entirely. For any layer that happens to be prefetched on GPU at update time, the live GPU tensor is also updated so the change takes effect immediately. This requires no extra GPU memory and does not disturb the offload state.
- **DTensor-aware:** parameters distributed via `torch.distributed.tensor` (tensor parallelism) are updated through `distribute_tensor` so that each shard is correctly placed on the right device mesh.

**Request body:**

| Field | Description | Defaults | Options |
| --- | --- | --- | --- |
| `model_path` | The model path with the new weights. | Required | Type: str |
| `flush_cache` | Flush TeaCache state after update. | `True` | Type: bool |
| `target_modules` | List of module names to update (e.g. `["transformer"]`). If omitted, all `nn.Module` components are updated. | `None` | Type: list[str] |

**Response body:**

| Field | Description | Defaults | Options |
| --- | --- | --- | --- |
| `success` | Whether the update succeeded. | - | Type: bool |
| `message` | Status / error message. | - | Type: str |

> **Note:** The diffusion engine (SGLang-Diffusion) does not currently support hot refit (updating weights while inference is in progress). The diffusion scheduler processes one request at a time and completes the entire inference before handling the next request, so weight updates and inference never run concurrently.

### Update Weights from Tensor

**When to use:**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
VertexGenerateReqInput,
)
from sglang.multimodal_gen.runtime.entrypoints.openai.utils import build_sampling_params
from sglang.multimodal_gen.runtime.entrypoints.post_training import weights_api
from sglang.multimodal_gen.runtime.entrypoints.utils import (
prepare_request,
save_outputs,
Expand Down Expand Up @@ -214,6 +215,7 @@ def create_app(server_args: ServerArgs):
app.include_router(common_api.router)
app.include_router(image_api.router)
app.include_router(video_api.router)
app.include_router(weights_api.router)

app.state.server_args = server_args
return app
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Request/response data structures for post-training APIs."""

from dataclasses import dataclass


@dataclass
class UpdateWeightFromDiskReqInput:
"""Request to update model weights from disk for diffusion models."""

model_path: str
flush_cache: bool = True
target_modules: list[str] | None = None


@dataclass
class GetWeightsChecksumReqInput:
"""Compute SHA-256 checksum of loaded module weights for verification."""

module_names: list[str] | None = None
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Weight update API for the diffusion engine."""

from fastapi import APIRouter, Request
from fastapi.responses import ORJSONResponse

from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import (
GetWeightsChecksumReqInput,
UpdateWeightFromDiskReqInput,
)
from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client

router = APIRouter()


@router.post("/update_weights_from_disk")
async def update_weights_from_disk(request: Request):
"""Update model weights from disk inplace without restarting the server."""
body = await request.json()
model_path = body.get("model_path")
if not model_path:
return ORJSONResponse(
{"success": False, "message": "model_path is required"},
status_code=400,
)

req = UpdateWeightFromDiskReqInput(
model_path=model_path,
flush_cache=body.get("flush_cache", True),
target_modules=body.get("target_modules"),
)

try:
response = await async_scheduler_client.forward(req)
except Exception as e:
return ORJSONResponse(
{"success": False, "message": str(e)},
status_code=500,
)

result = response.output
success = result.get("success", False)
message = result.get("message", "Unknown status")
return ORJSONResponse(
{"success": success, "message": message},
status_code=200 if success else 400,
)


@router.post("/get_weights_checksum")
async def get_weights_checksum(request: Request):
"""Return SHA-256 checksum of each requested module's weights."""
body = await request.json()
req = GetWeightsChecksumReqInput(
module_names=body.get("module_names"),
)

try:
response = await async_scheduler_client.forward(req)
except Exception as e:
return ORJSONResponse({"error": str(e)}, status_code=500)

return ORJSONResponse(response.output, status_code=200)
25 changes: 23 additions & 2 deletions python/sglang/multimodal_gen/runtime/loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@

# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/model_loader/weight_utils.py
"""Utilities for downloading and initializing model weights."""
"""Utilities for downloading, loading, initializing and verifying model weights."""

import hashlib
import json
import os
import tempfile
from collections.abc import Generator
from collections.abc import Generator, Iterable
from pathlib import Path

import filelock
import huggingface_hub.constants
import torch
from safetensors.torch import safe_open
from torch.distributed.tensor import DTensor
from tqdm.auto import tqdm

try:
Expand Down Expand Up @@ -336,3 +337,23 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:

# If there were no matches, return the untouched param name
return name


def compute_weights_checksum(
named_params: Iterable[tuple[str, torch.Tensor]],
) -> str:
"""Compute a SHA-256 checksum for a set of (name, tensor) pairs.

Used to verify the correctness of weight refitting. After a refit,
compare the checksum of the in-GPU model weights against the checksum
of the on-disk tensors or the tensors in the training engine.
"""
hasher = hashlib.sha256()
for name, tensor in sorted(named_params, key=lambda x: x[0]):
hasher.update(name.encode())
t = tensor.detach()
# DTensor doesn't support .numpy(); extract the local tensor.
if isinstance(t, DTensor):
t = t._local_tensor
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since some of the DTensors may be sharded across the devices and local_tensor is only the tensor on current device, do we need all-gather or some hash value merging logics here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, let me check.

hasher.update(t.cpu().contiguous().reshape(-1).view(torch.uint8).numpy().data)
return hasher.hexdigest()
Loading
Loading