Skip to content

[diffusion] feat: support parallel wan-vae decode#18179

Merged
mickqian merged 7 commits intosgl-project:mainfrom
nono-Sang:parallel_wanvae_encode
Feb 10, 2026
Merged

[diffusion] feat: support parallel wan-vae decode#18179
mickqian merged 7 commits intosgl-project:mainfrom
nono-Sang:parallel_wanvae_encode

Conversation

@nono-Sang
Copy link
Contributor

@nono-Sang nono-Sang commented Feb 3, 2026

credits @Songrui625

Close #16510

Motivation

Implement parallel wan-vae encoding based on url

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@github-actions github-actions bot added the diffusion SGLang Diffusion label Feb 3, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nono-Sang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Wan-VAE model by integrating spatial parallelism for both its encoder and decoder components. The primary goal is to improve performance and scalability by distributing computational load across multiple devices. This is achieved through the introduction of specialized distributed layers and utility functions that handle data partitioning, padding, and efficient communication of boundary data (halo exchange) between parallel processes. The changes are designed to be configurable, allowing users to enable or disable parallel operations as needed.

Highlights

  • Parallel Encoding Support: Introduced the capability for parallel encoding within the Wan-VAE model, allowing for distributed processing of input data across multiple devices.
  • Distributed Module Implementation: Added new distributed versions of core neural network layers, including convolutional layers (WanDistConv2d, WanDistCausalConv3d), resampling (WanDistResample), residual blocks (WanDistResidualBlock), attention blocks (WanDistAttentionBlock), and various up/downsampling blocks.
  • Spatial Parallelism Utilities: Implemented helper functions such as sp_should_padding, sp_pad, sp_chunk, and halo_exchange to manage data distribution, padding, and boundary data exchange for spatial parallelism.
  • Configurability: Added use_parallel_encode and use_parallel_decode flags to the WanVAEConfig to enable or disable parallel processing for encoding and decoding paths, respectively.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/multimodal_gen/configs/models/vaes/wanvae.py
    • Added use_parallel_decode and use_parallel_encode boolean fields to WanVAEConfig.
  • python/sglang/multimodal_gen/runtime/models/vaes/wanvae.py
    • Imported math, torch.distributed, and parallel state utilities.
    • Implemented sp_should_padding, sp_pad, sp_chunk for data partitioning and padding.
    • Introduced halo_exchange function for boundary data communication in distributed settings.
    • Defined WanDistConv2d and WanDistCausalConv3d as distributed convolutional layers.
    • Added WanDistZeroPad2d for distributed padding operations.
    • Refactored WanResample and WanResidualBlock forward logic into reusable helper functions (resample_forward, _residual_block_forward).
    • Created distributed counterparts for various model components: WanDistResample, WanDistResidualBlock, WanDistAttentionBlock, WanDistMidBlock, WanDistResidualDownBlock, WanDistResidualUpBlock, and WanDistUpBlock.
    • Modified WanEncoder3d and WanDecoder3d to conditionally instantiate and utilize these new distributed modules based on the use_parallel_encode and use_parallel_decode flags.
    • Added logic within WanEncoder3d.forward to handle input data preparation (padding, chunking) for parallel encoding and output gathering.
    • Added logic within WanDecoder3d.forward to handle input data preparation (chunking) for parallel decoding and output gathering.
    • Updated the _count_conv3d utility to recognize both standard and distributed causal convolution layers.
Activity
  • The author, nono-Sang, has opened this pull request to implement parallel wan-vae encoding.
  • The motivation for the changes has been provided, referencing a related pull request.
  • Details regarding specific modifications, accuracy tests, and benchmarking results are currently not filled out in the PR description.
  • The checklist for code formatting, unit tests, documentation, and benchmarks remains unchecked.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for parallel VAE encoding by introducing distributed versions of several modules, such as Conv2d, CausalConv3d, Resample, ResidualBlock, and AttentionBlock. The implementation of the distributed logic appears sound, utilizing halo exchanges and all-gather operations where appropriate. However, the current approach of using large if/else blocks in WanEncoder3d and WanDecoder3d to switch between parallel and non-parallel implementations has resulted in significant code duplication. This makes the code harder to maintain and read. I've provided suggestions to refactor this to improve code quality. There are also a couple of smaller issues noted for improvement.

