From 8c86ee8531103a97cc82d36dd0ca2a60bf6b0586 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 15 Jun 2024 09:23:13 -0500 Subject: [PATCH] [TIR][RPC] Allow RPC calls to compiled PrimFuncs with no arguments The `PackedFunc` interface has arguments `int num_args` and `TVMValue* args`, which contain the number of arguments and a pointer to the array of arguments. Prior to this commit, when implementing the `PackedFunc` interface for TIR `PrimFunc`s, the `MakePackedAPI` pass would always assert that the `args` pointer was not null. However, the `args` pointer is allowed to be null if `num_args` is zero. For example, this occurs when calling an RPC function with no arguments. This commit updates the `MakePackedAPI` transform to only assert that `args` is non-null when `num_args` is greater than zero. --- src/tir/transforms/make_packed_api.cc | 10 ++-- tests/python/runtime/test_runtime_rpc.py | 55 ++++++++++++++++++- .../test_tir_transform_make_packed_api.py | 41 ++++++++++++++ 3 files changed, 99 insertions(+), 7 deletions(-) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index bf1f3a9e7fd2..d327cdfa8393 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -296,10 +296,12 @@ PrimFunc MakePackedAPI(PrimFunc func) { return error_message.str(); }())); - seq_init.push_back( - MakeAssertNotNull(v_packed_args, name_hint + ": TVMValue* arg pointer was NULL")); - seq_init.push_back( - MakeAssertNotNull(buf_packed_arg_type_ids->data, name_hint + ": int* type_codes was NULL")); + if (num_args > 0) { + seq_init.push_back( + MakeAssertNotNull(v_packed_args, name_hint + ": TVMValue* arg pointer was NULL")); + seq_init.push_back( + MakeAssertNotNull(buf_packed_arg_type_ids->data, name_hint + ": int* type_codes was NULL")); + } seq_init.emplace_back(DeclBuffer(buf_packed_arg_type_ids, nop)); diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index 4963124b6224..fbdc33928b6e 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -14,22 +14,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te -import tvm.testing + import multiprocessing import os import stat import sys +import tempfile import time import pytest import numpy as np + +import tvm +import tvm.testing + +from tvm import te from tvm import rpc from tvm.relay.backend import Runtime from tvm.contrib import utils, cc from tvm.rpc.tracker import Tracker from tvm.rpc.proxy import Proxy +from tvm.script import ir as I, tir as T if __name__ == "__main__": @@ -685,3 +690,47 @@ def test_rpc_session_timeout_error(with_proxy): if with_proxy: proxy.terminate() tracker.terminate() + + +@pytest.mark.parametrize("call_with_unused_argument", [True, False]) +def test_compiled_function_with_zero_arguments(call_with_unused_argument): + """RPC functions do not require an argument + + This is a regression test. When no arguments are provided, RPC + provides NULL as the `TVMValue* args` argument to a PackedFunc. + However, previous implementations of `MakePackedAPI` + unconditionally asserted that the `args` pointer was non-null. + This assertion is now generated only when the function accepts + a non-zero number of arguments. + + """ + + @I.ir_module + class Module: + @T.prim_func + def func_without_arg() -> T.int64: + return T.int64(42) + + @T.prim_func + def func_with_arg(unused: T.int64) -> T.int64: + return T.int64(42) + + built = tvm.build(Module, target="llvm") + + server = tvm.rpc.Server(key="x1") + client = tvm.rpc.connect("127.0.0.1", server.port, key="x1") + + libname = "libbuilt.so" + with tempfile.TemporaryDirectory(prefix="tvm_rpc_testing_") as temp_dir: + local_path = os.path.join(temp_dir, libname) + built.export_library(local_path) + client.upload(local_path) + + remote_mod = client.load_module(libname) + + if call_with_unused_argument: + res = remote_mod["func_with_arg"](0) + else: + res = remote_mod["func_without_arg"]() + + assert res == 42 diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py b/tests/python/tir-transform/test_tir_transform_make_packed_api.py index bf182654d750..23a51a0817df 100644 --- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py +++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py @@ -353,5 +353,46 @@ def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")): built(A, B) +def test_zero_arg_function(): + """Only check non-null args when num_args>0""" + + @I.ir_module + class Before: + @T.prim_func + def func_without_arg() -> T.int64: + T.func_attr({"target": T.target("llvm", host="llvm")}) + return T.int64(42) + + @I.ir_module + class Expected: + @T.prim_func + def func_without_arg( + args: T.handle, + arg_type_ids: T.handle("int32"), + num_args: T.int32, + out_ret_value: T.handle("void"), + out_ret_tcode: T.handle("int32"), + resource_handle: T.handle, + ) -> T.int32: + T.func_attr( + { + "calling_conv": 1, + "target": T.target("llvm"), + } + ) + assert num_args == 0, "func_without_arg: num_args should be 0" + arg_type_ids_1 = T.decl_buffer((0,), "int32", data=arg_type_ids) + with T.attr(0, "compute_scope", "func_without_arg_compute_"): + out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) + out_ret_value_1[0] = T.Cast("int64", T.int64(42)) + out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,)) + out_ret_tcode_1[0] = 0 + return 0 + return 0 + + After = tvm.tir.transform.MakePackedAPI()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main()