Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions python/tvm/relax/backend/dispatch_sort_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
from tvm.target import Target


def is_gpu_target(target: Target) -> bool:
"""Check if the target is a GPU target."""
return "gpu" in target.keys


@expr_functor.mutator
class SortScanDispatcher(PyExprMutator):
"""
Expand Down Expand Up @@ -88,7 +93,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
te_func = topi.cuda.sort_thrust
kwargs["workspace"] = self.allocate_workspace(call)
elif tgt.kind.name == "cuda":
elif is_gpu_target(tgt):
te_func = topi.cuda.sort
return self.builder_.call_te(
te_func, call.args[0], call.attrs.axis, not call.attrs.descending, **kwargs
Expand All @@ -101,7 +106,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
te_func = topi.cuda.argsort_thrust
kwargs["workspace"] = self.allocate_workspace(call)
elif tgt.kind.name == "cuda":
elif is_gpu_target(tgt):
te_func = topi.cuda.argsort
return self.builder_.call_te(
te_func,
Expand All @@ -118,7 +123,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
te_func = topi.cuda.topk_thrust
kwargs["workspace"] = self.allocate_workspace(call)
elif tgt.kind.name == "cuda":
elif is_gpu_target(tgt):
te_func = topi.cuda.topk
tir_call = self.builder_.call_te(
te_func,
Expand All @@ -130,7 +135,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
dtype=call.attrs.dtype,
**kwargs,
)
if tgt.kind.name != "cuda":
if not is_gpu_target(tgt):
return tir_call
# apply dlight gpu fallback
self._apply_dlight_gpu_fallback(tgt, tir_call)
Expand All @@ -141,11 +146,11 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
kwargs = {}
with tgt:
if call.op.name == "relax.cumsum":
te_func = topi.cuda.cumsum if tgt.kind.name == "cuda" else topi.cumsum
te_func = topi.cuda.cumsum if is_gpu_target(tgt) else topi.cumsum
if can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan"):
kwargs["workspace"] = self.allocate_workspace(call)
elif call.op.name == "relax.cumprod":
te_func = topi.cuda.cumprod if tgt.kind.name == "cuda" else topi.cumprod
te_func = topi.cuda.cumprod if is_gpu_target(tgt) else topi.cumprod
else:
raise ValueError(f"Unsupported op: {call.op.name}")
tir_call = self.builder_.call_te(
Expand All @@ -156,7 +161,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
call.attrs.exclusive,
**kwargs,
)
if tgt.kind.name != "cuda":
if not is_gpu_target(tgt):
return tir_call
# apply dlight gpu fallback
self._apply_dlight_gpu_fallback(tgt, tir_call)
Expand Down