Comment on lines +1175 to +1272
if use_parallel_encode and world_size > 1:
# init block
self.conv_in = WanDistCausalConv3d(in_channels, dims[0], 3, padding=1)

# downsample blocks
self.down_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(
zip(dims[:-1], dims[1:], strict=True)
):
# residual (+attention) blocks
if is_residual:
self.down_blocks.append(
WanDistResidualDownBlock(
in_dim,
out_dim,
dropout,
num_res_blocks,
temperal_downsample=(
temperal_downsample[i]
if i != len(dim_mult) - 1
else False
),
down_flag=i != len(dim_mult) - 1,
)
)
)
else:
for _ in range(num_res_blocks):
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.down_blocks.append(WanAttentionBlock(out_dim))
in_dim = out_dim
else:
for _ in range(num_res_blocks):
self.down_blocks.append(
WanDistResidualBlock(in_dim, out_dim, dropout)
)
if scale in attn_scales:
self.down_blocks.append(WanDistAttentionBlock(out_dim))
in_dim = out_dim

# downsample block
if i != len(dim_mult) - 1:
mode = (
"downsample3d" if temperal_downsample[i] else "downsample2d"
)
self.down_blocks.append(WanDistResample(out_dim, mode=mode))
scale /= 2.0

# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.down_blocks.append(WanResample(out_dim, mode=mode))
scale /= 2.0
# middle blocks
self.mid_block = WanDistMidBlock(
out_dim, dropout, non_linearity, num_layers=1
)

# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanDistCausalConv3d(out_dim, z_dim, 3, padding=1)
else:
# init block
self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)

# downsample blocks
self.down_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(
zip(dims[:-1], dims[1:], strict=True)
):
# residual (+attention) blocks
if is_residual:
self.down_blocks.append(
WanResidualDownBlock(
in_dim,
out_dim,
dropout,
num_res_blocks,
temperal_downsample=(
temperal_downsample[i]
if i != len(dim_mult) - 1
else False
),
down_flag=i != len(dim_mult) - 1,
)
)
else:
for _ in range(num_res_blocks):
self.down_blocks.append(
WanResidualBlock(in_dim, out_dim, dropout)
)
if scale in attn_scales:
self.down_blocks.append(WanAttentionBlock(out_dim))
in_dim = out_dim

# downsample block
if i != len(dim_mult) - 1:
mode = (
"downsample3d" if temperal_downsample[i] else "downsample2d"
)
self.down_blocks.append(WanResample(out_dim, mode=mode))
scale /= 2.0

# middle blocks
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
# middle blocks
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)

# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This if/else block introduces a large amount of code duplication. The logic for building the encoder is nearly identical in both branches, with the only difference being the module classes used (Wan... vs WanDist...). This significantly impacts readability and maintainability, as any change to the model architecture must be applied in two places.

A better approach is to dynamically select the module classes based on use_parallel_encode and then use a single, non-duplicated block of code for model construction. For example:

if use_parallel_encode and world_size > 1:
    CausalConv3d = WanDistCausalConv3d
    ResidualDownBlock = WanDistResidualDownBlock
    ResidualBlock = WanDistResidualBlock
    AttentionBlock = WanDistAttentionBlock
    Resample = WanDistResample
    MidBlock = WanDistMidBlock
else:
    CausalConv3d = WanCausalConv3d
    ResidualDownBlock = WanResidualDownBlock
    ResidualBlock = WanResidualBlock
    AttentionBlock = WanAttentionBlock
    Resample = WanResample
    MidBlock = WanMidBlock

# init block
self.conv_in = CausalConv3d(in_channels, dims[0], 3, padding=1)

# downsample blocks
self.down_blocks = nn.ModuleList([])
# ... (rest of the logic using the dynamically chosen classes)

This refactoring would make the code much cleaner and easier to maintain.

Comment on lines +1690 to +1796
if use_parallel_decode and world_size > 1:
# init block
self.conv_in = WanDistCausalConv3d(z_dim, dims[0], 3, padding=1)

# middle blocks
self.mid_block = WanDistMidBlock(
dims[0], dropout, non_linearity, num_layers=1
)

# upsample blocks
self.upsample_count = 0
self.up_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(
zip(dims[:-1], dims[1:], strict=True)
):
# residual (+attention) blocks
if i > 0 and not is_residual:
# wan vae 2.1
in_dim = in_dim // 2

