Skip to content

Commit 00248d8

Browse files
Makcum888eDHX98ping1jing2DHX98yhyang201
authored
[diffusion] platform: support WAN/FLUX/Qwen-Image/Qwen-Image-edit on Ascend (sgl-project#13662)
Co-authored-by: dhx98 <haox.dai@gmail.com> Co-authored-by: DHX98 <haoxiand@andrew.cmu.edu> Co-authored-by: ronnie_zheng <zl19940307@163.com> Co-authored-by: DHX98 <DHX98@noreply.gitcode.com> Co-authored-by: Yuhao Yang <47235274+yhyang201@users.noreply.github.com>
1 parent 7b83659 commit 00248d8

File tree

25 files changed

+476
-30
lines changed

25 files changed

+476
-30
lines changed

.github/workflows/pr-test-npu.yml

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ jobs:
6464
multimodal_gen:
6565
- "python/sglang/multimodal_gen/**"
6666
- "python/pyproject_npu.toml"
67-
- "scripts/ci/npu_ci_install_dependency.sh"
67+
- "scripts/ci/npu/npu_ci_install_dependency.sh"
6868
- ".github/workflows/pr-test-npu.yml"
6969
7070
# ==================== PR Gate ==================== #
@@ -241,3 +241,42 @@ jobs:
241241
run: |
242242
cd test/srt
243243
python3 run_suite.py --suite per-commit-16-npu-a3 --timeout-per-file 3600 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2
244+
245+
multimodal-gen-test-1-npu-a3:
246+
needs: [check-changes, pr-gate]
247+
if: needs.check-changes.outputs.multimodal_gen == 'true'
248+
runs-on: linux-aarch64-a3-16
249+
container:
250+
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc2-a3-ubuntu22.04-py3.11
251+
steps:
252+
- name: Checkout code
253+
uses: actions/checkout@v4
254+
255+
- name: Install dependencies
256+
run: |
257+
# speed up by using infra cache services
258+
CACHING_URL="cache-service.nginx-pypi-cache.svc.cluster.local"
259+
sed -Ei "s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g" /etc/apt/sources.list
260+
pip config set global.index-url http://${CACHING_URL}/pypi/simple
261+
pip config set global.extra-index-url "https://pypi.tuna.tsinghua.edu.cn/simple"
262+
pip config set global.trusted-host "${CACHING_URL} pypi.tuna.tsinghua.edu.cn"
263+
264+
bash scripts/ci/npu/npu_ci_install_dependency.sh a3
265+
# copy required file from our daily cache
266+
cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp
267+
# copy download through proxy
268+
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
269+
270+
- name: Run test
271+
timeout-minutes: 60
272+
env:
273+
SGLANG_USE_MODELSCOPE: true
274+
SGLANG_IS_IN_CI: true
275+
HF_ENDPOINT: https://hf-mirror.com
276+
TORCH_EXTENSIONS_DIR: /tmp/torch_extensions
277+
PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True"
278+
STREAMS_PER_DEVICE: 32
279+
run: |
280+
export PATH="/usr/local/Ascend/8.3.RC1/compiler/bishengir/bin:${PATH}"
281+
cd python
282+
python3 sglang/multimodal_gen/test/run_suite.py --suite 1-npu

python/pyproject_npu.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ diffusion = [
7777
"moviepy>=2.0.0",
7878
"opencv-python==4.10.0.84",
7979
"remote-pdb",
80-
"cache-dit==1.1.8"
80+
"cache-dit==1.2.1",
81+
"addict"
8182
]
8283

8384
tracing = [

python/sglang/multimodal_gen/runtime/distributed/group_coordinator.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from torch.cuda import synchronize
1717
from torch.distributed import Backend, ProcessGroup
1818

19-
from sglang.multimodal_gen import envs
2019
from sglang.multimodal_gen.runtime.distributed.device_communicators.base_device_communicator import (
2120
DeviceCommunicatorBase,
2221
)
@@ -46,11 +45,7 @@
4645
def get_local_torch_device() -> torch.device:
4746
"""Return the torch device for the current rank."""
4847

49-
return (
50-
torch.device(f"cuda:{envs.LOCAL_RANK}")
51-
if current_platform.is_cuda_alike()
52-
else torch.device("mps")
53-
)
48+
return current_platform.get_local_torch_device()
5449

5550

5651
def _get_unique_name(name: str) -> str:
@@ -190,8 +185,6 @@ def __init__(
190185
# TODO: fix it for other platforms
191186
self.device = get_local_torch_device()
192187

193-
from sglang.multimodal_gen.runtime.platforms import current_platform
194-
195188
self.use_device_communicator = use_device_communicator
196189

197190
self.device_communicator: DeviceCommunicatorBase = None # type: ignore
@@ -287,9 +280,6 @@ def group_skip_rank(self):
287280

288281
@contextmanager
289282
def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None):
290-
# Platform-aware graph capture
291-
from sglang.multimodal_gen.runtime.platforms import current_platform
292-
293283
if current_platform.is_cuda_alike():
294284
if graph_capture_context is None:
295285
stream = torch.cuda.Stream()

python/sglang/multimodal_gen/runtime/distributed/parallel_state.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,11 @@ def init_distributed_environment(
248248
# For MPS and MUSA, don't pass device_id as it doesn't support device indices
249249
extra_args = (
250250
{}
251-
if (current_platform.is_mps() or current_platform.is_musa())
251+
if (
252+
current_platform.is_mps()
253+
or current_platform.is_musa()
254+
or current_platform.is_npu()
255+
)
252256
else dict(device_id=device_id)
253257
)
254258

@@ -618,6 +622,7 @@ def maybe_init_distributed_environment_and_model_parallel(
618622
local_rank=local_rank,
619623
distributed_init_method=distributed_init_method,
620624
device_id=device,
625+
backend=current_platform.get_torch_distributed_backend_str(),
621626
timeout=dist_timeout,
622627
)
623628
initialize_model_parallel(

python/sglang/multimodal_gen/runtime/layers/activation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414

1515
_is_cuda = current_platform.is_cuda()
1616
_is_hip = current_platform.is_hip()
17+
_is_npu = current_platform.is_npu()
1718
if _is_cuda or _is_hip:
1819
from sgl_kernel import silu_and_mul
20+
21+
if _is_npu:
22+
import torch_npu
1923
# TODO (will): remove this dependency
2024
from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp
2125

@@ -46,6 +50,10 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
4650
d = x.shape[-1] // 2
4751
return F.silu(x[..., :d]) * x[..., d:]
4852

53+
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
54+
out = torch_npu.npu_swiglu(x)
55+
return out
56+
4957

5058
@CustomOp.register("gelu_and_mul")
5159
class GeluAndMul(CustomOp):

python/sglang/multimodal_gen/runtime/layers/custom_op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ def forward_oot(self, *args, **kwargs) -> Any:
6464
# PyTorch-native implementation.
6565
return self.forward_native(*args, **kwargs)
6666

67+
def forward_npu(self, *args, **kwargs) -> Any:
68+
# By default, we assume that NPU ops are compatible with the
69+
# PyTorch-native implementation.
70+
return self.forward_native(*args, **kwargs)
71+
6772
def dispatch_forward(self) -> Callable:
6873
if _is_cuda:
6974
return self.forward_cuda

python/sglang/multimodal_gen/runtime/layers/layernorm.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@
1212
from sglang.multimodal_gen.runtime.platforms import current_platform
1313

1414
_is_cuda = current_platform.is_cuda()
15+
_is_npu = current_platform.is_npu()
1516
if _is_cuda:
1617
from sgl_kernel import fused_add_rmsnorm, rmsnorm
1718

19+
if _is_npu:
20+
import torch_npu
21+
1822
from sglang.jit_kernel.norm import can_use_fused_inplace_qknorm, fused_inplace_qknorm
1923
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
2024
get_tensor_model_parallel_rank,
@@ -28,11 +32,8 @@
2832
rms_norm_fn,
2933
triton_one_pass_rms_norm,
3034
)
31-
from sglang.multimodal_gen.runtime.platforms import current_platform
3235
from sglang.multimodal_gen.runtime.utils.common import get_bool_env_var
3336

34-
_is_cuda = current_platform.is_cuda()
35-
3637

3738
# Copied and adapted from sglang
3839
@CustomOp.register("rms_norm")
@@ -141,6 +142,18 @@ def forward_cpu(
141142
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
142143
return self.forward_native(x, residual)
143144

145+
def forward_npu(
146+
self,
147+
x: torch.Tensor,
148+
residual: Optional[torch.Tensor] = None,
149+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
150+
if residual is not None:
151+
out, _, residual_out = torch_npu.npu_add_rms_norm(
152+
residual, x, self.weight.data, self.variance_epsilon
153+
)
154+
return out, residual_out
155+
return torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
156+
144157
def forward_hip(
145158
self,
146159
x: torch.Tensor,
@@ -214,7 +227,7 @@ def forward_cuda(
214227
x = x.view(-1, self.hidden_size)
215228
return self.forward_triton(x).view(shape)
216229

217-
@torch.compile(backend="inductor")
230+
@torch.compile(backend="inductor", disable=current_platform.is_npu())
218231
def forward_native(
219232
self,
220233
x: torch.Tensor,

python/sglang/multimodal_gen/runtime/layers/linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
# yapf: enable
3737
from sglang.multimodal_gen.runtime.models.utils import set_weight_attrs
38+
from sglang.multimodal_gen.runtime.platforms import current_platform
3839
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
3940

4041
logger = init_logger(__name__)
@@ -152,7 +153,7 @@ def apply(
152153
) -> torch.Tensor:
153154
output = (
154155
F.linear(x, layer.weight, bias)
155-
if torch.cuda.is_available() or bias is None
156+
if current_platform.is_amp_supported() or bias is None
156157
else F.linear(x, layer.weight, bias.to(x.dtype))
157158
) # NOTE: this line assumes that we are using amp when using cuda and is needed to account for the fact that amp isn't supported in mps
158159
return output

python/sglang/multimodal_gen/runtime/layers/triton_ops.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import triton.language as tl # type: ignore
99
from torch import Tensor
1010

11+
from sglang.multimodal_gen.runtime.platforms import current_platform
12+
1113

1214
@triton.autotune(
1315
configs=[
@@ -524,8 +526,14 @@ def triton_autotune_configs():
524526
max_threads_per_block = 1024
525527
# Default to warp size 32 if not defined by device
526528
warp_size = getattr(
527-
torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32
529+
torch.get_device_module().get_device_properties(
530+
torch.get_device_module().current_device()
531+
),
532+
"warp_size",
533+
32,
528534
)
535+
if warp_size is None:
536+
warp_size = 32
529537
# Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
530538
return [
531539
triton.Config({}, num_warps=warp_count)
@@ -820,7 +828,7 @@ def _layer_norm_fwd_impl(
820828
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
821829
if N > BLOCK_N:
822830
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
823-
with torch.cuda.device(x.device.index):
831+
with torch.get_device_module().device(x.device.index):
824832
torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)](
825833
x,
826834
out,
@@ -1166,3 +1174,31 @@ def triton_one_pass_rms_norm(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6
11661174
BLOCK_SIZE_SEQ=BLOCK_SIZE_SEQ,
11671175
)
11681176
return y
1177+
1178+
1179+
if current_platform.is_npu():
1180+
# TODO: remove this when triton ascend bug is fixed
1181+
def fuse_scale_shift_native(
1182+
x: torch.Tensor,
1183+
scale: torch.Tensor,
1184+
shift: torch.Tensor,
1185+
block_l: int = 128,
1186+
block_c: int = 128,
1187+
):
1188+
return x * (1 + scale) + shift
1189+
1190+
fuse_scale_shift_kernel = fuse_scale_shift_native
1191+
1192+
# TODO: remove this when triton ascend bug is fixed
1193+
def apply_rotary_embedding_native(
1194+
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
1195+
) -> torch.Tensor:
1196+
cos = cos.unsqueeze(-2).to(x.dtype)
1197+
sin = sin.unsqueeze(-2).to(x.dtype)
1198+
x1 = x[..., ::2]
1199+
x2 = x[..., 1::2]
1200+
o1 = x1 * cos - x2 * sin
1201+
o2 = x2 * cos + x1 * sin
1202+
return torch.stack((o1, o2), dim=-1).flatten(-2)
1203+
1204+
apply_rotary_embedding = apply_rotary_embedding_native

python/sglang/multimodal_gen/runtime/layers/vocab_parallel_embedding.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,11 @@ def __post_init__(self):
145145
assert self.num_added_elements <= self.num_added_elements_padded
146146

147147

148-
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
148+
@torch.compile(
149+
dynamic=True,
150+
backend=current_platform.simple_compile_backend,
151+
disable=current_platform.is_npu(),
152+
)
149153
def get_masked_input_and_mask(
150154
input_: torch.Tensor,
151155
org_vocab_start_index: int,

0 commit comments

Comments
 (0)