Skip to content

Commit e8a380c

Browse files
author
Siyuan Feng
committed
[Relax] Prevent to generate duplicate func in dispatch_sort_scan
The current pass would generate multiple PrimFuncs even if they are structural equal, which is because `bb.update_func` will not check whether the new func is already in the list. This PR apply dlight at the end of the dispatching instead of after every function.
1 parent de91c5c commit e8a380c

2 files changed

Lines changed: 71 additions & 24 deletions

File tree

python/tvm/relax/backend/dispatch_sort_scan.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919

2020
from functools import reduce
2121
from operator import mul
22+
from typing import Dict
2223

2324
from tvm import DataType, dlight, relax, topi
2425
from tvm.contrib.thrust import can_use_thrust
25-
from tvm.ir import Op
26+
from tvm.ir import GlobalVar, Op
2627
from tvm.ir.module import IRModule
2728
from tvm.ir.transform import PassContext, module_pass
2829
from tvm.relax import PyExprMutator, expr_functor
@@ -41,8 +42,11 @@ class SortScanDispatcher(PyExprMutator):
4142
4243
"""
4344

45+
calls_to_update: Dict[GlobalVar, Target]
46+
4447
def __init__(self, mod):
4548
super().__init__(mod)
49+
self.calls_to_update = {}
4650

4751
def _get_target(self, sinfo: relax.StructInfo) -> Target:
4852
# Get target information from TensorStructInfo
@@ -64,22 +68,32 @@ def _get_target(self, sinfo: relax.StructInfo) -> Target:
6468
)
6569
return target
6670

67-
def _apply_dlight_gpu_fallback(self, target: Target, tir_call: relax.Call) -> None:
68-
# Apply dlight.gpu.Fallback() on GPU
71+
def apply_dlight_gpu_fallback(
72+
self,
73+
) -> None:
74+
""" Apply DLight rules for all the calls that need to be updated."""
75+
for gvar, target in self.calls_to_update.items():
76+
func = self.builder_.get()[gvar]
77+
sch = dlight.base.transform._apply_rules(
78+
func,
79+
target,
80+
rules=[dlight.gpu.Fallback()],
81+
tunable=False,
82+
)
83+
if sch is not None:
84+
assert len(sch) == 1
85+
self.builder_.update_func(gvar, sch[0].mod["main"].with_attr("tir.is_scheduled", 1))
86+
87+
def _append_calls_to_update(self, tir_call: relax.Call, target: Target) -> None:
6988
gvar = tir_call.args[0]
70-
assert isinstance(gvar, relax.GlobalVar)
71-
scan_prim_func = self.builder_.get()[gvar]
72-
sch = dlight.base.transform._apply_rules(
73-
scan_prim_func,
74-
target,
75-
[
76-
dlight.gpu.Fallback(),
77-
],
78-
False,
79-
)
80-
if sch is not None:
81-
assert len(sch) == 1
82-
self.builder_.update_func(gvar, sch[0].mod["main"].with_attr("tir.is_scheduled", 1))
89+
assert isinstance(gvar, GlobalVar)
90+
existing_tgt = self.calls_to_update.get(gvar, None)
91+
if existing_tgt is not None and existing_tgt != target:
92+
raise ValueError(
93+
f"Multiple targets detected for function {gvar}. "
94+
f"Existing target: {existing_tgt}, new target: {target}"
95+
)
96+
self.calls_to_update[gvar] = target
8397

8498
def visit_call_(self, call: relax.Call) -> relax.Expr:
8599
if not isinstance(call.op, Op):
@@ -135,10 +149,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
135149
dtype=call.attrs.dtype,
136150
**kwargs,
137151
)
138-
if not is_gpu_target(tgt):
139-
return tir_call
140-
# apply dlight gpu fallback
141-
self._apply_dlight_gpu_fallback(tgt, tir_call)
152+
self._append_calls_to_update(tir_call, tgt)
142153
return tir_call
143154
if call.op.name in ("relax.cumprod", "relax.cumsum"):
144155
tgt = self._get_target(call.struct_info)
@@ -161,10 +172,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
161172
call.attrs.exclusive,
162173
**kwargs,
163174
)
164-
if not is_gpu_target(tgt):
165-
return tir_call
166-
# apply dlight gpu fallback
167-
self._apply_dlight_gpu_fallback(tgt, tir_call)
175+
self._append_calls_to_update(tir_call, tgt)
168176
return tir_call
169177
return super().visit_call_(call)
170178

@@ -211,4 +219,5 @@ def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule:
211219
if isinstance(func, relax.Function):
212220
func = sort_scan_dispater.visit_expr(func)
213221
sort_scan_dispater.builder_.update_func(gv, func)
222+
sort_scan_dispater.apply_dlight_gpu_fallback()
214223
return sort_scan_dispater.builder_.finalize()

tests/python/relax/test_backend_dispatch_sort_scan.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,5 +361,43 @@ def foo(x: R.Tensor((2, 3), "float32", "cuda")):
361361
assert_structural_equal(mod, expected_mod)
362362

363363

364+
def test_dispatch_topk_gpu():
365+
@I.ir_module
366+
class Before:
367+
I.module_global_infos({"vdevice": [I.vdevice("vulkan")]})
368+
369+
@R.function
370+
def foo(x: R.Tensor((2, 3), "float32", "vulkan")):
371+
with R.dataflow():
372+
# Two same calls should have only one PrimFunc
373+
lv0 = R.topk(x, k=2, axis=1, largest=True)
374+
lv1 = R.topk(x, k=2, axis=1, largest=True)
375+
gv = (lv0, lv1)
376+
R.output(gv)
377+
return gv
378+
379+
target = tvm.target.Target("vulkan", host="llvm")
380+
381+
vdevices = [I.vdevice("vulkan", 0)]
382+
x = relax.Var("x", R.Tensor((2, 3), "float32", vdevices[0]))
383+
bb = relax.BlockBuilder()
384+
with target:
385+
with bb.function("foo", (x,), {"global_symbol": "foo"}):
386+
with bb.dataflow():
387+
lv0 = bb.emit_te(topi.cuda.topk, x, k=2, axis=1, is_ascend=False, dtype="int32")
388+
lv1 = bb.emit_te(topi.cuda.topk, x, k=2, axis=1, is_ascend=False, dtype="int32")
389+
out = (lv0, lv1)
390+
out = bb.emit_output(out)
391+
bb.emit_func_output(out)
392+
expected_mod = bb.finalize()
393+
expected_mod.update_global_info("vdevice", vdevices)
394+
395+
with target:
396+
mod = DispatchSortScan()(Before)
397+
expected_mod = dlight.ApplyDefaultSchedule(dlight.gpu.Fallback())(expected_mod)
398+
399+
assert_structural_equal(mod, expected_mod)
400+
401+
364402
if __name__ == "__main__":
365403
tvm.testing.main()

0 commit comments

Comments
 (0)