-
Notifications
You must be signed in to change notification settings - Fork 4.7k
[diffusion] feat: allow T5's TP Group to reuse the transformer's SP Group #17818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
d4ea7c1
[diffusion] feat: allow T5's TP Group to reuse the transformer's SP G…
nono-Sang e260e74
fix
nono-Sang 7e44509
update
nono-Sang 663c60c
fix
nono-Sang 2af63e2
upd
mickqian 240fc6f
upd
mickqian 8129d41
upd
mickqian 505e332
upd
mickqian 28b153e
upd
mickqian 926df7e
upd
mickqian File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,13 +6,13 @@ | |
| from abc import abstractmethod | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| import torch.nn.functional as F | ||
| from torch.nn.parameter import Parameter | ||
|
|
||
| from sglang.multimodal_gen.runtime.distributed import ( | ||
| divide, | ||
| get_tp_rank, | ||
| get_tp_world_size, | ||
| get_tp_group, | ||
| split_tensor_along_last_dim, | ||
| tensor_model_parallel_all_gather, | ||
| tensor_model_parallel_all_reduce, | ||
|
|
@@ -21,6 +21,7 @@ | |
| QuantizationConfig, | ||
| QuantizeMethodBase, | ||
| ) | ||
| from sglang.multimodal_gen.runtime.layers.utils import get_group_rank, get_group_size | ||
|
|
||
| # yapf: disable | ||
| from sglang.multimodal_gen.runtime.models.parameter import ( | ||
|
|
@@ -321,9 +322,12 @@ def __init__( | |
| quant_config: QuantizationConfig | None = None, | ||
| output_sizes: list[int] | None = None, | ||
| prefix: str = "", | ||
| tp_group: dist.ProcessGroup = None, | ||
| ): | ||
| # Divide the weight matrix along the last dimension. | ||
| self.tp_size = get_tp_world_size() | ||
| self.tp_group = tp_group or get_tp_group() | ||
| self.tp_size = get_group_size(self.tp_group) | ||
| self.tp_rank = get_group_rank(self.tp_group) | ||
|
Comment on lines
+325
to
+330
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| self.input_size_per_partition = input_size | ||
| self.output_size_per_partition = divide(output_size, self.tp_size) | ||
| self.output_partition_sizes = [self.output_size_per_partition] | ||
|
|
@@ -374,7 +378,7 @@ def __init__( | |
| self.register_parameter("bias", None) | ||
|
|
||
| def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor) -> None: | ||
| tp_rank = get_tp_rank() | ||
| tp_rank = self.tp_rank | ||
| output_dim = getattr(param, "output_dim", None) | ||
|
|
||
| is_sharded_weight = getattr(param, "is_sharded_weight", False) | ||
|
|
@@ -410,7 +414,9 @@ def forward(self, input_: torch.Tensor) -> tuple[torch.Tensor, Parameter | None] | |
| output_parallel = self.quant_method.apply(self, input_, bias) | ||
| if self.gather_output: | ||
| # All-gather across the partitions. | ||
| output = tensor_model_parallel_all_gather(output_parallel) | ||
| output = tensor_model_parallel_all_gather( | ||
| output_parallel, tp_group=self.tp_group | ||
| ) | ||
| else: | ||
| output = output_parallel | ||
| output_bias = self.bias if self.skip_bias_add else None | ||
|
|
@@ -420,7 +426,7 @@ def extra_repr(self) -> str: | |
| s = f"in_features={self.input_size}" | ||
| s += f", output_features={self.output_size_per_partition}" | ||
| s += f", bias={self.bias is not None}" | ||
| s += f", tp_size={get_tp_world_size()}" | ||
| s += f", tp_size={self.tp_size}" | ||
| s += f", gather_output={self.gather_output}" | ||
| return s | ||
|
|
||
|
|
@@ -458,10 +464,8 @@ def __init__( | |
| params_dtype: torch.dtype | None = None, | ||
| quant_config: QuantizationConfig | None = None, | ||
| prefix: str = "", | ||
| tp_group: dist.ProcessGroup = None, | ||
| ): | ||
| self.output_sizes = output_sizes | ||
| tp_size = get_tp_world_size() | ||
| assert all(output_size % tp_size == 0 for output_size in output_sizes) | ||
| super().__init__( | ||
| input_size=input_size, | ||
| output_size=sum(output_sizes), | ||
|
|
@@ -471,7 +475,10 @@ def __init__( | |
| params_dtype=params_dtype, | ||
| quant_config=quant_config, | ||
| prefix=prefix, | ||
| tp_group=tp_group, | ||
| ) | ||
| self.output_sizes = output_sizes | ||
| assert all(output_size % self.tp_size == 0 for output_size in output_sizes) | ||
|
Comment on lines
+480
to
+481
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def weight_loader( | ||
| self, | ||
|
|
@@ -512,8 +519,8 @@ def weight_loader( | |
| return | ||
|
|
||
| assert loaded_shard_id < len(self.output_sizes) | ||
| tp_rank = get_tp_rank() | ||
| tp_size = get_tp_world_size() | ||
| tp_rank = self.tp_rank | ||
| tp_size = self.tp_size | ||
| if output_dim is not None: | ||
| shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size | ||
| shard_size = self.output_sizes[loaded_shard_id] // tp_size | ||
|
|
@@ -607,7 +614,7 @@ def weight_loader_v2( | |
|
|
||
| assert loaded_shard_id < len(self.output_sizes) | ||
|
|
||
| tp_size = get_tp_world_size() | ||
| tp_size = self.tp_size | ||
|
|
||
| if isinstance(param, BlockQuantScaleParameter): | ||
| raise NotImplementedError("FP8 is not implemented yet") | ||
|
|
@@ -674,6 +681,7 @@ def __init__( | |
| params_dtype: torch.dtype | None = None, | ||
| quant_config: QuantizationConfig | None = None, | ||
| prefix: str = "", | ||
| tp_group: dist.ProcessGroup = None, | ||
| ): | ||
| self.hidden_size = hidden_size | ||
| self.head_size = head_size | ||
|
|
@@ -682,7 +690,8 @@ def __init__( | |
| total_num_kv_heads = total_num_heads | ||
| self.total_num_kv_heads = total_num_kv_heads | ||
| # Divide the weight matrix along the last dimension. | ||
| tp_size = get_tp_world_size() | ||
| tp_group = tp_group or get_tp_group() | ||
| tp_size = get_group_size(tp_group) | ||
| self.num_heads = divide(self.total_num_heads, tp_size) | ||
| if tp_size >= self.total_num_kv_heads: | ||
| self.num_kv_heads = 1 | ||
|
|
@@ -709,6 +718,7 @@ def __init__( | |
| params_dtype=params_dtype, | ||
| quant_config=quant_config, | ||
| prefix=prefix, | ||
| tp_group=tp_group, | ||
| ) | ||
|
|
||
| def _get_shard_offset_mapping(self, loaded_shard_id: str) -> int | None: | ||
|
|
@@ -852,7 +862,7 @@ def weight_loader( | |
| self.weight_loader(param, loaded_weight_shard, shard_id) | ||
| return | ||
|
|
||
| tp_rank = get_tp_rank() | ||
| tp_rank = self.tp_rank | ||
| assert loaded_shard_id in ["q", "k", "v"] | ||
|
|
||
| # If output dim is defined, use the default loading process. | ||
|
|
@@ -944,10 +954,12 @@ def __init__( | |
| reduce_results: bool = True, | ||
| quant_config: QuantizationConfig | None = None, | ||
| prefix: str = "", | ||
| tp_group: dist.ProcessGroup = None, | ||
| ): | ||
| # Divide the weight matrix along the first dimension. | ||
| self.tp_rank = get_tp_rank() | ||
| self.tp_size = get_tp_world_size() | ||
| self.tp_group = tp_group or get_tp_group() | ||
| self.tp_rank = get_group_rank(self.tp_group) | ||
| self.tp_size = get_group_size(self.tp_group) | ||
| self.input_size_per_partition = divide(input_size, self.tp_size) | ||
| self.output_size_per_partition = output_size | ||
| self.output_partition_sizes = [output_size] | ||
|
|
@@ -992,7 +1004,7 @@ def __init__( | |
| self.register_parameter("bias", None) | ||
|
|
||
| def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): | ||
| tp_rank = get_tp_rank() | ||
| tp_rank = self.tp_rank | ||
| input_dim = getattr(param, "input_dim", None) | ||
| is_sharded_weight = getattr(param, "is_sharded_weight", False) | ||
| # bitsandbytes loads the weights of the specific portion | ||
|
|
@@ -1027,7 +1039,7 @@ def forward(self, input_) -> tuple[torch.Tensor, Parameter | None]: | |
| if self.input_is_parallel: | ||
| input_parallel = input_ | ||
| else: | ||
| tp_rank = get_tp_rank() | ||
| tp_rank = self.tp_rank | ||
| splitted_input = split_tensor_along_last_dim( | ||
| input_, num_partitions=self.tp_size | ||
| ) | ||
|
|
@@ -1040,7 +1052,9 @@ def forward(self, input_) -> tuple[torch.Tensor, Parameter | None]: | |
| bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias | ||
| output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) | ||
| if self.reduce_results and self.tp_size > 1: | ||
| output = tensor_model_parallel_all_reduce(output_parallel) | ||
| output = tensor_model_parallel_all_reduce( | ||
| output_parallel, tp_group=self.tp_group | ||
| ) | ||
| else: | ||
| output = output_parallel | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,24 @@ | |
| from sglang.multimodal_gen.runtime.platforms import current_platform | ||
|
|
||
|
|
||
| def get_group_size(group) -> int: | ||
| if hasattr(group, "world_size"): | ||
| return group.world_size # GroupCoordinator | ||
| elif hasattr(group, "size") and callable(getattr(group, "size", None)): | ||
| return group.size() # ProcessGroup | ||
| else: | ||
| raise ValueError(f"Unsupported group type: {type(group)}") | ||
|
|
||
|
|
||
| def get_group_rank(group) -> int: | ||
| if hasattr(group, "rank_in_group"): | ||
| return group.rank_in_group # GroupCoordinator | ||
| elif hasattr(group, "rank") and callable(getattr(group, "rank", None)): | ||
| return group.rank() # ProcessGroup | ||
| else: | ||
| raise ValueError(f"Unsupported group type: {type(group)}") | ||
|
Comment on lines
+15
to
+30
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
|
|
||
| def get_token_bin_counts_and_mask( | ||
| tokens: torch.Tensor, | ||
| vocab_size: int, | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The addition of
tp_group: dist.ProcessGroup = Noneas an optional argument and defaulting toget_tp_group()significantly enhances the flexibility of these communication operations. This is crucial for supporting the new parallel folding feature, allowing specific process groups to be used for tensor parallelism.