Skip to content

Commit 2cc235e

Browse files
Fix Bug on dsv3.2 (sgl-project#18553)
This PR affects only the NPU. If any issues arise, please contact iforgetmyname.
1 parent d84d206 commit 2cc235e

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

python/sglang/srt/layers/attention/nsa/nsa_indexer.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from einops import rearrange
99

10+
from sglang.srt.environ import envs
1011
from sglang.srt.layers.layernorm import LayerNorm
1112
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
1213
from sglang.srt.layers.utils import MultiPlatformOp
@@ -1190,13 +1191,17 @@ def forward_npu(
11901191
) # [bs, n, d]
11911192
q = torch.cat([q_pe, q_nope], dim=-1)
11921193

1193-
indexer_weight_stream = get_indexer_weight_stream()
1194-
indexer_weight_stream.wait_stream(torch.npu.current_stream())
1195-
with torch.npu.stream(indexer_weight_stream):
1194+
if envs.SGLANG_NPU_USE_MULTI_STREAM.get():
1195+
indexer_weight_stream = get_indexer_weight_stream()
1196+
indexer_weight_stream.wait_stream(torch.npu.current_stream())
1197+
with torch.npu.stream(indexer_weight_stream):
1198+
x = x.view(-1, self.hidden_size)
1199+
weights = self.weights_proj(x.float())[0].to(torch.bfloat16)
1200+
weights.record_stream(indexer_weight_stream)
1201+
weights_event = indexer_weight_stream.record_event()
1202+
else:
11961203
x = x.view(-1, self.hidden_size)
11971204
weights = self.weights_proj(x.float())[0].to(torch.bfloat16)
1198-
weights.record_stream(indexer_weight_stream)
1199-
weights_event = indexer_weight_stream.record_event()
12001205

12011206
k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128]
12021207
k = self.k_norm(k_proj)
@@ -1278,7 +1283,8 @@ def forward_npu(
12781283

12791284
if self.alt_stream is not None:
12801285
torch.npu.current_stream().wait_event(q_rope_event)
1281-
torch.npu.current_stream().wait_event(weights_event)
1286+
if envs.SGLANG_NPU_USE_MULTI_STREAM.get():
1287+
torch.npu.current_stream().wait_event(weights_event)
12821288

12831289
block_table = forward_batch.attn_backend.forward_metadata.block_tables
12841290
if (

python/sglang/srt/managers/overlap_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66
import torch
77

88
from sglang.srt.speculative.spec_utils import spec_need_hidden_states
9-
from sglang.srt.utils import get_compiler_backend
9+
from sglang.srt.utils import get_compiler_backend, is_npu
1010

1111
if TYPE_CHECKING:
1212
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
1313
from sglang.srt.managers.scheduler import GenerationBatchResult
1414
from sglang.srt.speculative.eagle_info import EagleDraftInput
1515
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
1616

17+
_is_npu = is_npu()
1718

18-
@torch.compile(dynamic=True, backend=get_compiler_backend())
19+
20+
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
1921
def _resolve_future_token_ids(input_ids, future_token_ids_map):
2022
input_ids[:] = torch.where(
2123
input_ids < 0,

0 commit comments

Comments
 (0)