From cbf0143e963d6ccd2da4f78dd0272b6422a1e913 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 10 Jan 2024 20:04:56 +0000 Subject: [PATCH] [Unity] Check for symbolic vars in PrimValue in when lowering to TIR Prior to this commit, a fused relax function could accept a `R.Prim` value, but wouldn't use it to provide symbolic variables to the fused function. --- python/tvm/relax/utils.py | 7 ++- .../python/relax/test_blockbuilder_emit_te.py | 54 +++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index b720a727f639..a58b65477cee 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -370,8 +370,11 @@ def _convert_te_arg_helper(arg): arg, ShapeExpr ), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr" return [_convert_te_arg_helper(val) for val in arg.values] - if isinstance(arg.struct_info, PrimStructInfo): - return arg.value + if ( + isinstance(arg.struct_info, PrimStructInfo) + and arg.struct_info.value is not None + ): + return _convert_te_arg_helper(arg.struct_info.value) elif isinstance(arg, (list, Array)): return [_convert_te_arg_helper(x) for x in arg] elif isinstance(arg, tuple): diff --git a/tests/python/relax/test_blockbuilder_emit_te.py b/tests/python/relax/test_blockbuilder_emit_te.py index 3724c1a4b884..ea89832ea1ec 100644 --- a/tests/python/relax/test_blockbuilder_emit_te.py +++ b/tests/python/relax/test_blockbuilder_emit_te.py @@ -16,6 +16,7 @@ # under the License. """ This file tests advanced emit_te features with help of TVMScript assertion""" # The tests here depend on tvmscript +import tvm from tvm import te, tir from tvm import relax as rx from tvm.ir.base import assert_structural_equal @@ -69,3 +70,56 @@ def main( return gv assert_structural_equal(after, Expected) + + +def test_symbolic_shape_in_prim_value(): + """Symbolic vars may be provided to TE in R.Prim""" + + def te_slice(tensor, i): + return tvm.te.compute([tensor.shape[1]], lambda j: tensor[i, j], name="slice") + + def from_builder(): + bb = rx.BlockBuilder() + A = rx.Var("A", R.Tensor([16, 16], "float32")) + tir_i = tvm.tir.Var("tir_i", "int64") + relax_i = rx.Var("relax_i", R.Prim(value=tir_i)) + + with bb.function("main", params=[A, relax_i]): + A_sliced = bb.emit_te(te_slice, A, relax_i) + bb.emit_func_output(A_sliced) + + return bb.get() + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def te_slice( + A: T.Buffer([T.int64(16), T.int64(16)], "float32"), + Output: T.Buffer(T.int64(16), "float32"), + row_index: T.int64, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + + for i in range(A.shape[1]): + with T.block("slice"): + vi = T.axis.remap("S", [i]) + Output[vi] = A[row_index, vi] + + @R.function + def main( + A: R.Tensor([16, 16], "float32"), + arg_row_index: R.Prim(value="row_index"), + ): + cls = Expected + + row_index = T.int64() + + gv = R.call_tir( + cls.te_slice, + A, + tir_vars=[row_index], + out_sinfo=R.Tensor([16], "float32"), + ) + return gv + + tvm.ir.assert_structural_equal(from_builder(), Expected)