diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py b/python/tvm/relax/backend/dispatch_sort_scan.py index 480420c31373..064d3abf2581 100644 --- a/python/tvm/relax/backend/dispatch_sort_scan.py +++ b/python/tvm/relax/backend/dispatch_sort_scan.py @@ -29,6 +29,11 @@ from tvm.target import Target +def is_gpu_target(target: Target) -> bool: + """Check if the target is a GPU target.""" + return "gpu" in target.keys + + @expr_functor.mutator class SortScanDispatcher(PyExprMutator): """ @@ -88,7 +93,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: if can_use_thrust(tgt, "tvm.contrib.thrust.sort"): te_func = topi.cuda.sort_thrust kwargs["workspace"] = self.allocate_workspace(call) - elif tgt.kind.name == "cuda": + elif is_gpu_target(tgt): te_func = topi.cuda.sort return self.builder_.call_te( te_func, call.args[0], call.attrs.axis, not call.attrs.descending, **kwargs @@ -101,7 +106,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: if can_use_thrust(tgt, "tvm.contrib.thrust.sort"): te_func = topi.cuda.argsort_thrust kwargs["workspace"] = self.allocate_workspace(call) - elif tgt.kind.name == "cuda": + elif is_gpu_target(tgt): te_func = topi.cuda.argsort return self.builder_.call_te( te_func, @@ -118,7 +123,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: if can_use_thrust(tgt, "tvm.contrib.thrust.sort"): te_func = topi.cuda.topk_thrust kwargs["workspace"] = self.allocate_workspace(call) - elif tgt.kind.name == "cuda": + elif is_gpu_target(tgt): te_func = topi.cuda.topk tir_call = self.builder_.call_te( te_func, @@ -130,7 +135,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: dtype=call.attrs.dtype, **kwargs, ) - if tgt.kind.name != "cuda": + if not is_gpu_target(tgt): return tir_call # apply dlight gpu fallback self._apply_dlight_gpu_fallback(tgt, tir_call) @@ -141,11 +146,11 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: kwargs = {} with tgt: if call.op.name == "relax.cumsum": - te_func = topi.cuda.cumsum if tgt.kind.name == "cuda" else topi.cumsum + te_func = topi.cuda.cumsum if is_gpu_target(tgt) else topi.cumsum if can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan"): kwargs["workspace"] = self.allocate_workspace(call) elif call.op.name == "relax.cumprod": - te_func = topi.cuda.cumprod if tgt.kind.name == "cuda" else topi.cumprod + te_func = topi.cuda.cumprod if is_gpu_target(tgt) else topi.cumprod else: raise ValueError(f"Unsupported op: {call.op.name}") tir_call = self.builder_.call_te( @@ -156,7 +161,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: call.attrs.exclusive, **kwargs, ) - if tgt.kind.name != "cuda": + if not is_gpu_target(tgt): return tir_call # apply dlight gpu fallback self._apply_dlight_gpu_fallback(tgt, tir_call)