Skip to content

Commit 40765de

Browse files
committed
AOT with LLVM Codegen on Hexagon
1 parent a945586 commit 40765de

10 files changed

Lines changed: 117 additions & 37 deletions

File tree

python/tvm/contrib/hexagon/build.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str):
182182
assert self._workspace
183183
self._copy_to_remote(local_path, os.path.join(str(self._workspace), remote_filename))
184184

185-
def start_session(self) -> Session:
185+
def start_session(self, name="hexagon-rpc") -> Session:
186186
"""Connect to the RPC server.
187187
188188
Returns
@@ -197,7 +197,7 @@ def start_session(self) -> Session:
197197
"timeout": 0,
198198
"key": self._device_key,
199199
}
200-
return Session(self, hexagon_remote_kw)
200+
return Session(self, hexagon_remote_kw, session_name=name)
201201

202202
def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module], session: Session):
203203
"""Load TVM module.

python/tvm/contrib/hexagon/session.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,12 @@ def __enter__(self):
8686
self._rpc_receive_buffer_size_bytes,
8787
],
8888
)
89-
self.device = self._rpc.hexagon(0)
89+
if self._session_name == "cpu-rpc":
90+
self.device = self._rpc.cpu(0)
91+
elif self._session_name == "hexagon-rpc":
92+
self.device = self._rpc.hexagon(0)
93+
else:
94+
raise RuntimeError("Incorrect session name: %s", self._session_name)
9095
return self
9196

9297
except RuntimeError as exception:
@@ -286,6 +291,11 @@ def _aot_executor_from_factory(
286291
for target in module.target.values()
287292
if "hexagon" in target.keys
288293
)
294+
assert len(module.target.values()) == 1
295+
296+
for target in module.target.values():
297+
target_kind = str(target).split()[0]
298+
289299
assert hexagon_arch, "No hexagon target architecture found"
290300
assert len(hexagon_arch) == 1, f"Inconsistent hexagon architecture found, {hexagon_arch}"
291301
hexagon_arch = hexagon_arch.pop()
@@ -295,11 +305,19 @@ def _aot_executor_from_factory(
295305
binary_name = "test_binary.so"
296306
binary_path = temp_dir / binary_name
297307

298-
module.export_library(
299-
str(binary_path),
300-
fcompile=hexagon.create_aot_shared,
301-
hexagon_arch=hexagon_arch,
302-
)
308+
if target_kind == "hexagon":
309+
module.export_library(
310+
str(binary_path),
311+
fcompile=hexagon.create_aot_shared,
312+
hexagon_arch=hexagon_arch,
313+
)
314+
elif target_kind == "llvm":
315+
module.export_library(
316+
str(binary_path),
317+
cc=hexagon.hexagon_clang_plus(),
318+
)
319+
else:
320+
raise ValueError("Incorrect Target kind.")
303321

304322
self.upload(binary_path, binary_name)
305323

src/relay/backend/aot_executor_codegen.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,12 +1234,10 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
12341234
Target target_host;
12351235
for (const auto& it : tmp) {
12361236
auto dev_type = it.first.as<tir::IntImmNode>();
1237-
if (!target_host.defined() && it.second->kind->device_type == kDLCPU) {
1237+
if (!target_host.defined() && ((it.second->kind->device_type == kDLCPU) ||
1238+
(it.second->kind->device_type == kDLHexagon))) {
12381239
target_host = it.second;
12391240
}
1240-
if (!target_host.defined() && it.second->kind->device_type == kDLHexagon) {
1241-
target_host = *(new Target("c"));
1242-
}
12431241
ICHECK(dev_type);
12441242
targets[static_cast<DLDeviceType>(dev_type->value)] = it.second;
12451243
}

src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,16 @@ struct HexagonWorkspacePool : public WorkspacePool {
107107
};
108108

109109
void* HexagonDeviceAPIv2::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) {
110-
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;
110+
bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) ||
111+
(DLDeviceType(dev.device_type) == kDLCPU);
112+
CHECK(is_valid_device) << "dev.device_type: " << dev.device_type;
111113
return dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->AllocWorkspace(dev, size);
112114
}
113115

114116
void HexagonDeviceAPIv2::FreeWorkspace(Device dev, void* data) {
115-
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;
117+
bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) ||
118+
(DLDeviceType(dev.device_type) == kDLCPU);
119+
CHECK(is_valid_device) << "dev.device_type: " << dev.device_type;
116120
CHECK(hexagon_buffer_map_.count(data) != 0)
117121
<< "Attempt made to free unknown or already freed workspace allocation";
118122
dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->FreeWorkspace(dev, data);

