From 236f92c9240e1c8c9d3a248c9689698a094af708 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 10 Jan 2024 20:16:08 +0000 Subject: [PATCH] [Unity][VM] Recursively visit match bindings in VMShapeLowerMutator Prior to this commit, the `MatchBinding` visitor in `VMShapeLowerMutator`. If the RHS of the `MatchBinding` is a `ShapeExpr` that uses symbolic variables, that RHS must be visited in order to have the symbolic variable updated. --- src/relax/backend/vm/vm_shape_lower.cc | 2 +- .../test_backend_transform_shape_lower.py | 78 +++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 41b27ea6252b..5875ad55628c 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -419,7 +419,7 @@ class VMShapeLowerMutator // These checks are emitted as extra, in codegen // match-cast is simply ignored and treated as a normal binding. - builder_->EmitNormalized(GetRef(binding)); + ExprMutator::VisitBinding_(binding); } // Do not override shape in struct info fields diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index b9a353763032..31eb4b26bee0 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -731,5 +731,83 @@ def main( assert_structural_equal(after, expected) +def test_update_symbolic_vars_in_match_cast_rhs(): + """Symbolic variables may be used on the RHS of match_cast""" + + @I.ir_module + class Before: + @R.function + def main( + arg_prim_value: R.Prim(value="n"), + ): + R.func_attr({"relax.force_pure": 1}) + n = T.int64() + shape = R.shape([n]) + m = T.int64() + _ = R.match_cast(shape, R.Shape([m])) + return R.prim_value(m) + + @I.ir_module + class Expected: + @R.function + def main(arg_prim_value: R.Prim(value="n")) -> R.Prim("int64"): + R.func_attr({"relax.force_pure": 1}) + n = T.int64() + + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [2], + sinfo_args=(R.Tensor(dtype="int64", ndim=1),), + ) + _ = R.call_packed( + "vm.builtin.check_prim_value_info", + arg_prim_value, + R.dtype("int64"), + "", + sinfo_args=[R.Tuple], + ) + _ = R.call_packed( + "vm.builtin.match_prim_value", + arg_prim_value, + shape_heap, + MatchShapeCode.STORE_TO_HEAP, + 0, + "", + sinfo_args=[R.Tuple], + ) + shape = R.call_packed( + "vm.builtin.make_shape", + shape_heap, + 1, + MakeShapeCode.LOAD_SHAPE, + 0, + sinfo_args=[R.Shape(ndim=1)], + ) + _ = R.call_packed( + "vm.builtin.match_shape", + shape, + shape_heap, + 1, + MatchShapeCode.STORE_TO_HEAP, + 1, + "", + sinfo_args=[R.Tuple], + ) + + m = T.int64() + _ = R.match_cast(shape, R.Shape([m])) + gv = R.call_packed( + "vm.builtin.make_prim_value", + shape_heap, + MakeShapeCode.LOAD_SHAPE, + 1, + sinfo_args=[R.Prim(value=m)], + ) + return gv + + After = relax.transform.VMShapeLower(emit_err_ctx=False)(Before) + assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main()