|
7 | 7 | import torch |
8 | 8 | from einops import rearrange |
9 | 9 |
|
| 10 | +from sglang.srt.environ import envs |
10 | 11 | from sglang.srt.layers.layernorm import LayerNorm |
11 | 12 | from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz |
12 | 13 | from sglang.srt.layers.utils import MultiPlatformOp |
@@ -1190,13 +1191,17 @@ def forward_npu( |
1190 | 1191 | ) # [bs, n, d] |
1191 | 1192 | q = torch.cat([q_pe, q_nope], dim=-1) |
1192 | 1193 |
|
1193 | | - indexer_weight_stream = get_indexer_weight_stream() |
1194 | | - indexer_weight_stream.wait_stream(torch.npu.current_stream()) |
1195 | | - with torch.npu.stream(indexer_weight_stream): |
| 1194 | + if envs.SGLANG_NPU_USE_MULTI_STREAM.get(): |
| 1195 | + indexer_weight_stream = get_indexer_weight_stream() |
| 1196 | + indexer_weight_stream.wait_stream(torch.npu.current_stream()) |
| 1197 | + with torch.npu.stream(indexer_weight_stream): |
| 1198 | + x = x.view(-1, self.hidden_size) |
| 1199 | + weights = self.weights_proj(x.float())[0].to(torch.bfloat16) |
| 1200 | + weights.record_stream(indexer_weight_stream) |
| 1201 | + weights_event = indexer_weight_stream.record_event() |
| 1202 | + else: |
1196 | 1203 | x = x.view(-1, self.hidden_size) |
1197 | 1204 | weights = self.weights_proj(x.float())[0].to(torch.bfloat16) |
1198 | | - weights.record_stream(indexer_weight_stream) |
1199 | | - weights_event = indexer_weight_stream.record_event() |
1200 | 1205 |
|
1201 | 1206 | k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128] |
1202 | 1207 | k = self.k_norm(k_proj) |
@@ -1278,7 +1283,8 @@ def forward_npu( |
1278 | 1283 |
|
1279 | 1284 | if self.alt_stream is not None: |
1280 | 1285 | torch.npu.current_stream().wait_event(q_rope_event) |
1281 | | - torch.npu.current_stream().wait_event(weights_event) |
| 1286 | + if envs.SGLANG_NPU_USE_MULTI_STREAM.get(): |
| 1287 | + torch.npu.current_stream().wait_event(weights_event) |
1282 | 1288 |
|
1283 | 1289 | block_table = forward_batch.attn_backend.forward_metadata.block_tables |
1284 | 1290 | if ( |
|
0 commit comments