src/runtime/hexagon/rpc/hexagon/rpc_server.cc

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,10 @@ class HexagonIOHandler {
5858
read_buffer_size_bytes_{read_buffer_size_bytes},
5959
write_buffer_available_length_{0} {}
6060

61-
void MessageStart(size_t message_size_bytes) {}
61+
void MessageStart(size_t message_size_bytes) { LOG(INFO) << "MessageStart called."; }
6262

6363
ssize_t PosixWrite(const uint8_t* buf, size_t write_len_bytes) {
64-
LOG(INFO) << "INFO: HexagonIOHandler PosixWrite called, write_len_bytes(" << write_len_bytes
65-
<< ")";
64+
LOG(INFO) << "HexagonIOHandler PosixWrite called, write_len_bytes(" << write_len_bytes << ")";
6665
int32_t written_size = write_buffer_.sputn(reinterpret_cast<const char*>(buf), write_len_bytes);
6766
if (written_size != write_len_bytes) {
6867
LOG(ERROR) << "written_size(" << written_size << ") != write_len_bytes(" << write_len_bytes
@@ -72,10 +71,10 @@ class HexagonIOHandler {
7271
return (ssize_t)written_size;
7372
}
7473

75-
void MessageDone() { LOG(INFO) << "INFO: Message Done."; }
74+
void MessageDone() { LOG(INFO) << "Message Done."; }
7675

7776
ssize_t PosixRead(uint8_t* buf, size_t read_len_bytes) {
78-
LOG(INFO) << "INFO: HexagonIOHandler PosixRead called, read_len_bytes(" << read_len_bytes
77+
LOG(INFO) << "HexagonIOHandler PosixRead called, read_len_bytes(" << read_len_bytes
7978
<< "), read_buffer_index_(" << read_buffer_index_ << ")";
8079

8180
uint32_t bytes_to_read = 0;
@@ -99,7 +98,7 @@ class HexagonIOHandler {
9998
* \return The status
10099
*/
101100
AEEResult SetReadBuffer(const uint8_t* data, size_t data_size_bytes) {
102-
LOG(INFO) << "INFO: HexagonIOHandler SetReadBuffer: data_size_bytes(" << data_size_bytes
101+
LOG(INFO) << "HexagonIOHandler SetReadBuffer: data_size_bytes(" << data_size_bytes
103102
<< "), read_buffer_index_(" << read_buffer_index_ << "), read_buffer_size_bytes_("
104103
<< read_buffer_size_bytes_ << ")";
105104
if (data_size_bytes > read_buffer_size_bytes_) {
@@ -121,7 +120,7 @@ class HexagonIOHandler {
121120
* \return The size of data that is read in bytes.
122121
*/
123122
int64_t ReadFromWriteBuffer(uint8_t* buf, size_t read_size_bytes) {
124-
LOG(INFO) << "INFO: HexagonIOHandler ReadFromWriteBuffer called, read_size_bytes: "
123+
LOG(INFO) << "HexagonIOHandler ReadFromWriteBuffer called, read_size_bytes: "
125124
<< read_size_bytes;
126125
int64_t size = (int64_t)write_buffer_.sgetn(reinterpret_cast<char*>(buf), read_size_bytes);
127126
write_buffer_available_length_ -= size;
@@ -133,7 +132,7 @@ class HexagonIOHandler {
133132
return size;
134133
}
135134

136-
void Close() { LOG(INFO) << "INFO: HexagonIOHandler Close called"; }
135+
void Close() { LOG(INFO) << "HexagonIOHandler Close called"; }
137136

138137
void Exit(int code) { exit(code); }
139138

@@ -156,13 +155,19 @@ class HexagonRPCServer {
156155
* \param data The data pointer
157156
* \param data_size_bytes The data size in bytes.
158157
*
159-
* \return The size of data written to IOHandler.
158+
* \return The size of data written to IOHandler if no error.
159+
* Otherwise, returns -1;
160160
*/
161161
int64_t Write(const uint8_t* data, size_t data_size_bytes) {
162162
if (io_.SetReadBuffer(data, data_size_bytes) != AEE_SUCCESS) {
163+
LOG(ERROR) << "ERROR: SetReadBuffer failed";
164+
return -1;
165+
}
166+
167+
if (!rpc_server_.ProcessOnePacket()) {
168+
LOG(ERROR) << "ERROR: ProcessOnePacket failed";
163169
return -1;
164170
}
165-
rpc_server_.ProcessOnePacket();
166171
return (int64_t)data_size_bytes;
167172
}
168173

@@ -211,6 +216,7 @@ const tvm::runtime::PackedFunc get_runtime_func(const std::string& name) {
211216
void reset_device_api() {
212217
const tvm::runtime::PackedFunc api = get_runtime_func("device_api.hexagon.v2");
213218
tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(api);
219+
tvm::runtime::Registry::Register("device_api.cpu", true).set_body(api);
214220
}
215221

216222
int __QAIC_HEADER(hexagon_rpc_open)(const char* uri, remote_handle64* handle) {

src/runtime/hexagon/rpc/simulator/rpc_server.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ int main() {
292292
const auto* api_v2 = tvm::runtime::Registry::Get("device_api.hexagon.v2");
293293
ICHECK(api_v2 != nullptr);
294294
tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(*api_v2);
295+
tvm::runtime::Registry::Register("device_api.cpu", true).set_body(*api_v2);
295296

296297
tvm::runtime::hexagon::SimulatorRPCServer server;
297298

src/runtime/hexagon/rpc/simulator/session.cc

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,11 @@ class SimulatorRPCChannel final : public RPCChannel {
214214
std::string runmain; // Path to run_main_on_hexagon.
215215
};
216216

217+
struct Message_ {
218+
Message msg;
219+
std::string str() const;
220+
};
221+
217222
Message SendMsg(Message msg);
218223
Message SendMsg(uint32_t code, uint32_t len, uint32_t va);
219224
void ReadFromProcess(void* host_dst, HEX_VA_t src, size_t len);
@@ -461,6 +466,27 @@ std::string SimulatorRPCChannel::Cpu_::str() const {
461466
return default_cpu_;
462467
}
463468

469+
std::string SimulatorRPCChannel::Message_::str() const {
470+
switch (msg.code) {
471+
case Message::kNone:
472+
return "kNone";
473+
case Message::kAck:
474+
return "kAck";
475+
case Message::kTerminate:
476+
return "kTerminate";
477+
case Message::kReceiveStart:
478+
return "kReceiveStart";
479+
case Message::kReceiveEnd:
480+
return "kReceiveEnd";
481+
case Message::kSendStart:
482+
return "kSendStart";
483+
case Message::kSendEnd:
484+
return "kSendEnd";
485+
default:
486+
break;
487+
}
488+
}
489+
464490
SimulatorRPCChannel::SDKInfo_::SDKInfo_(const std::string& sdk_root, const std::string& cpu)
465491
: root(sdk_root) {
466492
// For v69 chips, still look for v68 in the directory names.
@@ -524,6 +550,7 @@ SimulatorRPCChannel::SimulatorRPCChannel(int stack_size, std::string args) {
524550
const auto* api_v2 = tvm::runtime::Registry::Get("device_api.hexagon.v2");
525551
ICHECK(api_v2 != nullptr);
526552
tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(*api_v2);
553+
tvm::runtime::Registry::Register("device_api.cpu", true).set_body(*api_v2);
527554

528555
const char* sdk_root_env = std::getenv("HEXAGON_SDK_ROOT");
529556
ICHECK(sdk_root_env != nullptr) << "Please set HEXAGON_SDK_ROOT";
@@ -651,9 +678,14 @@ Message SimulatorRPCChannel::SendMsg(Message msg) {
651678
HEX_4u_t result;
652679

653680
core = sim_->Run(&result);
654-
ICHECK_EQ(core, HEX_CORE_BREAKPOINT);
681+
Core_ core_ = {core};
682+
ICHECK_EQ(core, HEX_CORE_BREAKPOINT)
683+
<< "Expecting HEX_CORE_BREAKPOINT, received: " << core_.str();
655684
};
656685

686+
Message_ msg_ = {msg};
687+
LOG(INFO) << "Sending message: " << msg_.str();
688+
657689
WriteToProcess(message_buffer_v_, &msg, sizeof msg);
658690
run();
659691

src/runtime/library_module.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) {
115115
loaders += name.substr(loadkey.size());
116116
}
117117
}
118-
LOG(FATAL) << "Binary was created using " << type_key
119-
<< " but a loader of that name is not registered. Available loaders are " << loaders
118+
LOG(FATAL) << "Binary was created using {" << type_key
119+
<< "} but a loader of that name is not registered. Available loaders are " << loaders
120120
<< ". Perhaps you need to recompile with this runtime enabled.";
121121
}
122122

src/target/llvm/codegen_hexagon.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,12 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {
475475

476476
TVM_REGISTER_GLOBAL("target.build.hexagon").set_body_typed(BuildHexagon);
477477

478+
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_hexagon")
479+
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
480+
CodeGenLLVM* cg = new CodeGenHexagon();
481+
*rv = static_cast<void*>(cg);
482+
});
483+
478484
} // namespace codegen
479485
} // namespace tvm
480486

tests/python/contrib/test_hexagon/test_launcher.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,23 @@
1616
# under the License.
1717

1818
import os
19-
import pathlib
2019
import sys
2120
import pytest
2221
import numpy as np
23-
import logging
2422

2523
import tvm.testing
2624
from tvm import te
2725
from tvm import relay
2826
from tvm.relay.backend import Executor, Runtime
29-
from tvm.contrib import utils, ndk
30-
from tvm.contrib.hexagon.build import HexagonLauncher
3127
import tvm.contrib.hexagon as hexagon
3228

3329
from .conftest import requires_hexagon_toolchain
3430

31+
aot_target_kind = tvm.testing.parameter(
32+
"c",
33+
"llvm -keys=hexagon -link-params=0 -mattr=+hvxv68,+hvx-length128b,+hvx-qfloat,-hvx-ieee-fp -mcpu=hexagonv68 -mtriple=hexagon",
34+
)
35+
3536

3637
@requires_hexagon_toolchain
3738
def test_add(hexagon_session):
@@ -270,8 +271,18 @@ def _workaround_create_aot_shared():
270271
)
271272

272273

274+
def get_target_and_session(target_kind: str):
275+
if target_kind == "c":
276+
target_hexagon = tvm.target.hexagon("v68")
277+
session_key = "hexagon-rpc"
278+
elif target_kind.startswith("llvm"):
279+
target_hexagon = target_kind
280+
session_key = "cpu-rpc"
281+
return target_hexagon, session_key
282+
283+
273284
@requires_hexagon_toolchain
274-
def test_aot_executor(hexagon_session):
285+
def test_aot_executor(hexagon_launcher, aot_target_kind):
275286
dtype = "float32"
276287
input_shape = (1, 128, 128, 3)
277288
w_shape = (5, 5, 3, 8)
@@ -290,7 +301,7 @@ def test_aot_executor(hexagon_session):
290301
relay_mod = tvm.IRModule.from_expr(f)
291302
relay_mod = relay.transform.InferType()(relay_mod)
292303

293-
target_hexagon = tvm.target.hexagon("v68")
304+
target_hexagon, session_key = get_target_and_session(aot_target_kind)
294305

295306
weight_data = np.random.rand(w_shape[0], w_shape[1], w_shape[2], w_shape[3]).astype(dtype=dtype)
296307
input_data = np.random.rand(
@@ -304,11 +315,13 @@ def test_aot_executor(hexagon_session):
304315
lowered = tvm.relay.build(
305316
relay_mod,
306317
params=params,
307-
target=tvm.target.Target(target_hexagon, host="c"),
318+
target=tvm.target.Target(target_hexagon, host=aot_target_kind),
308319
runtime=Runtime("cpp"),
309320
executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}),
310321
)
311322

323+
hexagon_session = hexagon_launcher.start_session(name=session_key)
324+
hexagon_session.__enter__()
312325
aot_mod = hexagon_session.get_executor_from_factory(lowered)
313326
aot_mod.set_input(**inputs)
314327
aot_mod.run()
@@ -332,7 +345,7 @@ def test_aot_executor(hexagon_session):
332345

333346

334347
@requires_hexagon_toolchain
335-
def test_aot_executor_multiple_conv2d(hexagon_session):
348+
def test_aot_executor_multiple_conv2d(hexagon_launcher, aot_target_kind):
336349
dtype = "float32"
337350
input_shape = (1, 8, 8, 3)
338351
w1_shape = (5, 5, 3, 1)
@@ -362,7 +375,7 @@ def test_aot_executor_multiple_conv2d(hexagon_session):
362375
relay_mod = tvm.IRModule.from_expr(f)
363376
relay_mod = relay.transform.InferType()(relay_mod)
364377

365-
target_hexagon = tvm.target.hexagon("v68")
378+
target_hexagon, session_key = get_target_and_session(aot_target_kind)
366379

367380
weight1_data = np.random.rand(w1_shape[0], w1_shape[1], w1_shape[2], w1_shape[3]).astype(
368381
dtype=dtype
@@ -381,11 +394,13 @@ def test_aot_executor_multiple_conv2d(hexagon_session):
381394
lowered = tvm.relay.build(
382395
relay_mod,
383396
params=params,
384-
target=tvm.target.Target(target_hexagon, host="c"),
397+
target=tvm.target.Target(target_hexagon, host=aot_target_kind),
385398
runtime=Runtime("cpp"),
386399
executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}),
387400
)
388401

402+
hexagon_session = hexagon_launcher.start_session(name=session_key)
403+
hexagon_session.__enter__()
389404
aot_mod = hexagon_session.get_executor_from_factory(lowered)
390405
aot_mod.set_input(**inputs)
391406
aot_mod.run()

0 commit comments

Comments
 (0)