Skip to content

[diffusion] feat: support saving videos directly on the server to avoid the overhead of tensor transfer#18253

Merged
mickqian merged 3 commits intosgl-project:mainfrom
nono-Sang:optim_video_save
Feb 7, 2026
Merged

[diffusion] feat: support saving videos directly on the server to avoid the overhead of tensor transfer#18253
mickqian merged 3 commits intosgl-project:mainfrom
nono-Sang:optim_video_save

Conversation

@nono-Sang
Copy link
Contributor

@nono-Sang nono-Sang commented Feb 4, 2026

Motivation

The scheduler_client connects with the gpu_worker via ZMQ. The workflow is as follows:

  • sglang generate: diffusion_generate --(sync_scheduler_client.forward)--> gpu_worker

  • sglang serve: http_server --(async_scheduler_client.forward)--> gpu_worker

Old method: The scheduler_client and gpu_worker exchanged output tensors, which introduced serialization and deserialization overhead.

New method: The gpu_worker directly processes and saves the output tensor as video, then returns the file path to the scheduler_client.

Modifications

Main code changes:

  1. A common save function is implemented in multimodal_gen/runtime/entrypoints/utils.py. The saving logic in the following sections now uses this common function:
  • multimodal_gen/runtime/entrypoints/openai/utils.py

  • multimodal_gen/runtime/managers/gpu_worker.py

  • multimodal_gen/runtime/entrypoints/http_server.py

  • multimodal_gen/runtime/entrypoints/diffusion_generator.py

  1. The default value of return_file_paths_only is set to True.

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@github-actions github-actions bot added the diffusion SGLang Diffusion label Feb 4, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nono-Sang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request implements a significant optimization for multimodal generation by allowing video and audio outputs to be saved directly on the server. This change introduces a mechanism to return only the file paths to the client, thereby eliminating the need to transfer potentially large tensor data across the network. The modification enhances the system's efficiency and responsiveness, particularly for applications involving high-volume video generation.

Highlights

  • New Feature: Server-Side Video Saving: Introduced a new return_file_paths_only parameter and corresponding CLI argument. When enabled, generated video and audio content is saved directly on the server, and only the file paths are returned to the client.
  • Reduced Data Transfer Overhead: The primary goal of this feature is to avoid the overhead of transferring large video/audio tensors over the network, improving performance and efficiency for multimodal generation tasks.
  • GPU Worker Direct File Saving: The GPU worker is now responsible for performing the post_process_sample operation and saving the output to files directly, then nullifying the output tensors before sending the response.
  • Client-Side Adaptations: The diffusion generator, HTTP server, and OpenAI utility entry points have been updated to handle the new output_file_paths field, prioritizing it over raw tensor outputs when the feature is active.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/multimodal_gen/configs/sample/sampling_params.py
    • Added return_file_paths_only: bool = False to the SamplingParams class.
    • Integrated --return-file-paths-only as a new command-line argument.
  • python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py
    • Modified the generate function to conditionally process and return output_file_paths when req.return_file_paths_only is true.
    • Updated the check for empty output to also consider output_batch.output_file_paths.
  • python/sglang/multimodal_gen/runtime/entrypoints/http_server.py
    • Adjusted forward_to_scheduler to check for and utilize response.output_file_paths if available, bypassing tensor processing.
    • Updated the condition for raising a RuntimeError to include response.output_file_paths.
  • python/sglang/multimodal_gen/runtime/entrypoints/openai/utils.py
    • Updated process_generation_batch to handle result.output_file_paths directly, appending them to save_file_path_list.
    • Modified the empty output check to account for result.output_file_paths.
  • python/sglang/multimodal_gen/runtime/managers/gpu_worker.py
    • Implemented logic within execute_forward to save generated samples and audio to files directly on the GPU worker when req.return_file_paths_only is active.
    • Populated output_batch.output_file_paths with the saved file paths and set output_batch.output = None to prevent tensor transfer.
    • Added imports for numpy, DataType, and post_process_sample.
    • Refactored the error handling by removing the finally block.
  • python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py
    • Removed the default fps value from the Req class definition.
    • Added output_file_paths: list[str] | None = None to the OutputBatch class.
    • Included fps in the logging output of the Req object.
Activity
  • No specific human activity (comments, reviews, or progress updates) has been logged for this pull request yet, beyond its initial creation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable optimization by allowing videos to be saved directly on the server, which avoids the performance overhead of transferring large tensors. The changes are implemented across several files, including sampling parameters, the diffusion generator, and the GPU worker, and appear to correctly implement the new return_file_paths_only feature. My review includes two suggestions to address code duplication, which would improve the overall maintainability of the code.

Comment on lines +243 to +267
if req.save_output and req.return_file_paths_only:
for output_idx, output_path in enumerate(
output_batch.output_file_paths
):
result_item: dict[str, Any] = {
"samples": None,
"frames": None,
"audio": None,
"prompts": req.prompt,
"size": (req.height, req.width, req.num_frames),
"generation_time": timer.duration,
"peak_memory_mb": output_batch.peak_memory_mb,
"timings": (
output_batch.timings.to_dict()
if output_batch.timings
else {}
),
"trajectory": output_batch.trajectory_latents,
"trajectory_timesteps": output_batch.trajectory_timesteps,
"trajectory_decoded": output_batch.trajectory_decoded,
"prompt_index": output_idx,
"output_file_path": output_path,
}
results.append(result_item)
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant duplication in how result_item is constructed in this new block and in the existing code that handles the case where return_file_paths_only is false (located at lines 301-318 in the full file). Many fields are identical across both dictionary creations.

