[diffusion] feat: support parallel wan-vae decode#18179
[diffusion] feat: support parallel wan-vae decode#18179mickqian merged 7 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @nono-Sang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Wan-VAE model by integrating spatial parallelism for both its encoder and decoder components. The primary goal is to improve performance and scalability by distributing computational load across multiple devices. This is achieved through the introduction of specialized distributed layers and utility functions that handle data partitioning, padding, and efficient communication of boundary data (halo exchange) between parallel processes. The changes are designed to be configurable, allowing users to enable or disable parallel operations as needed. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for parallel VAE encoding by introducing distributed versions of several modules, such as Conv2d, CausalConv3d, Resample, ResidualBlock, and AttentionBlock. The implementation of the distributed logic appears sound, utilizing halo exchanges and all-gather operations where appropriate. However, the current approach of using large if/else blocks in WanEncoder3d and WanDecoder3d to switch between parallel and non-parallel implementations has resulted in significant code duplication. This makes the code harder to maintain and read. I've provided suggestions to refactor this to improve code quality. There are also a couple of smaller issues noted for improvement.
| if use_parallel_encode and world_size > 1: | ||
| # init block | ||
| self.conv_in = WanDistCausalConv3d(in_channels, dims[0], 3, padding=1) | ||
|
|
||
| # downsample blocks | ||
| self.down_blocks = nn.ModuleList([]) | ||
| for i, (in_dim, out_dim) in enumerate( | ||
| zip(dims[:-1], dims[1:], strict=True) | ||
| ): | ||
| # residual (+attention) blocks | ||
| if is_residual: | ||
| self.down_blocks.append( | ||
| WanDistResidualDownBlock( | ||
| in_dim, | ||
| out_dim, | ||
| dropout, | ||
| num_res_blocks, | ||
| temperal_downsample=( | ||
| temperal_downsample[i] | ||
| if i != len(dim_mult) - 1 | ||
| else False | ||
| ), | ||
| down_flag=i != len(dim_mult) - 1, | ||
| ) | ||
| ) | ||
| ) | ||
| else: | ||
| for _ in range(num_res_blocks): | ||
| self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout)) | ||
| if scale in attn_scales: | ||
| self.down_blocks.append(WanAttentionBlock(out_dim)) | ||
| in_dim = out_dim | ||
| else: | ||
| for _ in range(num_res_blocks): | ||
| self.down_blocks.append( | ||
| WanDistResidualBlock(in_dim, out_dim, dropout) | ||
| ) | ||
| if scale in attn_scales: | ||
| self.down_blocks.append(WanDistAttentionBlock(out_dim)) | ||
| in_dim = out_dim | ||
|
|
||
| # downsample block | ||
| if i != len(dim_mult) - 1: | ||
| mode = ( | ||
| "downsample3d" if temperal_downsample[i] else "downsample2d" | ||
| ) | ||
| self.down_blocks.append(WanDistResample(out_dim, mode=mode)) | ||
| scale /= 2.0 | ||
|
|
||
| # downsample block | ||
| if i != len(dim_mult) - 1: | ||
| mode = "downsample3d" if temperal_downsample[i] else "downsample2d" | ||
| self.down_blocks.append(WanResample(out_dim, mode=mode)) | ||
| scale /= 2.0 | ||
| # middle blocks | ||
| self.mid_block = WanDistMidBlock( | ||
| out_dim, dropout, non_linearity, num_layers=1 | ||
| ) | ||
|
|
||
| # output blocks | ||
| self.norm_out = WanRMS_norm(out_dim, images=False) | ||
| self.conv_out = WanDistCausalConv3d(out_dim, z_dim, 3, padding=1) | ||
| else: | ||
| # init block | ||
| self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1) | ||
|
|
||
| # downsample blocks | ||
| self.down_blocks = nn.ModuleList([]) | ||
| for i, (in_dim, out_dim) in enumerate( | ||
| zip(dims[:-1], dims[1:], strict=True) | ||
| ): | ||
| # residual (+attention) blocks | ||
| if is_residual: | ||
| self.down_blocks.append( | ||
| WanResidualDownBlock( | ||
| in_dim, | ||
| out_dim, | ||
| dropout, | ||
| num_res_blocks, | ||
| temperal_downsample=( | ||
| temperal_downsample[i] | ||
| if i != len(dim_mult) - 1 | ||
| else False | ||
| ), | ||
| down_flag=i != len(dim_mult) - 1, | ||
| ) | ||
| ) | ||
| else: | ||
| for _ in range(num_res_blocks): | ||
| self.down_blocks.append( | ||
| WanResidualBlock(in_dim, out_dim, dropout) | ||
| ) | ||
| if scale in attn_scales: | ||
| self.down_blocks.append(WanAttentionBlock(out_dim)) | ||
| in_dim = out_dim | ||
|
|
||
| # downsample block | ||
| if i != len(dim_mult) - 1: | ||
| mode = ( | ||
| "downsample3d" if temperal_downsample[i] else "downsample2d" | ||
| ) | ||
| self.down_blocks.append(WanResample(out_dim, mode=mode)) | ||
| scale /= 2.0 | ||
|
|
||
| # middle blocks | ||
| self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1) | ||
| # middle blocks | ||
| self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1) | ||
|
|
||
| # output blocks | ||
| self.norm_out = WanRMS_norm(out_dim, images=False) | ||
| self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1) | ||
| # output blocks | ||
| self.norm_out = WanRMS_norm(out_dim, images=False) | ||
| self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1) |
There was a problem hiding this comment.
This if/else block introduces a large amount of code duplication. The logic for building the encoder is nearly identical in both branches, with the only difference being the module classes used (Wan... vs WanDist...). This significantly impacts readability and maintainability, as any change to the model architecture must be applied in two places.
A better approach is to dynamically select the module classes based on use_parallel_encode and then use a single, non-duplicated block of code for model construction. For example:
if use_parallel_encode and world_size > 1:
CausalConv3d = WanDistCausalConv3d
ResidualDownBlock = WanDistResidualDownBlock
ResidualBlock = WanDistResidualBlock
AttentionBlock = WanDistAttentionBlock
Resample = WanDistResample
MidBlock = WanDistMidBlock
else:
CausalConv3d = WanCausalConv3d
ResidualDownBlock = WanResidualDownBlock
ResidualBlock = WanResidualBlock
AttentionBlock = WanAttentionBlock
Resample = WanResample
MidBlock = WanMidBlock
# init block
self.conv_in = CausalConv3d(in_channels, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
# ... (rest of the logic using the dynamically chosen classes)This refactoring would make the code much cleaner and easier to maintain.
| if use_parallel_decode and world_size > 1: | ||
| # init block | ||
| self.conv_in = WanDistCausalConv3d(z_dim, dims[0], 3, padding=1) | ||
|
|
||
| # middle blocks | ||
| self.mid_block = WanDistMidBlock( | ||
| dims[0], dropout, non_linearity, num_layers=1 | ||
| ) | ||
|
|
||
| # upsample blocks | ||
| self.upsample_count = 0 | ||
| self.up_blocks = nn.ModuleList([]) | ||
| for i, (in_dim, out_dim) in enumerate( | ||
| zip(dims[:-1], dims[1:], strict=True) | ||
| ): | ||
| # residual (+attention) blocks | ||
| if i > 0 and not is_residual: | ||
| # wan vae 2.1 | ||
| in_dim = in_dim // 2 | ||
|
|
||
| # determine if we need upsampling | ||
| up_flag = i != len(dim_mult) - 1 | ||
| # determine upsampling mode, if not upsampling, set to None | ||
| upsample_mode = None | ||
| if up_flag and temperal_upsample[i]: | ||
| upsample_mode = "upsample3d" | ||
| elif up_flag: | ||
| upsample_mode = "upsample2d" | ||
|
|
||
| # Create and add the upsampling block | ||
| if is_residual: | ||
| up_block = WanDistResidualUpBlock( | ||
| in_dim=in_dim, | ||
| out_dim=out_dim, | ||
| num_res_blocks=num_res_blocks, | ||
| dropout=dropout, | ||
| temperal_upsample=temperal_upsample[i] if up_flag else False, | ||
| up_flag=up_flag, | ||
| non_linearity=non_linearity, | ||
| ) | ||
| else: | ||
| up_block = WanDistUpBlock( | ||
| in_dim=in_dim, | ||
| out_dim=out_dim, | ||
| num_res_blocks=num_res_blocks, | ||
| dropout=dropout, | ||
| upsample_mode=upsample_mode, | ||
| non_linearity=non_linearity, | ||
| ) | ||
| self.up_blocks.append(up_block) | ||
| if up_flag: | ||
| self.upsample_count += 1 | ||
|
|
||
| # output blocks | ||
| self.norm_out = WanRMS_norm(out_dim, images=False) | ||
| self.conv_out = WanDistCausalConv3d(out_dim, out_channels, 3, padding=1) | ||
| else: | ||
| # init block | ||
| self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1) | ||
|
|
||
| # middle blocks | ||
| self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1) | ||
|
|
||
| # upsample blocks | ||
| self.up_blocks = nn.ModuleList([]) | ||
| for i, (in_dim, out_dim) in enumerate( | ||
| zip(dims[:-1], dims[1:], strict=True) | ||
| ): | ||
| # residual (+attention) blocks | ||
| if i > 0 and not is_residual: | ||
| # wan vae 2.1 | ||
| in_dim = in_dim // 2 | ||
|
|
||
| # determine if we need upsampling | ||
| up_flag = i != len(dim_mult) - 1 | ||
| # determine upsampling mode, if not upsampling, set to None | ||
| upsample_mode = None | ||
| if up_flag and temperal_upsample[i]: | ||
| upsample_mode = "upsample3d" | ||
| elif up_flag: | ||
| upsample_mode = "upsample2d" | ||
|
|
||
| # Create and add the upsampling block | ||
| if is_residual: | ||
| up_block = WanResidualUpBlock( | ||
| in_dim=in_dim, | ||
| out_dim=out_dim, | ||
| num_res_blocks=num_res_blocks, | ||
| dropout=dropout, | ||
| temperal_upsample=temperal_upsample[i] if up_flag else False, | ||
| up_flag=up_flag, | ||
| non_linearity=non_linearity, | ||
| ) | ||
| else: | ||
| up_block = WanUpBlock( | ||
| in_dim=in_dim, | ||
| out_dim=out_dim, | ||
| num_res_blocks=num_res_blocks, | ||
| dropout=dropout, | ||
| upsample_mode=upsample_mode, | ||
| non_linearity=non_linearity, | ||
| ) | ||
| self.up_blocks.append(up_block) | ||
|
|
||
| # output blocks | ||
| self.norm_out = WanRMS_norm(out_dim, images=False) | ||
| self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1) | ||
| # output blocks | ||
| self.norm_out = WanRMS_norm(out_dim, images=False) | ||
| self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1) |
There was a problem hiding this comment.
Similar to WanEncoder3d, this if/else block in WanDecoder3d.__init__ duplicates a significant amount of code for model construction. This harms maintainability and readability.
Please refactor this by dynamically selecting the module classes based on use_parallel_decode to eliminate the duplicated model definition logic, as suggested for WanEncoder3d.
| padding = list(self._padding) | ||
| if cache_x is not None and self._padding[4] > 0: | ||
| cache_x = cache_x.to(x.device) | ||
| x = torch.cat([cache_x, x], dim=2) | ||
| padding[4] -= cache_x.shape[2] | ||
|
|
||
| x = F.pad(x, padding) | ||
|
|
||
| x = ( | ||
| x.to(self.weight.dtype) if current_platform.is_mps() else x | ||
| ) # casting needed for mps since amp isn't supported |
There was a problem hiding this comment.
This block of code for handling cache_x and MPS-specific dtype casting is duplicated from WanCausalConv3d.forward. To improve maintainability and reduce redundancy, consider extracting this shared logic into a common helper function. This function could be called by both WanCausalConv3d and WanDistCausalConv3d to prepare the input tensor x before the main convolution operation.
|
@Songrui625 could you take a review? |
| parser.add_argument( | ||
| f"--{prefix}.use-parallel-encode", | ||
| action=StoreBoolean, | ||
| dest=f"{prefix.replace('-', '_')}.use_parallel_encode", |
There was a problem hiding this comment.
could we make it true as default?
| x = self.resnets[0](x) | ||
|
|
||
| # Process through attention and residual blocks | ||
| for idx, (attn, resnet) in enumerate( |
|
|
||
|
|
||
| def residual_down_block_forward(self, x): | ||
| x_copy = x.clone() |
|
|
||
| def residual_up_block_forward(self, x): | ||
| if self.avg_shortcut is not None: | ||
| x_copy = x.clone() |
| def forward(self, x): | ||
| world_size = 1 | ||
| if dist.is_initialized(): | ||
| world_size = get_sp_world_size() |
There was a problem hiding this comment.
could we move it out of hot path?
| expected_local_height = None | ||
| expected_height = None | ||
| if self.use_parallel_encode and world_size > 1: | ||
| rank = get_sp_parallel_rank() |
| top_row = x[..., :height_halo_size, :].contiguous() | ||
| bottom_row = x[..., -height_halo_size:, :].contiguous() | ||
|
|
||
| recv_top_buf = torch.empty_like(top_row) |
There was a problem hiding this comment.
could this be cached, since the shape of the top_row is static?
| group = sp_group.device_group | ||
| group_ranks = sp_group.ranks | ||
|
|
||
| top_row = x[..., :height_halo_size, :].contiguous() |
There was a problem hiding this comment.
.contiguous() might be redundant, since isend and irecv support non-contiguous tensors for most backends
There was a problem hiding this comment.
@mickqian contiguous() isn’t strictly required since most backends can handle non‑contiguous tensors in isend/irecv, but they typically make an internal copy anyway. Keeping it makes the behavior explicit and avoids backend‑specific corner cases; if x is already contiguous it’s a no‑op. So it’s more about clarity/stability than necessity.
|
/tag-and-rerun-ci |
credits @Songrui625
Close #16510
Motivation
Implement parallel wan-vae encoding based on url
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci