Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ class EncoderConfig(ModelConfig):
class TextEncoderConfig(EncoderConfig):
arch_config: ArchConfig = field(default_factory=TextEncoderArchConfig)

# Use the SP Group of the transformer as the TP Group of T5.
parallel_folding: bool = False
# "sp" or "ulysses" or "ring"
parallel_folding_mode: str = "sp"


@dataclass
class ImageEncoderConfig(EncoderConfig):
Expand Down
27 changes: 27 additions & 0 deletions python/sglang/multimodal_gen/configs/models/encoders/t5.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

# SPDX-License-Identifier: Apache-2.0
import argparse
from dataclasses import dataclass, field

from sglang.multimodal_gen.configs.models.encoders.base import (
TextEncoderArchConfig,
TextEncoderConfig,
)
from sglang.multimodal_gen.utils import StoreBoolean


def _is_transformer_layer(n: str, m) -> bool:
Expand Down Expand Up @@ -84,3 +86,28 @@ class T5Config(TextEncoderConfig):
arch_config: TextEncoderArchConfig = field(default_factory=T5ArchConfig)

prefix: str = "t5"
# Use the SP Group of the transformer as the TP Group of T5.
parallel_folding: bool = False
# "sp" or "ulysses" or "ring"
parallel_folding_mode: str = "sp"