# determine if we need upsampling
up_flag = i != len(dim_mult) - 1
# determine upsampling mode, if not upsampling, set to None
upsample_mode = None
if up_flag and temperal_upsample[i]:
upsample_mode = "upsample3d"
elif up_flag:
upsample_mode = "upsample2d"

# Create and add the upsampling block
if is_residual:
up_block = WanDistResidualUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
temperal_upsample=temperal_upsample[i] if up_flag else False,
up_flag=up_flag,
non_linearity=non_linearity,
)
else:
up_block = WanDistUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
self.up_blocks.append(up_block)
if up_flag:
self.upsample_count += 1

# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanDistCausalConv3d(out_dim, out_channels, 3, padding=1)
else:
# init block
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)

# middle blocks
self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)

# upsample blocks
self.up_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(
zip(dims[:-1], dims[1:], strict=True)
):
# residual (+attention) blocks
if i > 0 and not is_residual:
# wan vae 2.1
in_dim = in_dim // 2

# determine if we need upsampling
up_flag = i != len(dim_mult) - 1
# determine upsampling mode, if not upsampling, set to None
upsample_mode = None
if up_flag and temperal_upsample[i]:
upsample_mode = "upsample3d"
elif up_flag:
upsample_mode = "upsample2d"

# Create and add the upsampling block
if is_residual:
up_block = WanResidualUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
temperal_upsample=temperal_upsample[i] if up_flag else False,
up_flag=up_flag,
non_linearity=non_linearity,
)
else:
up_block = WanUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
self.up_blocks.append(up_block)

# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to WanEncoder3d, this if/else block in WanDecoder3d.__init__ duplicates a significant amount of code for model construction. This harms maintainability and readability.

Please refactor this by dynamically selecting the module classes based on use_parallel_decode to eliminate the duplicated model definition logic, as suggested for WanEncoder3d.

Comment on lines +425 to +435
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]

x = F.pad(x, padding)

x = (
x.to(self.weight.dtype) if current_platform.is_mps() else x
) # casting needed for mps since amp isn't supported
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for handling cache_x and MPS-specific dtype casting is duplicated from WanCausalConv3d.forward. To improve maintainability and reduce redundancy, consider extracting this shared logic into a common helper function. This function could be called by both WanCausalConv3d and WanDistCausalConv3d to prepare the input tensor x before the main convolution operation.

@mickqian
Copy link
Collaborator

mickqian commented Feb 5, 2026

@Songrui625 could you take a review?

parser.add_argument(
f"--{prefix}.use-parallel-encode",
action=StoreBoolean,
dest=f"{prefix.replace('-', '_')}.use_parallel_encode",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we make it true as default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Collaborator

@mickqian mickqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! I added several comments

x = self.resnets[0](x)

# Process through attention and residual blocks
for idx, (attn, resnet) in enumerate(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: idx is redundant



def residual_down_block_forward(self, x):
x_copy = x.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.clone() is redundant


def residual_up_block_forward(self, x):
if self.avg_shortcut is not None:
x_copy = x.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too

def forward(self, x):
world_size = 1
if dist.is_initialized():
world_size = get_sp_world_size()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we move it out of hot path?

expected_local_height = None
expected_height = None
if self.use_parallel_encode and world_size > 1:
rank = get_sp_parallel_rank()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

top_row = x[..., :height_halo_size, :].contiguous()
bottom_row = x[..., -height_halo_size:, :].contiguous()

recv_top_buf = torch.empty_like(top_row)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be cached, since the shape of the top_row is static?

group = sp_group.device_group
group_ranks = sp_group.ranks

top_row = x[..., :height_halo_size, :].contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.contiguous() might be redundant, since isend and irecv support non-contiguous tensors for most backends

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mickqian contiguous() isn’t strictly required since most backends can handle non‑contiguous tensors in isend/irecv, but they typically make an internal copy anyway. Keeping it makes the behavior explicit and avoids backend‑specific corner cases; if x is already contiguous it’s a no‑op. So it’s more about clarity/stability than necessity.

@mickqian mickqian changed the title [diffusion] feat: support parallel wan-vae encode [diffusion] feat: support parallel wan-vae decode Feb 9, 2026
@mickqian
Copy link
Collaborator

/tag-and-rerun-ci

@mickqian mickqian merged commit 47978ee into sgl-project:main Feb 10, 2026
132 of 139 checks passed
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants