|
8 | 8 | import triton.language as tl # type: ignore |
9 | 9 | from torch import Tensor |
10 | 10 |
|
| 11 | +from sglang.multimodal_gen.runtime.platforms import current_platform |
| 12 | + |
11 | 13 |
|
12 | 14 | @triton.autotune( |
13 | 15 | configs=[ |
@@ -524,8 +526,14 @@ def triton_autotune_configs(): |
524 | 526 | max_threads_per_block = 1024 |
525 | 527 | # Default to warp size 32 if not defined by device |
526 | 528 | warp_size = getattr( |
527 | | - torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32 |
| 529 | + torch.get_device_module().get_device_properties( |
| 530 | + torch.get_device_module().current_device() |
| 531 | + ), |
| 532 | + "warp_size", |
| 533 | + 32, |
528 | 534 | ) |
| 535 | + if warp_size is None: |
| 536 | + warp_size = 32 |
529 | 537 | # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit |
530 | 538 | return [ |
531 | 539 | triton.Config({}, num_warps=warp_count) |
@@ -820,7 +828,7 @@ def _layer_norm_fwd_impl( |
820 | 828 | BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) |
821 | 829 | if N > BLOCK_N: |
822 | 830 | raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") |
823 | | - with torch.cuda.device(x.device.index): |
| 831 | + with torch.get_device_module().device(x.device.index): |
824 | 832 | torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)]( |
825 | 833 | x, |
826 | 834 | out, |
@@ -1166,3 +1174,31 @@ def triton_one_pass_rms_norm(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6 |
1166 | 1174 | BLOCK_SIZE_SEQ=BLOCK_SIZE_SEQ, |
1167 | 1175 | ) |
1168 | 1176 | return y |
| 1177 | + |
| 1178 | + |
| 1179 | +if current_platform.is_npu(): |
| 1180 | + # TODO: remove this when triton ascend bug is fixed |
| 1181 | + def fuse_scale_shift_native( |
| 1182 | + x: torch.Tensor, |
| 1183 | + scale: torch.Tensor, |
| 1184 | + shift: torch.Tensor, |
| 1185 | + block_l: int = 128, |
| 1186 | + block_c: int = 128, |
| 1187 | + ): |
| 1188 | + return x * (1 + scale) + shift |
| 1189 | + |
| 1190 | + fuse_scale_shift_kernel = fuse_scale_shift_native |
| 1191 | + |
| 1192 | + # TODO: remove this when triton ascend bug is fixed |
| 1193 | + def apply_rotary_embedding_native( |
| 1194 | + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False |
| 1195 | + ) -> torch.Tensor: |
| 1196 | + cos = cos.unsqueeze(-2).to(x.dtype) |
| 1197 | + sin = sin.unsqueeze(-2).to(x.dtype) |
| 1198 | + x1 = x[..., ::2] |
| 1199 | + x2 = x[..., 1::2] |
| 1200 | + o1 = x1 * cos - x2 * sin |
| 1201 | + o2 = x2 * cos + x1 * sin |
| 1202 | + return torch.stack((o1, o2), dim=-1).flatten(-2) |
| 1203 | + |
| 1204 | + apply_rotary_embedding = apply_rotary_embedding_native |
0 commit comments