From 25d1944494fc1b9a5cd9e0a94cdf0edefea27e71 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 7 Jun 2024 09:01:12 -0500 Subject: [PATCH 1/2] [Relax][Bugfix] FCallPacked not checked in CodegenVMTIR Prior to this commit, an operator's `FCallPacked` attribute, used to specify a 1:1 mapping between a relax operator and a `PackedFunc` that implements it, was only checked in `CodegenVM`. Any operator with `FCallPacked` would raise an error when compiled using `CodegenVMTIR`. This commit removes the `FCallPacked` handling from `CodegenVM` altogether, and instead checks for this attribute as part of `LegalizeOps`. This provides the same functionality across both backends. --- src/relax/backend/vm/codegen_vm.cc | 24 +--- src/relax/backend/vm/codegen_vm_tir.cc | 15 --- src/relax/transform/legalize_ops.cc | 25 ++-- tests/python/relax/test_relax_operators.py | 139 ++++++++++++--------- 4 files changed, 100 insertions(+), 103 deletions(-) diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 1c795594629e..ca2d4d4fdb2e 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -45,21 +45,6 @@ using namespace relax; using namespace tvm::runtime; using namespace tvm::runtime::relax_vm; -namespace { -// Helper function to get the function name of the registered packed function implementation of -// relax operator. -FCallPacked GetPackedFuncName(const Call& call) { - static auto op_map = Op::GetAttrMap("FCallPacked"); - if (call->op.as()) { - Op op = Downcast(call->op); - if (op_map.count(op)) { - return op_map[op]; - } - } - return {}; -} -} // namespace - /*! * \brief A class to generate VM executable for Relax functions. */ @@ -156,14 +141,7 @@ class CodeGenVM : public ExprFunctor { // allocate dst register. RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister : NewRegister(); if (call->op.as()) { - // special case generate for the intrinsics whose attribute fields - // cannot be represented by args in the CallNode - FCallPacked name = GetPackedFuncName(call); - if (!name.empty()) { - // If the operator has a registered packed function implementation, emit call to that packed - // function. - EmitPackedFuncCall(call, name, dst_reg); - } else if (call_node->op == call_builtin_with_ctx_op_) { + if (call_node->op == call_builtin_with_ctx_op_) { // TODO(relax-team) migrate most handling of op to // directly map to call_builtin_with_ctx before codegen and simplify vm codegen. EmitCallBuiltinWithCtx(call, dst_reg); diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index 5e6a1c3f8442..c29e416395d9 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -44,21 +44,6 @@ namespace relax_vm { using vm::VMFuncInfo; -namespace { -// Helper function to get the function name of the registered packed function implementation of -// relax operator. -FCallPacked GetPackedFuncName(const Call& call) { - static auto op_map = Op::GetAttrMap("FCallPacked"); - if (call->op.as()) { - Op op = Downcast(call->op); - if (op_map.count(op)) { - return op_map[op]; - } - } - return {}; -} -} // namespace - /*! * \brief A class to generate VMTIR for Relax functions. * diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 34902fa0f8b6..4a6b44bf2839 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -224,6 +224,7 @@ class LegalizeMutator : public ExprMutator { Expr VisitExpr_(const CallNode* call) final { Call visited_call = Downcast(this->VisitExprPostOrder_(call)); static const auto& legalize_map = Op::GetAttrMap("FLegalize"); + static const auto& call_packed_map = Op::GetAttrMap("FCallPacked"); static const auto& requires_arg_shapes_map = Op::GetAttrMap("RequiresArgumentShapes"); static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed"); static const Op& call_tir_op = Op::Get("relax.call_tir"); @@ -236,7 +237,7 @@ class LegalizeMutator : public ExprMutator { } auto op = GetRef(op_node); - bool can_legalize = [&]() -> bool { + bool shapes_are_known_if_required = [&]() -> bool { bool requires_arg_shapes = requires_arg_shapes_map.get(op, Bool(true))->value; if (!requires_arg_shapes) { // This operator does not require its arguments to have a @@ -299,23 +300,31 @@ class LegalizeMutator : public ExprMutator { return true; }(); - if (!can_legalize) { - return visited_call; - } - FLegalize legalization_func; - if (auto opt_custom_legalize = cmap_.Get(op->name)) { + if (auto opt_custom_legalize = cmap_.Get(op->name); + opt_custom_legalize && shapes_are_known_if_required) { // First choice, use a custom legalization function legalization_func = opt_custom_legalize.value(); - } else if (legalize_map.count(op)) { + } else if (legalize_map.count(op) && shapes_are_known_if_required) { // Second choice, use a default legalization legalization_func = legalize_map[op]; + } else if (call_packed_map.count(op)) { + // Third choice, use an explicit FCallPacked replacement. This does not require the shape + String packed_func_name = call_packed_map[op]; + legalization_func = [packed_func_name](const BlockBuilder& bb, const Call& call) -> Expr { + return Call(ExternFunc(packed_func_name), call->args, Attrs(), {GetStructInfo(call)}); + }; } else { // No legalization. if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op && op != call_pure_packed_op) { - LOG(WARNING) << "No legalization func for " << op->name << " is found."; + if (shapes_are_known_if_required) { + LOG(WARNING) << "No legalization func for " << op->name << " is found."; + } else { + LOG(WARNING) << "Cannot legalize " << visited_call + << ", missing known shapes for arguments and return value"; + } } return visited_call; } diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index 41618a32cb55..fcb8727d8508 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -27,6 +27,8 @@ from tvm._ffi.base import TVMError from tvm.script import ir as I, relax as R, tir as T +exec_mode = tvm.testing.parameter("bytecode", "compiled") + @tvm.script.ir_module class InputModule: @@ -37,7 +39,7 @@ def foo(x: R.Tensor(("m", "n"), "int64")): return y, y_sorted -def run_cpu(mod, func_name, *args): +def run_cpu(mod, func_name, *args, exec_mode): if isinstance(mod, relax.Function): func = mod args = [func_name, *args] @@ -45,17 +47,17 @@ def run_cpu(mod, func_name, *args): mod = tvm.IRModule.from_expr(func) target = tvm.target.Target("llvm") - ex = relax.build(mod, target) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) return vm[func_name](*args) -def test_unique(): +def test_unique(exec_mode): # TODO(prakalp): also add test for compiling and running on cuda device. data_numpy = np.random.randint(0, 16, (16, 16)) data = tvm.nd.array(data_numpy) - result, result_sorted = run_cpu(InputModule, "foo", data) + result, result_sorted = run_cpu(InputModule, "foo", data, exec_mode=exec_mode) expected_output_sorted, indices = np.unique(data_numpy, return_index=True) expected_output = [data_numpy.flatten()[index] for index in sorted(indices, reverse=True)] @@ -81,12 +83,17 @@ def foo(x: R.Tensor((), "int32")): return x -def test_print(): +def test_print(exec_mode): try: stdout = sys.stdout with tempfile.TemporaryFile(mode="w+") as test_out: sys.stdout = test_out - run_cpu(PrintTest, "foo", tvm.nd.array(np.array(1).astype("int32"))) + run_cpu( + PrintTest, + "foo", + tvm.nd.array(np.array(1).astype("int32")), + exec_mode=exec_mode, + ) test_out.seek(0) printed_text = str(test_out.read()) expected = "1\nNumber: 1\nTuple: (1, 1)\n1 (1, 1)\nCustom print: 1 1\nAnother print: 1 (1, 1)\n" @@ -95,65 +102,65 @@ def test_print(): sys.stdout = stdout -def test_assert_passes(): +def test_assert_passes(exec_mode): @R.function(pure=False) def func(x: R.Tensor((), "int32")): _ = R.assert_op(relax.const(True)) return x - run_cpu(func, tvm.nd.array(np.array(1).astype("int32"))) + run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), exec_mode=exec_mode) -def test_assert_passes_with_format_args(): +def test_assert_passes_with_format_args(exec_mode): @R.function(pure=False) def func(x: R.Tensor((), "int32")): _ = R.assert_op(relax.const(True), x, format="You won't see me") return x - run_cpu(func, tvm.nd.array(np.array(1).astype("int32"))) + run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), exec_mode=exec_mode) -def test_assert_fails(): +def test_assert_fails(exec_mode): @R.function(pure=False) def func(x: R.Tensor((), "int32")): _ = R.assert_op(relax.const(False)) return x with pytest.raises(AssertionError, match="Assertion Failed"): - run_cpu(func, tvm.nd.array(np.array(1).astype("int32"))) + run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), exec_mode=exec_mode) -def test_assert_fails_with_message(): +def test_assert_fails_with_message(exec_mode): @R.function(pure=False) def func(x: R.Tensor((), "int32")): _ = R.assert_op(relax.const(False), format="I failed...") return x with pytest.raises(AssertionError, match="I failed..."): - run_cpu(func, tvm.nd.array(np.array(1).astype("int32"))) + run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), exec_mode=exec_mode) -def test_assert_fails_with_args(): +def test_assert_fails_with_args(exec_mode): @R.function(pure=False) def func(x: R.Tensor((), "int32")): _ = R.assert_op(relax.const(False), [x, x]) return x with pytest.raises(AssertionError, match="5, 5"): - run_cpu(func, tvm.nd.array(np.array(5).astype("int32"))) + run_cpu(func, tvm.nd.array(np.array(5).astype("int32")), exec_mode=exec_mode) -def test_assert_fails_with_formatted_args(): +def test_assert_fails_with_formatted_args(exec_mode): @R.function(pure=False) def func(x: R.Tensor((), "int32")): _ = R.assert_op(relax.const(False), x, format="Number: {}") return x with pytest.raises(AssertionError, match="Number: 6"): - run_cpu(func, tvm.nd.array(np.array(6).astype("int32"))) + run_cpu(func, tvm.nd.array(np.array(6).astype("int32")), exec_mode=exec_mode) -def test_assert_on_argument_passes(): +def test_assert_on_argument_passes(exec_mode): @R.function(pure=False) def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")): _ = R.assert_op(condition) @@ -161,10 +168,10 @@ def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")): condition = tvm.nd.array(np.array(True)) x = tvm.nd.array(np.array(5).astype("int32")) - run_cpu(func, condition, x) + run_cpu(func, condition, x, exec_mode=exec_mode) -def test_assert_on_argument_fails(): +def test_assert_on_argument_fails(exec_mode): @R.function(pure=False) def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")): _ = R.assert_op(condition) @@ -173,10 +180,10 @@ def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")): condition = tvm.nd.array(np.array(False)) x = tvm.nd.array(np.array(5).astype("int32")) with pytest.raises(AssertionError): - run_cpu(func, condition, x) + run_cpu(func, condition, x, exec_mode=exec_mode) -def test_assert_on_symbolic_var_passes(): +def test_assert_on_symbolic_var_passes(exec_mode): @R.function(pure=False) def func(x: R.Tensor(["N"], "int32")): N = T.int64() @@ -184,10 +191,10 @@ def func(x: R.Tensor(["N"], "int32")): return x x = tvm.nd.array(np.arange(8, dtype="int32")) - run_cpu(func, x) + run_cpu(func, x, exec_mode=exec_mode) -def test_assert_on_symbolic_var_fails(): +def test_assert_on_symbolic_var_fails(exec_mode): @R.function(pure=False) def func(x: R.Tensor(["N"], "int32")): N = T.int64() @@ -196,7 +203,7 @@ def func(x: R.Tensor(["N"], "int32")): x = tvm.nd.array(np.arange(10, dtype="int32")) with pytest.raises(AssertionError): - run_cpu(func, x) + run_cpu(func, x, exec_mode=exec_mode) @tvm.script.ir_module @@ -223,23 +230,31 @@ def get_constant_shape() -> R.Shape((2, 2)): return R.shape_of(x) -def test_op_shape_of(): - unit_shape = run_cpu(ShapeOfTest, "get_scalar_shape") +def test_op_shape_of(exec_mode): + unit_shape = run_cpu(ShapeOfTest, "get_scalar_shape", exec_mode=exec_mode) assert unit_shape == tvm.runtime.ShapeTuple([]) - const_shape = run_cpu(ShapeOfTest, "get_constant_shape") + const_shape = run_cpu(ShapeOfTest, "get_constant_shape", exec_mode=exec_mode) assert const_shape == tvm.runtime.ShapeTuple([2, 2]) - scalar_shape = run_cpu(ShapeOfTest, "get_shape", tvm.nd.array(np.array(1, dtype="int32"))) + scalar_shape = run_cpu( + ShapeOfTest, "get_shape", tvm.nd.array(np.array(1, dtype="int32")), exec_mode=exec_mode + ) assert scalar_shape == tvm.runtime.ShapeTuple([]) tensor_shape = run_cpu( - ShapeOfTest, "get_shape", tvm.nd.array(np.zeros((1, 2, 3)).astype("int32")) + ShapeOfTest, + "get_shape", + tvm.nd.array(np.zeros((1, 2, 3)).astype("int32")), + exec_mode=exec_mode, ) assert tensor_shape == tvm.runtime.ShapeTuple([1, 2, 3]) constrained_shape = run_cpu( - ShapeOfTest, "get_constrained_shape", tvm.nd.array(np.zeros((1,)).astype("int32")) + ShapeOfTest, + "get_constrained_shape", + tvm.nd.array(np.zeros((1,)).astype("int32")), + exec_mode=exec_mode, ) assert constrained_shape == tvm.runtime.ShapeTuple([1]) @@ -257,7 +272,7 @@ def symbolic_shape(shape: R.Shape(("m", "n"))) -> R.Tensor(ndim=-1): return R.shape_to_tensor(shape) -def test_op_shape_to_tensor(): +def test_op_shape_to_tensor(exec_mode): # Check struct info isinstance(ShapeToTensorTest["const_shape"].body.struct_info, tvm.relax.TensorStructInfo) assert ShapeToTensorTest["const_shape"].body.struct_info.ndim == 1 @@ -265,24 +280,32 @@ def test_op_shape_to_tensor(): assert ShapeToTensorTest["symbolic_shape"].body.struct_info.ndim == 1 # Check its functionality - out2d = run_cpu(ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 2])) + out2d = run_cpu( + ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 2]), exec_mode=exec_mode + ) assert isinstance(out2d, tvm.runtime.ndarray.NDArray) assert np.array_equal(out2d.numpy(), np.array([3, 2])) - out3d = run_cpu(ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2])) + out3d = run_cpu( + ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2]), exec_mode=exec_mode + ) assert isinstance(out3d, tvm.runtime.ndarray.NDArray) assert np.array_equal(out3d.numpy(), np.array([3, 3, 2])) - out4d = run_cpu(ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2, 2])) + out4d = run_cpu( + ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2, 2]), exec_mode=exec_mode + ) assert isinstance(out4d, tvm.runtime.ndarray.NDArray) assert np.array_equal(out4d.numpy(), np.array([3, 3, 2, 2])) - outs = run_cpu(ShapeToTensorTest, "symbolic_shape", tvm.runtime.ShapeTuple([3, 2])) + outs = run_cpu( + ShapeToTensorTest, "symbolic_shape", tvm.runtime.ShapeTuple([3, 2]), exec_mode=exec_mode + ) assert isinstance(outs, tvm.runtime.ndarray.NDArray) assert np.array_equal(outs.numpy(), np.array([3, 2])) -def test_op_call_pure_packed(): +def test_op_call_pure_packed(exec_mode): @tvm.script.ir_module class CallPureTest: @R.function @@ -294,11 +317,11 @@ def pure_copy(x: R.Tensor((3, 4), "float32")): np.random.seed(0) # to avoid flakiness arr = np.random.rand(3, 4).astype("float32") - copy_found = run_cpu(CallPureTest, "pure_copy", tvm.nd.array(arr)) + copy_found = run_cpu(CallPureTest, "pure_copy", tvm.nd.array(arr), exec_mode=exec_mode) assert (copy_found.numpy() == arr).all() -def test_op_call_inplace_packed(): +def test_op_call_inplace_packed(exec_mode): # in this case we can use the same test as above @tvm.script.ir_module class CallInplaceTest: @@ -312,7 +335,7 @@ def pure_copy(x: R.Tensor((3, 4), "float32")): ) return z - @tvm.register_func("test.inplace.add") + @tvm.register_func("test.inplace.add", override=True) def inplace_add(a, b): arr_a = a.numpy() arr_b = b.numpy() @@ -340,11 +363,13 @@ def inplace_add(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): arr_b = np.random.rand(3, 4).astype("float32") sum = arr_a + arr_b tvm_arr_a = tvm.nd.array(arr_a) - result = run_cpu(CallInplaceAddTest, "inplace_add", tvm_arr_a, tvm.nd.array(arr_b)) + result = run_cpu( + CallInplaceAddTest, "inplace_add", tvm_arr_a, tvm.nd.array(arr_b), exec_mode=exec_mode + ) assert result == tvm_arr_a assert (result.numpy() == sum).all() - @tvm.register_func("test.inplace.tuple_add") + @tvm.register_func("test.inplace.tuple_add", override=True) def inplace_tuple_add(a, b): arr_a = a.numpy() arr_b = b.numpy() @@ -374,14 +399,14 @@ def inplace_tuple(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32") sum = arr_a + arr_b tvm_arr_a = tvm.nd.array(arr_a) tvm_arr_b = tvm.nd.array(arr_b) - result = run_cpu(CallInplaceTuple, "inplace_tuple", tvm_arr_a, tvm_arr_b) + result = run_cpu(CallInplaceTuple, "inplace_tuple", tvm_arr_a, tvm_arr_b, exec_mode=exec_mode) assert result[0] == tvm_arr_a assert (result[0].numpy() == sum).all() assert result[1] != tvm_arr_a and result[1] != tvm_arr_b assert (result[1].numpy() == sum).all() -def test_op_to_device(): +def test_op_to_device(exec_mode): @tvm.script.ir_module class CallToDevice: @R.function @@ -397,11 +422,11 @@ def to_dev(x: R.Tensor((3, 4), "float32")): np.random.seed(0) # to avoid flakiness arr = np.random.rand(3, 4).astype("float32") - copy_found = run_cpu(CallToDevice, "to_dev", tvm.nd.array(arr)) + copy_found = run_cpu(CallToDevice, "to_dev", tvm.nd.array(arr), exec_mode=exec_mode) assert (copy_found.numpy() == arr).all() -def test_op_to_vdevice(): +def test_op_to_vdevice(exec_mode): @tvm.script.ir_module class ToVDevice: I.module_global_infos({"vdevice": [I.vdevice("llvm")]}) @@ -414,11 +439,11 @@ def to_vdev(x: R.Tensor((3, 4), "float32")): np.random.seed(0) arr = np.random.rand(3, 4).astype("float32") - copy_found = run_cpu(ToVDevice, "to_vdev", tvm.nd.array(arr)) + copy_found = run_cpu(ToVDevice, "to_vdev", tvm.nd.array(arr), exec_mode=exec_mode) assert (copy_found.numpy() == arr).all() -def test_scalar_tensor_as_branch_condition(): +def test_scalar_tensor_as_branch_condition(exec_mode): """The condition of a branch may be a scalar tensor""" @R.function @@ -429,14 +454,14 @@ def func(condition: R.Tensor((), "bool")): out = R.prim_value(10) return out - res = run_cpu(func, tvm.nd.array(np.array(True))) + res = run_cpu(func, tvm.nd.array(np.array(True)), exec_mode=exec_mode) assert res == 5 - res = run_cpu(func, tvm.nd.array(np.array(False))) + res = run_cpu(func, tvm.nd.array(np.array(False)), exec_mode=exec_mode) assert res == 10 -def test_prim_value_as_branch_condition(): +def test_prim_value_as_branch_condition(exec_mode): """The condition may be a PrimValue""" @R.function @@ -447,14 +472,14 @@ def func(condition: R.Prim("bool")): out = R.prim_value(10) return out - res = run_cpu(func, True) + res = run_cpu(func, True, exec_mode=exec_mode) assert res == 5 - res = run_cpu(func, False) + res = run_cpu(func, False, exec_mode=exec_mode) assert res == 10 -def test_computed_prim_value_as_branch_condition(): +def test_computed_prim_value_as_branch_condition(exec_mode): """The R.Prim condition may be computed within the function""" @R.function @@ -466,10 +491,10 @@ def func(x: R.Tensor(["N"], "int64")): out = R.prim_value(10) return out - res = run_cpu(func, tvm.nd.array(np.arange(16))) + res = run_cpu(func, tvm.nd.array(np.arange(16)), exec_mode=exec_mode) assert res == 5 - res = run_cpu(func, tvm.nd.array(np.arange(20))) + res = run_cpu(func, tvm.nd.array(np.arange(20)), exec_mode=exec_mode) assert res == 10 From b958943877701c6e11cb72d3688fdb11a9b31497 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 5 Sep 2024 12:10:57 -0500 Subject: [PATCH 2/2] Fixup merge conflicts --- src/relax/backend/vm/codegen_vm_tir.cc | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index c29e416395d9..a92cf7c749a0 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -232,14 +232,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister(); if (call->op.as()) { - // special case generate for the intrinsics whose attribute fields - // cannot be represented by args in the CallNode - FCallPacked name = GetPackedFuncName(call); - if (name.size()) { - // If the operator has a registered packed function implementation, emit call to that packed - // function. - EmitCallPacked(name, VisitArray(call->args), dst_reg); - } else if (call_node->op == call_builtin_with_ctx_op_) { + if (call_node->op == call_builtin_with_ctx_op_) { EmitCallBuiltinWithCtx(call, dst_reg); } else if (call_node->op == alloc_storage_op_) { EmitAllocStorage(call, dst_reg);