|
45 | 45 |
|
46 | 46 | if not is_cpu() and not is_npu(): |
47 | 47 | # fix import error on CPU device, no impacts when non-CPU path |
48 | | - from sglang.jit_kernel.cutedsl_gdn import ( |
49 | | - cutedsl_fused_sigmoid_gating_delta_rule_update, |
50 | | - ) |
| 48 | + try: |
| 49 | + from sglang.jit_kernel.cutedsl_gdn import ( |
| 50 | + cutedsl_fused_sigmoid_gating_delta_rule_update, |
| 51 | + ) |
| 52 | + except ModuleNotFoundError: |
| 53 | + # CuTe DSL path requires cuda-python (cuda.bindings.*). Keep runtime usable |
| 54 | + # by falling back to non-CuTe kernels when it's unavailable. |
| 55 | + cutedsl_fused_sigmoid_gating_delta_rule_update = None |
51 | 56 | from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule |
52 | 57 | from sglang.srt.layers.attention.fla.chunk_delta_h import ( |
53 | 58 | CHUNK_SIZE as FLA_CHUNK_SIZE, |
@@ -830,6 +835,12 @@ def __init__(self, model_runner: ModelRunner): |
830 | 835 | ), f"{self.conv_states_shape[-1]=} should be less than {FLA_CHUNK_SIZE}" |
831 | 836 |
|
832 | 837 | use_cutedsl = Envs.SGLANG_USE_CUTEDSL_GDN_DECODE.get() |
| 838 | + if use_cutedsl and cutedsl_fused_sigmoid_gating_delta_rule_update is None: |
| 839 | + rank0_log( |
| 840 | + "CuTe DSL GDN decode requested but unavailable " |
| 841 | + "(missing cuda.bindings). Falling back to FLA decode kernel." |
| 842 | + ) |
| 843 | + use_cutedsl = False |
833 | 844 | rank0_log(f"CuTe DSL GDN decode enabled: {use_cutedsl}") |
834 | 845 | self._kernel_func = ( |
835 | 846 | cutedsl_fused_sigmoid_gating_delta_rule_update |
|
0 commit comments