1919
2020from functools import reduce
2121from operator import mul
22+ from typing import Dict
2223
2324from tvm import DataType , dlight , relax , topi
2425from tvm .contrib .thrust import can_use_thrust
25- from tvm .ir import Op
26+ from tvm .ir import GlobalVar , Op
2627from tvm .ir .module import IRModule
2728from tvm .ir .transform import PassContext , module_pass
2829from tvm .relax import PyExprMutator , expr_functor
@@ -41,8 +42,11 @@ class SortScanDispatcher(PyExprMutator):
4142
4243 """
4344
45+ calls_to_update : Dict [GlobalVar , Target ]
46+
4447 def __init__ (self , mod ):
4548 super ().__init__ (mod )
49+ self .calls_to_update = {}
4650
4751 def _get_target (self , sinfo : relax .StructInfo ) -> Target :
4852 # Get target information from TensorStructInfo
@@ -64,22 +68,32 @@ def _get_target(self, sinfo: relax.StructInfo) -> Target:
6468 )
6569 return target
6670
67- def _apply_dlight_gpu_fallback (self , target : Target , tir_call : relax .Call ) -> None :
68- # Apply dlight.gpu.Fallback() on GPU
71+ def apply_dlight_gpu_fallback (
72+ self ,
73+ ) -> None :
74+ """ Apply DLight rules for all the calls that need to be updated."""
75+ for gvar , target in self .calls_to_update .items ():
76+ func = self .builder_ .get ()[gvar ]
77+ sch = dlight .base .transform ._apply_rules (
78+ func ,
79+ target ,
80+ rules = [dlight .gpu .Fallback ()],
81+ tunable = False ,
82+ )
83+ if sch is not None :
84+ assert len (sch ) == 1
85+ self .builder_ .update_func (gvar , sch [0 ].mod ["main" ].with_attr ("tir.is_scheduled" , 1 ))
86+
87+ def _append_calls_to_update (self , tir_call : relax .Call , target : Target ) -> None :
6988 gvar = tir_call .args [0 ]
70- assert isinstance (gvar , relax .GlobalVar )
71- scan_prim_func = self .builder_ .get ()[gvar ]
72- sch = dlight .base .transform ._apply_rules (
73- scan_prim_func ,
74- target ,
75- [
76- dlight .gpu .Fallback (),
77- ],
78- False ,
79- )
80- if sch is not None :
81- assert len (sch ) == 1
82- self .builder_ .update_func (gvar , sch [0 ].mod ["main" ].with_attr ("tir.is_scheduled" , 1 ))
89+ assert isinstance (gvar , GlobalVar )
90+ existing_tgt = self .calls_to_update .get (gvar , None )
91+ if existing_tgt is not None and existing_tgt != target :
92+ raise ValueError (
93+ f"Multiple targets detected for function { gvar } . "
94+ f"Existing target: { existing_tgt } , new target: { target } "
95+ )
96+ self .calls_to_update [gvar ] = target
8397
8498 def visit_call_ (self , call : relax .Call ) -> relax .Expr :
8599 if not isinstance (call .op , Op ):
@@ -135,10 +149,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
135149 dtype = call .attrs .dtype ,
136150 ** kwargs ,
137151 )
138- if not is_gpu_target (tgt ):
139- return tir_call
140- # apply dlight gpu fallback
141- self ._apply_dlight_gpu_fallback (tgt , tir_call )
152+ self ._append_calls_to_update (tir_call , tgt )
142153 return tir_call
143154 if call .op .name in ("relax.cumprod" , "relax.cumsum" ):
144155 tgt = self ._get_target (call .struct_info )
@@ -161,10 +172,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
161172 call .attrs .exclusive ,
162173 ** kwargs ,
163174 )
164- if not is_gpu_target (tgt ):
165- return tir_call
166- # apply dlight gpu fallback
167- self ._apply_dlight_gpu_fallback (tgt , tir_call )
175+ self ._append_calls_to_update (tir_call , tgt )
168176 return tir_call
169177 return super ().visit_call_ (call )
170178
@@ -211,4 +219,5 @@ def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule:
211219 if isinstance (func , relax .Function ):
212220 func = sort_scan_dispater .visit_expr (func )
213221 sort_scan_dispater .builder_ .update_func (gv , func )
222+ sort_scan_dispater .apply_dlight_gpu_fallback ()
214223 return sort_scan_dispater .builder_ .finalize ()
0 commit comments