diff --git a/python/tvm/contrib/ndk.py b/python/tvm/contrib/ndk.py index 2a1105ed2bbb..14820c0ca8ab 100644 --- a/python/tvm/contrib/ndk.py +++ b/python/tvm/contrib/ndk.py @@ -22,7 +22,10 @@ import os import shutil from typing import Dict +import tempfile +from pathlib import Path +from .._ffi import register_func from .._ffi.base import py_str from . import utils as _utils, tar as _tar, cc as _cc from .cc import get_target_by_dump_machine @@ -152,3 +155,12 @@ def get_global_symbol_section_map(path, *, nm=None) -> Dict[str, str]: base_path = os.path.dirname(compiler) nm = os.path.join(base_path, "llvm-nm") return _cc.get_global_symbol_section_map(path, nm=nm) + + +@register_func("meta_schedule.builder.export_ndk") +def _ndk_export(mod): + tmp_dir = tempfile.mkdtemp() + binary_name = "tmp_binary.so" + binary_path = Path(tmp_dir) / binary_name + mod.export_library(binary_path, fcompile=create_shared) + return str(binary_path) diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 60840ca1634e..ceb0356cbcfe 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -513,7 +513,8 @@ inline void CloneRules(const SpaceGeneratorNode* src, SpaceGeneratorNode* dst) { /*! \brief Returns true if the given target is one of the supported gpu targets. */ inline bool IsGPUTarget(const std::string& target_name) { - static const std::unordered_set gpu_targets{"cuda", "rocm", "vulkan", "metal"}; + static const std::unordered_set gpu_targets{"cuda", "rocm", "vulkan", "metal", + "opencl"}; return gpu_targets.count(target_name); } diff --git a/tests/python/contrib/test_android/__init__.py b/tests/python/contrib/test_android/__init__.py new file mode 100644 index 000000000000..9669578bb7ad --- /dev/null +++ b/tests/python/contrib/test_android/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Testing infrastructure for Android """ diff --git a/tests/python/contrib/test_android/infrastructure.py b/tests/python/contrib/test_android/infrastructure.py new file mode 100644 index 000000000000..b78d0bb40e21 --- /dev/null +++ b/tests/python/contrib/test_android/infrastructure.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name + +""" Android testing infrastructure """ + +import os +import tvm +from tvm.meta_schedule.runner import RPCRunner, RPCConfig, EvaluatorConfig + + +def get_rpc_runner() -> tvm.meta_schedule.runner.RPCRunner: + if ( + "TVM_TRACKER_HOST" in os.environ + and "TVM_TRACKER_PORT" in os.environ + and "RPC_DEVICE_KEY" in os.environ + ): + rpc_host = os.environ["TVM_TRACKER_HOST"] + rpc_port = int(os.environ["TVM_TRACKER_PORT"]) + rpc_key = os.environ["RPC_DEVICE_KEY"] + else: + raise Exception("Please initialize environment variables for using RPC tracker") + + rpc_config = RPCConfig( + tracker_host=rpc_host, + tracker_port=rpc_port, + tracker_key=rpc_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + ) + return RPCRunner(rpc_config, evaluator_config) + + +def get_android_gpu_target() -> tvm.target.Target: + """Creates a Android GPU target""" + target_c = "opencl" + target_h = "llvm -mtriple=arm64-linux-android" + return tvm.target.Target(target_c, host=target_h) diff --git a/tests/python/contrib/test_android/test_meta_schedule.py b/tests/python/contrib/test_android/test_meta_schedule.py new file mode 100644 index 000000000000..eac5fab30357 --- /dev/null +++ b/tests/python/contrib/test_android/test_meta_schedule.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Test rpc based launcher for Android """ +import tempfile + +import numpy as np +import pytest +import tvm.testing +import tvm.topi.testing +from tvm import meta_schedule as ms +from tvm.meta_schedule.builder import LocalBuilder +from tvm.script import tir as T + +from .infrastructure import get_android_gpu_target, get_rpc_runner + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@pytest.mark.skip("Integration test") +def test_tune_tir_on_android(): + """Test tune_tir on Android through RPC.""" + max_workers = 4 + builder = LocalBuilder(f_export="meta_schedule.builder.export_ndk", max_workers=max_workers) + runner = get_rpc_runner() + target = get_android_gpu_target() + with tempfile.TemporaryDirectory() as work_dir: + database = ms.tir_integration.tune_tir( + mod=matmul, + target=target, + work_dir=work_dir, + max_trials_global=32, + num_trials_per_iter=16, + builder=builder, + runner=runner, + ) + sch = ms.tir_integration.compile_tir(database, matmul, target) + if sch is None: + print("No valid schedule found!") + else: + sch.mod.show() + sch.trace.show() + + +if __name__ == "__main__": + tvm.testing.main()