Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/tvm/contrib/ndk.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import os
import shutil
from typing import Dict
import tempfile
from pathlib import Path

from .._ffi import register_func
from .._ffi.base import py_str
from . import utils as _utils, tar as _tar, cc as _cc
from .cc import get_target_by_dump_machine
Expand Down Expand Up @@ -152,3 +155,12 @@ def get_global_symbol_section_map(path, *, nm=None) -> Dict[str, str]:
base_path = os.path.dirname(compiler)
nm = os.path.join(base_path, "llvm-nm")
return _cc.get_global_symbol_section_map(path, nm=nm)


@register_func("meta_schedule.builder.export_ndk")
def _ndk_export(mod):
tmp_dir = tempfile.mkdtemp()
binary_name = "tmp_binary.so"
binary_path = Path(tmp_dir) / binary_name
mod.export_library(binary_path, fcompile=create_shared)
return str(binary_path)
3 changes: 2 additions & 1 deletion src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,8 @@ inline void CloneRules(const SpaceGeneratorNode* src, SpaceGeneratorNode* dst) {

/*! \brief Returns true if the given target is one of the supported gpu targets. */
inline bool IsGPUTarget(const std::string& target_name) {
static const std::unordered_set<std::string> gpu_targets{"cuda", "rocm", "vulkan", "metal"};
static const std::unordered_set<std::string> gpu_targets{"cuda", "rocm", "vulkan", "metal",
"opencl"};
return gpu_targets.count(target_name);
}

Expand Down
18 changes: 18 additions & 0 deletions tests/python/contrib/test_android/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

""" Testing infrastructure for Android """
57 changes: 57 additions & 0 deletions tests/python/contrib/test_android/infrastructure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name

""" Android testing infrastructure """

import os
import tvm
from tvm.meta_schedule.runner import RPCRunner, RPCConfig, EvaluatorConfig


def get_rpc_runner() -> tvm.meta_schedule.runner.RPCRunner:
if (
"TVM_TRACKER_HOST" in os.environ
and "TVM_TRACKER_PORT" in os.environ
and "RPC_DEVICE_KEY" in os.environ
):
rpc_host = os.environ["TVM_TRACKER_HOST"]
rpc_port = int(os.environ["TVM_TRACKER_PORT"])
rpc_key = os.environ["RPC_DEVICE_KEY"]
else:
raise Exception("Please initialize environment variables for using RPC tracker")

rpc_config = RPCConfig(
tracker_host=rpc_host,
tracker_port=rpc_port,
tracker_key=rpc_key,
session_priority=1,
session_timeout_sec=100,
)
evaluator_config = EvaluatorConfig(
number=1,
repeat=1,
min_repeat_ms=0,
)
return RPCRunner(rpc_config, evaluator_config)


def get_android_gpu_target() -> tvm.target.Target:
"""Creates a Android GPU target"""
target_c = "opencl"
target_h = "llvm -mtriple=arm64-linux-android"
return tvm.target.Target(target_c, host=target_h)
71 changes: 71 additions & 0 deletions tests/python/contrib/test_android/test_meta_schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

""" Test rpc based launcher for Android """
import tempfile

import numpy as np
import pytest
import tvm.testing
import tvm.topi.testing
from tvm import meta_schedule as ms
from tvm.meta_schedule.builder import LocalBuilder
from tvm.script import tir as T

from .infrastructure import get_android_gpu_target, get_rpc_runner


@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


@pytest.mark.skip("Integration test")
def test_tune_tir_on_android():
"""Test tune_tir on Android through RPC."""
max_workers = 4
builder = LocalBuilder(f_export="meta_schedule.builder.export_ndk", max_workers=max_workers)
runner = get_rpc_runner()
target = get_android_gpu_target()
with tempfile.TemporaryDirectory() as work_dir:
database = ms.tir_integration.tune_tir(
mod=matmul,
target=target,
work_dir=work_dir,
max_trials_global=32,
num_trials_per_iter=16,
builder=builder,
runner=runner,
)
sch = ms.tir_integration.compile_tir(database, matmul, target)
if sch is None:
print("No valid schedule found!")
else:
sch.mod.show()
sch.trace.show()


if __name__ == "__main__":
tvm.testing.main()