@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser, prefix: str = "t5-config"
) -> argparse.ArgumentParser:
parser.add_argument(
f"--{prefix}.parallel-folding",
action=StoreBoolean,
dest=f"{prefix.replace('-', '_')}.parallel_folding",
default=T5Config.parallel_folding,
help="Whether to use parallel folding for T5",
)
parser.add_argument(
f"--{prefix}.parallel-folding-mode",
type=str,
choices=["sp", "ulysses", "ring"],
dest=f"{prefix.replace('-', '_')}.parallel_folding_mode",
default=T5Config.parallel_folding_mode,
help="Parallel folding mode for T5",
)
return parser
20 changes: 18 additions & 2 deletions python/sglang/multimodal_gen/configs/pipeline_configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
VAEConfig,
)
from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput
from sglang.multimodal_gen.configs.models.encoders.t5 import T5Config
from sglang.multimodal_gen.configs.sample.sampling_params import DataType
from sglang.multimodal_gen.configs.utils import update_config_from_args
from sglang.multimodal_gen.runtime.distributed import (
from sglang.multimodal_gen.runtime.distributed.communication_op import (
sequence_model_parallel_all_gather,
)
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
get_sp_parallel_rank,
get_sp_world_size,
sequence_model_parallel_all_gather,
)
from sglang.multimodal_gen.runtime.models.vision_utils import get_default_height_width
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
Expand Down Expand Up @@ -492,6 +495,11 @@ def add_cli_args(

DiTConfig.add_cli_args(parser, prefix=f"{prefix_with_dot}dit-config")

# Add T5 configuration arguments
from sglang.multimodal_gen.configs.models.encoders.t5 import T5Config

T5Config.add_cli_args(parser, prefix=f"{prefix_with_dot}t5-config")

return parser

def update_config_from_dict(self, args: dict[str, Any], prefix: str = "") -> None:
Expand All @@ -503,6 +511,14 @@ def update_config_from_dict(self, args: dict[str, Any], prefix: str = "") -> Non
update_config_from_args(
self.dit_config, args, f"{prefix_with_dot}dit_config", pop_args=True
)
for text_encoder_config in self.text_encoder_configs:
if isinstance(text_encoder_config, T5Config):
update_config_from_args(
text_encoder_config,
args,
f"{prefix_with_dot}t5_config",
pop_args=True,
)

@classmethod
def from_kwargs(
Expand Down
20 changes: 18 additions & 2 deletions python/sglang/multimodal_gen/runtime/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from functools import lru_cache

# SPDX-License-Identifier: Apache-2.0

from sglang.multimodal_gen.configs.models.encoders import TextEncoderConfig
from sglang.multimodal_gen.runtime.distributed.communication_op import *
from sglang.multimodal_gen.runtime.distributed.group_coordinator import (
get_local_torch_device,
Expand All @@ -27,6 +27,9 @@
)
from sglang.multimodal_gen.runtime.distributed.utils import *

# SPDX-License-Identifier: Apache-2.0


__all__ = [
# Initialization
"init_distributed_environment",
Expand All @@ -53,3 +56,16 @@
# Get torch device
"get_local_torch_device",
]


def _get_folding_tp_group(
config: TextEncoderConfig,
) -> torch.distributed.ProcessGroup | None:
if config.parallel_folding:
if config.parallel_folding_mode == "sp":
return get_sp_group()
elif config.parallel_folding_mode == "ulysses":
return get_sp_group().ulysses_group
elif config.parallel_folding_mode == "ring":
return get_sp_group().ring_group
return get_tp_group()
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/communication_op.py

import torch
import torch.distributed
import torch.distributed as dist

from sglang.multimodal_gen.runtime.distributed.parallel_state import (
get_cfg_group,
Expand All @@ -13,16 +13,20 @@
)


def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
def tensor_model_parallel_all_reduce(
input_: torch.Tensor, tp_group: dist.ProcessGroup = None
Comment on lines +16 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The addition of tp_group: dist.ProcessGroup = None as an optional argument and defaulting to get_tp_group() significantly enhances the flexibility of these communication operations. This is crucial for supporting the new parallel folding feature, allowing specific process groups to be used for tensor parallelism.

) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)
tp_group = tp_group or get_tp_group()
return tp_group.all_reduce(input_)


def tensor_model_parallel_all_gather(
input_: torch.Tensor, dim: int = -1
input_: torch.Tensor, dim: int = -1, tp_group: dist.ProcessGroup = None
) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_tp_group().all_gather(input_, dim)
tp_group = tp_group or get_tp_group()
return tp_group.all_gather(input_, dim)


# TODO: remove model, make it sequence_parallel
Expand Down
52 changes: 33 additions & 19 deletions python/sglang/multimodal_gen/runtime/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from abc import abstractmethod

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from sglang.multimodal_gen.runtime.distributed import (
divide,
get_tp_rank,
get_tp_world_size,
get_tp_group,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
Expand All @@ -21,6 +21,7 @@
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.multimodal_gen.runtime.layers.utils import get_group_rank, get_group_size

# yapf: disable
from sglang.multimodal_gen.runtime.models.parameter import (
Expand Down Expand Up @@ -321,9 +322,12 @@ def __init__(
quant_config: QuantizationConfig | None = None,
output_sizes: list[int] | None = None,
prefix: str = "",
tp_group: dist.ProcessGroup = None,
):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tp_world_size()
self.tp_group = tp_group or get_tp_group()
self.tp_size = get_group_size(self.tp_group)
self.tp_rank = get_group_rank(self.tp_group)
Comment on lines +325 to +330
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accepting tp_group in ColumnParallelLinear.__init__ and deriving tp_size and tp_rank from it makes the layer more modular and independent of global state. This is a good architectural improvement for flexibility and testability.

self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
Expand Down Expand Up @@ -374,7 +378,7 @@ def __init__(
self.register_parameter("bias", None)

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor) -> None:
tp_rank = get_tp_rank()
tp_rank = self.tp_rank
output_dim = getattr(param, "output_dim", None)

is_sharded_weight = getattr(param, "is_sharded_weight", False)
Expand Down Expand Up @@ -410,7 +414,9 @@ def forward(self, input_: torch.Tensor) -> tuple[torch.Tensor, Parameter | None]
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
output = tensor_model_parallel_all_gather(
output_parallel, tp_group=self.tp_group
)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
Expand All @@ -420,7 +426,7 @@ def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tp_world_size()}"
s += f", tp_size={self.tp_size}"
s += f", gather_output={self.gather_output}"
return s

Expand Down Expand Up @@ -458,10 +464,8 @@ def __init__(
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
tp_group: dist.ProcessGroup = None,
):
self.output_sizes = output_sizes
tp_size = get_tp_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(
input_size=input_size,
output_size=sum(output_sizes),
Expand All @@ -471,7 +475,10 @@ def __init__(
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
tp_group=tp_group,
)
self.output_sizes = output_sizes
assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
Comment on lines +480 to +481
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Moving the assertion assert all(output_size % self.tp_size == 0 for output_size in output_sizes) after the super().__init__ call is correct. This ensures that self.tp_size has been properly initialized by the superclass before being used in the assertion.


def weight_loader(
self,
Expand Down Expand Up @@ -512,8 +519,8 @@ def weight_loader(
return

assert loaded_shard_id < len(self.output_sizes)
tp_rank = get_tp_rank()
tp_size = get_tp_world_size()
tp_rank = self.tp_rank
tp_size = self.tp_size
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
Expand Down Expand Up @@ -607,7 +614,7 @@ def weight_loader_v2(

assert loaded_shard_id < len(self.output_sizes)

tp_size = get_tp_world_size()
tp_size = self.tp_size

if isinstance(param, BlockQuantScaleParameter):
raise NotImplementedError("FP8 is not implemented yet")
Expand Down Expand Up @@ -674,6 +681,7 @@ def __init__(
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
tp_group: dist.ProcessGroup = None,
):
self.hidden_size = hidden_size
self.head_size = head_size
Expand All @@ -682,7 +690,8 @@ def __init__(
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = get_tp_world_size()
tp_group = tp_group or get_tp_group()
tp_size = get_group_size(tp_group)
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
Expand All @@ -709,6 +718,7 @@ def __init__(
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
tp_group=tp_group,
)

def _get_shard_offset_mapping(self, loaded_shard_id: str) -> int | None:
Expand Down Expand Up @@ -852,7 +862,7 @@ def weight_loader(
self.weight_loader(param, loaded_weight_shard, shard_id)
return

tp_rank = get_tp_rank()
tp_rank = self.tp_rank
assert loaded_shard_id in ["q", "k", "v"]

# If output dim is defined, use the default loading process.
Expand Down Expand Up @@ -944,10 +954,12 @@ def __init__(
reduce_results: bool = True,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
tp_group: dist.ProcessGroup = None,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tp_rank()
self.tp_size = get_tp_world_size()
self.tp_group = tp_group or get_tp_group()
self.tp_rank = get_group_rank(self.tp_group)
self.tp_size = get_group_size(self.tp_group)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
Expand Down Expand Up @@ -992,7 +1004,7 @@ def __init__(
self.register_parameter("bias", None)

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tp_rank()
tp_rank = self.tp_rank
input_dim = getattr(param, "input_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
Expand Down Expand Up @@ -1027,7 +1039,7 @@ def forward(self, input_) -> tuple[torch.Tensor, Parameter | None]:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tp_rank()
tp_rank = self.tp_rank
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size
)
Expand All @@ -1040,7 +1052,9 @@ def forward(self, input_) -> tuple[torch.Tensor, Parameter | None]:
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
output = tensor_model_parallel_all_reduce(
output_parallel, tp_group=self.tp_group
)
else:
output = output_parallel

Expand Down
18 changes: 18 additions & 0 deletions python/sglang/multimodal_gen/runtime/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,24 @@
from sglang.multimodal_gen.runtime.platforms import current_platform


def get_group_size(group) -> int:
if hasattr(group, "world_size"):
return group.world_size # GroupCoordinator
elif hasattr(group, "size") and callable(getattr(group, "size", None)):
return group.size() # ProcessGroup
else:
raise ValueError(f"Unsupported group type: {type(group)}")


def get_group_rank(group) -> int:
if hasattr(group, "rank_in_group"):
return group.rank_in_group # GroupCoordinator
elif hasattr(group, "rank") and callable(getattr(group, "rank", None)):
return group.rank() # ProcessGroup
else:
raise ValueError(f"Unsupported group type: {type(group)}")
Comment on lines +15 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The get_group_size and get_group_rank functions provide a clean and robust abstraction for querying properties of different process group types (GroupCoordinator and ProcessGroup). This enhances code readability and maintainability by centralizing this logic.



def get_token_bin_counts_and_mask(
tokens: torch.Tensor,
vocab_size: int,
Expand Down
Loading
Loading