-
Notifications
You must be signed in to change notification settings - Fork 4.7k
[Feature] Implement update_weights_from_disk for SGLang-D (Diffusion … #18306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 32 commits
c93c103
7703c69
886a6bc
3d19d61
d0e2fec
07cc679
c442af2
881d8b3
dfc93f6
68902f3
719d31f
81d585b
85b646c
32a743b
c35eecd
41148c4
14e69ec
b64b185
d68da5c
d3728cb
f831d94
bd23519
22b32c4
a9f6808
a32aed9
66e3c17
52ba35b
c3d478e
cde71fe
40bba8d
7467682
3100b2f
fb87570
fab939e
9e030f0
2b44290
f60f638
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
dreamyang-liu marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| """Request/response data structures for post-training APIs.""" | ||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
|
|
||
| @dataclass | ||
| class UpdateWeightFromDiskReqInput: | ||
| """Request to update model weights from disk for diffusion models.""" | ||
|
|
||
| model_path: str | ||
| flush_cache: bool = True | ||
| target_modules: list[str] | None = None | ||
|
|
||
|
|
||
| @dataclass | ||
| class GetWeightsChecksumReqInput: | ||
| """Compute SHA-256 checksum of loaded module weights for verification.""" | ||
|
|
||
| module_names: list[str] | None = None |
zhaochenyang20 marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| """Weight update API for the diffusion engine.""" | ||
|
|
||
| from fastapi import APIRouter, Request | ||
| from fastapi.responses import ORJSONResponse | ||
|
|
||
| from sglang.multimodal_gen.runtime.entrypoints.post_training.io_struct import ( | ||
| GetWeightsChecksumReqInput, | ||
| UpdateWeightFromDiskReqInput, | ||
| ) | ||
| from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client | ||
|
|
||
| router = APIRouter() | ||
|
|
||
|
|
||
| @router.post("/update_weights_from_disk") | ||
| async def update_weights_from_disk(request: Request): | ||
| """Update model weights from disk inplace without restarting the server.""" | ||
| body = await request.json() | ||
| model_path = body.get("model_path") | ||
| if not model_path: | ||
| return ORJSONResponse( | ||
| {"success": False, "message": "model_path is required"}, | ||
| status_code=400, | ||
| ) | ||
|
|
||
| req = UpdateWeightFromDiskReqInput( | ||
| model_path=model_path, | ||
| flush_cache=body.get("flush_cache", True), | ||
| target_modules=body.get("target_modules"), | ||
zhaochenyang20 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
| try: | ||
| response = await async_scheduler_client.forward(req) | ||
| except Exception as e: | ||
| return ORJSONResponse( | ||
| {"success": False, "message": str(e)}, | ||
| status_code=500, | ||
| ) | ||
|
|
||
| result = response.output | ||
| success = result.get("success", False) | ||
| message = result.get("message", "Unknown status") | ||
| return ORJSONResponse( | ||
| {"success": success, "message": message}, | ||
| status_code=200 if success else 400, | ||
| ) | ||
|
|
||
|
|
||
| @router.post("/get_weights_checksum") | ||
| async def get_weights_checksum(request: Request): | ||
| """Return SHA-256 checksum of each requested module's weights.""" | ||
| body = await request.json() | ||
| req = GetWeightsChecksumReqInput( | ||
| module_names=body.get("module_names"), | ||
| ) | ||
|
|
||
| try: | ||
| response = await async_scheduler_client.forward(req) | ||
| except Exception as e: | ||
| return ORJSONResponse({"error": str(e)}, status_code=500) | ||
|
|
||
| return ORJSONResponse(response.output, status_code=200) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,19 +2,20 @@ | |
|
|
||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/model_loader/weight_utils.py | ||
| """Utilities for downloading and initializing model weights.""" | ||
| """Utilities for downloading, loading, initializing and verifying model weights.""" | ||
|
|
||
| import hashlib | ||
| import json | ||
| import os | ||
| import tempfile | ||
| from collections.abc import Generator | ||
| from collections.abc import Generator, Iterable | ||
| from pathlib import Path | ||
|
|
||
| import filelock | ||
| import huggingface_hub.constants | ||
| import torch | ||
| from safetensors.torch import safe_open | ||
| from torch.distributed.tensor import DTensor | ||
| from tqdm.auto import tqdm | ||
|
|
||
| try: | ||
|
|
@@ -336,3 +337,23 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: | |
|
|
||
| # If there were no matches, return the untouched param name | ||
| return name | ||
|
|
||
|
|
||
| def compute_weights_checksum( | ||
| named_params: Iterable[tuple[str, torch.Tensor]], | ||
| ) -> str: | ||
| """Compute a SHA-256 checksum for a set of (name, tensor) pairs. | ||
|
|
||
| Used to verify the correctness of weight refitting. After a refit, | ||
| compare the checksum of the in-GPU model weights against the checksum | ||
| of the on-disk tensors or the tensors in the training engine. | ||
| """ | ||
zhaochenyang20 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| hasher = hashlib.sha256() | ||
| for name, tensor in sorted(named_params, key=lambda x: x[0]): | ||
| hasher.update(name.encode()) | ||
| t = tensor.detach() | ||
| # DTensor doesn't support .numpy(); extract the local tensor. | ||
| if isinstance(t, DTensor): | ||
| t = t._local_tensor | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since some of the DTensors may be sharded across the devices and local_tensor is only the tensor on current device, do we need all-gather or some hash value merging logics here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question, let me check. |
||
| hasher.update(t.cpu().contiguous().reshape(-1).view(torch.uint8).numpy().data) | ||
| return hasher.hexdigest() | ||
Uh oh!
There was an error while loading. Please reload this page.