To make the code more maintainable and adhere to the DRY (Don't Repeat Yourself) principle, I recommend refactoring this. You could define a base dictionary containing all the common fields before this if statement, and then use dictionary unpacking to add the specific fields for each case.

For example:

# Create a base dictionary with common fields
base_result_item = {
    "prompts": req.prompt,
    "size": (req.height, req.width, req.num_frames),
    "generation_time": timer.duration,
    # ... other common fields
}

if req.save_output and req.return_file_paths_only:
    for output_idx, output_path in enumerate(output_batch.output_file_paths):
        result_item = {
            **base_result_item,
            "samples": None,
            "frames": None,
            "audio": None,
            "prompt_index": output_idx,
            "output_file_path": output_path,
        }
        results.append(result_item)
    continue

# Later in the function for the other case:
for output_idx, sample in enumerate(output_batch.output):
    # ...
    result_item = {
        **base_result_item,
        "samples": sample,
        "frames": frames,
        "audio": audio,
        "prompt_index": output_idx,
    }
    results.append(result_item)

Comment on lines +191 to +228
# Save output to file and return file path only if requested. Avoid the serialization
# and deserialization overhead between scheduler_client and GPU worker.
if req.save_output and req.return_file_paths_only:
output_paths: list[str] = []
# Copied from DiffGenerator.generate()
audio_sample_rate = output_batch.audio_sample_rate
for output_idx, sample in enumerate(output_batch.output):
num_outputs = len(output_batch.output)
audio = output_batch.audio
if req.data_type == DataType.VIDEO:
if isinstance(audio, torch.Tensor) and audio.ndim >= 2:
audio = (
audio[output_idx]
if audio.shape[0] > output_idx
else None
)
elif isinstance(audio, np.ndarray) and audio.ndim >= 2:
audio = (
audio[output_idx]
if audio.shape[0] > output_idx
else None
)
if audio is not None and not (
isinstance(sample, (tuple, list)) and len(sample) == 2
):
sample = (sample, audio)
save_file_path = req.output_file_path(num_outputs, output_idx)
post_process_sample(
sample,
fps=req.fps,
save_output=True,
save_file_path=save_file_path,
data_type=req.data_type,
audio_sample_rate=audio_sample_rate,
)
output_paths.append(save_file_path)
output_batch.output_file_paths = output_paths
output_batch.output = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for preparing video samples by pairing them with audio (lines 199-216) is duplicated from other parts of the codebase, as acknowledged by the comment # Copied from DiffGenerator.generate(). This same logic can also be found in diffusion_generator.py and runtime/entrypoints/openai/utils.py.

To improve code maintainability and avoid future inconsistencies, this duplicated logic should be extracted into a single utility function. A good location for this function would be in python/sglang/multimodal_gen/runtime/entrypoints/utils.py, alongside post_process_sample.

For example, you could create a helper function:

def pair_video_sample_with_audio(sample, audio_batch, output_idx):
    """Pairs a video sample with its corresponding audio from a batch if available."""
    audio_for_sample = audio_batch
    if isinstance(audio_batch, (torch.Tensor, np.ndarray)) and audio_batch.ndim >= 2:
        audio_for_sample = audio_batch[output_idx] if audio_batch.shape[0] > output_idx else None

    if audio_for_sample is not None and not (isinstance(sample, (tuple, list)) and len(sample) == 2):
        return (sample, audio_for_sample)
    return sample

Using this helper would simplify the code here and in the other locations where this logic is repeated.

@mickqian
Copy link
Collaborator

mickqian commented Feb 7, 2026

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Feb 7, 2026
@mickqian mickqian merged commit 64950d8 into sgl-project:main Feb 7, 2026
103 of 107 checks passed
charlesHsuGG pushed a commit to charlesHsuGG/sglang that referenced this pull request Feb 9, 2026
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
1StepForever pushed a commit to 1StepForever/sglang that referenced this pull request Feb 26, 2026
* www/pr/ks: (265 commits)
  [BugFix][PD]Fix metadata_buffer_index leak when aborted in PD (sgl-project#17483)
  Refactoring Mooncake TE as a shared distributed component (sgl-project#17810)
  [ModelOPT] Support Qwen 3 Next Coder NVFP4 (sgl-project#18224)
  Update author information in pyproject.toml (sgl-project#18453)
  [Kimi-K2.5] Fix missing `quant_config` in `KimiK25` (sgl-project#18440)
  Add tensor parallelism support to LFM2 ShortConv layers (sgl-project#17777)
  [diffusion] chore: revise process title (sgl-project#18446)
  Fix TRT-LLM MLA backend applying k_scale to BF16 KV cache in BMM1 (sgl-project#18396)
  [diffusion] refactor: group component loaders under the component_loaders/ directory (sgl-project#18438)
  [ModelOpt] Fix broken Qwen3-235B-A22B-Instruct-2507-NVFP4 launch (sgl-project#18189)
  [diffusion] feat: support efficient sequence shard (sgl-project#18161)
  [CI] fix: notebook ci may not working (sgl-project#18417)
  fix: sync server_args.kv_cache_dtype when detecting FP8 KV cache (sgl-project#18394)
  [Fix] Fix backend selection after flashinfer version update (sgl-project#18364)
  [diffusion] platform: support WAN/FLUX/Qwen-Image/Qwen-Image-edit on Ascend (sgl-project#13662)
  fix: fix NVFP4 Kimi-K2.5 weight mapping and exclude list (sgl-project#18370)
  [diffusion] feat: support saving videos directly on the server to avoid the overhead of tensor transfer (sgl-project#18253)
  [diffusion] fix: respect dist_timeout option (sgl-project#18386)
  [Doc] Fix outdated `--fp4-gemm-backend` documentation (sgl-project#18350)
  [diffusion] fix: remove unnecessary norm_type argument from GLM-Image dits (sgl-project#18382)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants