Skip to content

[diffusion] feat: allow T5's TP Group to reuse the transformer's SP Group#17818

Merged
mickqian merged 10 commits intosgl-project:mainfrom
nono-Sang:t5_parallel_folding
Feb 5, 2026
Merged

[diffusion] feat: allow T5's TP Group to reuse the transformer's SP Group#17818
mickqian merged 10 commits intosgl-project:mainfrom
nono-Sang:t5_parallel_folding

Conversation

@nono-Sang
Copy link
Contributor

Motivation

Currently, the TP Group of the text_encoder is consistent with that of the transformer. I have introduced an optional feature that allows the text_encoder to adopt the SP or Ulysses or Ring Group of the transformer as its own TP Group (referred to as "parallel folding").

Use case: For instance, if I run inference on eight GPUs with ulysses=8 and tp=1, each GPU will store a complete copy of the text_encoder. By enabling parallel folding, the text_encoder can utilize tp8 instead.

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@github-actions github-actions bot added the diffusion SGLang Diffusion label Jan 27, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nono-Sang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization for T5 text encoders in distributed environments. It enables the text encoder's tensor parallelism to leverage existing sequence, Ulysses, or ring parallel groups from the main transformer model. This "parallel folding" can lead to more efficient resource utilization, particularly memory, by preventing redundant model copies on GPUs when other parallelism strategies are already in use. The changes involve modifying T5 configuration, refactoring distributed communication primitives, and updating parallel linear and embedding layers to support dynamic process group assignment.

Highlights

  • T5 Parallel Folding Feature: Introduced an optional parallel_folding mechanism for T5 text encoders, allowing them to reuse existing transformer parallel groups (Sequence Parallel, Ulysses, or Ring) as their own Tensor Parallel (TP) group for improved resource utilization.
  • Flexible Distributed Operations: Refactored core distributed communication operations (tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather) and parallel linear/embedding layers to accept an explicit tp_group argument, enhancing flexibility in distributed setups.
  • Unified Group Utilities: Added get_group_size and get_group_rank utility functions to abstract away the underlying ProcessGroup or GroupCoordinator types, simplifying group-related operations across the codebase.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new 'parallel folding' feature for T5 text encoders, allowing them to utilize existing sequence parallelism groups (SP, Ulysses, or Ring Group) as their tensor parallelism (TP) group. This is achieved by adding parallel_folding and parallel_folding_mode configurations to T5Config. The changes involve refactoring distributed communication operations to accept an optional tp_group argument, which defaults to the global TP group if not specified. This refactoring is consistently applied across various linear and embedding layers to enhance flexibility in group assignment. Additionally, new utility functions get_group_size and get_group_rank are introduced for abstracting process group properties. A notable change in wanvideo.py updates proj_out to use ColumnParallelLinear, enabling its parallelization.

self.n_heads = config.num_heads // tp_world_size
self.tp_group = _get_folding_tp_group(config)
self.tp_world_size = get_group_size(self.tp_group)
assert config.num_heads % self.tp_world_size == 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The assertion assert config.num_heads % self.tp_world_size == 0 is crucial for ensuring the correct distribution of attention heads in a tensor parallel setup. If config.num_heads is not divisible by self.tp_world_size, it would lead to an uneven distribution or errors in parallel processing.

Comment on lines +750 to +754
self.proj_out = ColumnParallelLinear(
inner_dim,
config.out_channels * math.prod(config.patch_size),
bias=True,
gather_output=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Changing self.proj_out from nn.Linear to ColumnParallelLinear implies that this projection layer will now be parallelized. This is a significant functional change that should be thoroughly tested to ensure correctness and to verify any performance implications.

Comment on lines +16 to +17
def tensor_model_parallel_all_reduce(
input_: torch.Tensor, tp_group: dist.ProcessGroup = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The addition of tp_group: dist.ProcessGroup = None as an optional argument and defaulting to get_tp_group() significantly enhances the flexibility of these communication operations. This is crucial for supporting the new parallel folding feature, allowing specific process groups to be used for tensor parallelism.

Comment on lines +325 to +330
tp_group: dist.ProcessGroup = None,
):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tp_world_size()
self.tp_group = tp_group or get_tp_group()
self.tp_size = get_group_size(self.tp_group)
self.tp_rank = get_group_rank(self.tp_group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accepting tp_group in ColumnParallelLinear.__init__ and deriving tp_size and tp_rank from it makes the layer more modular and independent of global state. This is a good architectural improvement for flexibility and testability.

Comment on lines +480 to +481
self.output_sizes = output_sizes
assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Moving the assertion assert all(output_size % self.tp_size == 0 for output_size in output_sizes) after the super().__init__ call is correct. This ensures that self.tp_size has been properly initialized by the superclass before being used in the assertion.

Comment on lines +15 to +30
def get_group_size(group) -> int:
if hasattr(group, "world_size"):
return group.world_size # GroupCoordinator
elif hasattr(group, "size") and callable(getattr(group, "size", None)):
return group.size() # ProcessGroup
else:
raise ValueError(f"Unsupported group type: {type(group)}")


def get_group_rank(group) -> int:
if hasattr(group, "rank_in_group"):
return group.rank_in_group # GroupCoordinator
elif hasattr(group, "rank") and callable(getattr(group, "rank", None)):
return group.rank() # ProcessGroup
else:
raise ValueError(f"Unsupported group type: {type(group)}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The get_group_size and get_group_rank functions provide a clean and robust abstraction for querying properties of different process group types (GroupCoordinator and ProcessGroup). This enhances code readability and maintainability by centralizing this logic.

Comment on lines +76 to +84
def _get_folding_tp_group(config: T5Config) -> dist.ProcessGroup | None:
if config.parallel_folding:
if config.parallel_folding_mode == "sp":
return get_sp_group()
elif config.parallel_folding_mode == "ulysses":
return get_sp_group().ulysses_group
elif config.parallel_folding_mode == "ring":
return get_sp_group().ring_group
return get_tp_group()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _get_folding_tp_group function effectively centralizes the logic for determining the appropriate tp_group based on the parallel_folding configuration. This is a good design pattern for managing conditional logic related to distributed groups, improving clarity and reducing redundancy.

@BBuf
Copy link
Collaborator

BBuf commented Feb 2, 2026

@mickqian Any advices?

attn_bias: torch.Tensor


def _get_folding_tp_group(config: T5Config) -> dist.ProcessGroup | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we have a @lru_cache(maxsize=1) here?

attn_bias: torch.Tensor


def _get_folding_tp_group(config: T5Config) -> dist.ProcessGroup | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider moving to somewhere like distributed/util.py

Copy link
Collaborator

@mickqian mickqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

brilliant. we should document this change

@BBuf
Copy link
Collaborator

BBuf commented Feb 3, 2026

cc @nono-Sang

@nono-Sang nono-Sang force-pushed the t5_parallel_folding branch from 6db6152 to 429f0f9 Compare February 3, 2026 16:13


_seen_keys = set() # 用集合记录已经出现过的 key
_seen_keys = set()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you also clean this? seems redundant 😂

@mickqian mickqian force-pushed the t5_parallel_folding branch from e34df17 to 2af63e2 Compare February 5, 2026 01:37
@mickqian
Copy link
Collaborator

mickqian commented Feb 5, 2026

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Feb 5, 2026
@mickqian mickqian merged commit b639779 into sgl-project:main Feb 5, 2026
80 of 82 checks passed
charlesHsuGG pushed a commit to charlesHsuGG/sglang that referenced this pull request Feb 9, 2026
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants