diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 60e35353194c..cb88bc53f9a6 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -232,6 +232,11 @@ class TVM_DLL ModuleNode : public Object { return (GetPropertyMask() & ModulePropertyMask::kDSOExportable) != 0; } + /*! \brief Returns true if this module is 'Binary Serializable'. */ + bool IsBinarySerializable() const { + return (GetPropertyMask() & ModulePropertyMask::kBinarySerializable) != 0; + } + /*! * \brief Returns true if this module has a definition for a function of \p name. If * \p query_imports is true, also search in any imported modules. diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h index 46a19ad71b60..0490e14e22a6 100644 --- a/include/tvm/target/codegen.h +++ b/include/tvm/target/codegen.h @@ -47,6 +47,21 @@ using runtime::TVMRetValue; */ runtime::Module Build(IRModule mod, Target target); +/*! + * \brief Serialize runtime module including its submodules + * \param mod The runtime module to serialize including its import tree. + * \param export_dso By default, include the info of DSOExportable modules. If disabled, an error + * will be raised when encountering DSO modules. + */ +std::string SerializeModuleToBytes(const runtime::Module& mod, bool export_dso = true); + +/*! + * \brief Deserialize runtime module including its submodules + * \param blob byte stream, which are generated by `SerializeModuleToBytes`. + * \return runtime::Module runtime module constructed from the given stream + */ +runtime::Module DeserializeModuleFromBytes(std::string blob); + /*! * \brief Pack imported device library to a C file. * Compile the C file and link with the host library @@ -77,6 +92,7 @@ std::string PackImportsToC(const runtime::Module& m, bool system_lib, runtime::Module PackImportsToLLVM(const runtime::Module& m, bool system_lib, const std::string& target_triple, const std::string& c_symbol_prefix = ""); + } // namespace codegen } // namespace tvm #endif // TVM_TARGET_CODEGEN_H_ diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index cbba590e85dc..dfe35f2aae19 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -25,17 +25,19 @@ which is used to optimize the `torch.nn.module` by TVM metaSchedule, and returns a custom TorchScript operator """ -import base64 + import contextlib import tempfile from typing import Optional, Tuple, Union - +import base64 import torch import torch.utils.dlpack import tvm +import tvm._ffi +from tvm._ffi import register_func from tvm import meta_schedule as ms from tvm import relay -from tvm._ffi import get_global_func, register_func +from tvm._ffi import get_global_func from tvm.target import Target @@ -51,14 +53,6 @@ def forward(self, *torch_inputs: Tuple[torch.Tensor]): return ret -@register_func("script_torch.save_to_base64") -def save_to_base64(obj) -> bytes: - with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile: - obj.export_library(tmpfile.name) - with open(tmpfile.name, "rb") as temp_file: - return base64.b64encode(temp_file.read()) - - def optimize_torch( func, example_inputs, @@ -173,3 +167,11 @@ def optimize_torch( save_runtime_mod(executor_factory.module) return GraphExecutorFactoryWrapper(torch.classes.tvm_torch.GraphExecutorFactoryWrapper()) + + +@register_func("export_runtime_module") +def save_to_base64(obj) -> bytes: + with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile: + obj.export_library(tmpfile.name) + with open(tmpfile.name, "rb") as temp_file: + return base64.b64encode(temp_file.read()) diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index c9e3eb6add75..2a1db2cbb2ec 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -23,7 +23,6 @@ from typing import Sequence import numpy as np -import tvm._ffi from tvm._ffi.base import _LIB, check_call, c_str, string_types, _RUNTIME_ONLY from tvm._ffi.libinfo import find_include_path from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module diff --git a/src/contrib/torch/base64.h b/src/contrib/torch/base64.h deleted file mode 100644 index d7dac4b86cc8..000000000000 --- a/src/contrib/torch/base64.h +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file base64.h - * \brief Util functions for converting plain bytes back to plain bytes - */ - -#ifndef TVM_CONTRIB_TORCH_BASE64_H_ -#define TVM_CONTRIB_TORCH_BASE64_H_ - -#include - -#include -#include -#include - -#include "../../support/base64.h" - -namespace tvm { -namespace support { - -inline size_t b64strlen(const std::string b64str) { - ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding"; - size_t length = b64str.size() / 4 * 3; - if (b64str[b64str.size() - 2] == '=') { - length -= 2; - } else if (b64str[b64str.size() - 1] == '=') { - length -= 1; - } - return length; -} - -inline void b64decode(const std::string b64str, u_char* ret) { - size_t index = 0; - const auto length = b64str.size(); - for (size_t i = 0; i < length; i += 4) { - int8_t ch0 = base64::DecodeTable[(int32_t)b64str[i]]; - int8_t ch1 = base64::DecodeTable[(int32_t)b64str[i + 1]]; - int8_t ch2 = base64::DecodeTable[(int32_t)b64str[i + 2]]; - int8_t ch3 = base64::DecodeTable[(int32_t)b64str[i + 3]]; - u_char st1 = (ch0 << 2) + (ch1 >> 4); - ret[index++] = st1; - if (b64str[i + 2] != '=') { - u_char st2 = ((ch1 & 0b1111) << 4) + (ch2 >> 2); - ret[index++] = st2; - if (b64str[i + 3] != '=') { - u_char st3 = ((ch2 & 0b11) << 6) + ch3; - ret[index++] = st3; - } - } - } - ICHECK(b64strlen(b64str) == index) << "base64 decoding fails"; -} - -} // namespace support -} // namespace tvm - -#endif // TVM_CONTRIB_TORCH_BASE64_H_ diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index fb570c163feb..c77996cf67b6 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -29,7 +29,7 @@ #include #include "../../../runtime/graph_executor/graph_executor_factory.h" -#include "../base64.h" +#include "../../support/base64.h" #include "runtime_bridge.h" namespace tvm { @@ -46,54 +46,6 @@ struct ThreadLocalStore { } }; -/* - * Encode TVM runtime module to base64 stream - */ -std::string serialize(tvm::runtime::Module module) { - static const runtime::PackedFunc* f_to_str = - runtime::Registry::Get("script_torch.save_to_base64"); - ICHECK(f_to_str) << "IndexError: Cannot find the packed function " - "`script_torch.save_to_base64` in the global registry"; - return (*f_to_str)(module); -} - -struct Deleter { // deleter - explicit Deleter(std::string file_name) { this->file_name = file_name; } - void operator()(FILE* p) const { - fclose(p); - ICHECK(remove(file_name.c_str()) == 0) - << "remove temporary file (" << file_name << ") unsuccessfully"; - } - std::string file_name; -}; - -/* - * Decode TVM runtime module from base64 stream - */ -tvm::runtime::Module deserialize(std::string state) { - auto length = tvm::support::b64strlen(state); - - std::vector bytes(length); // bytes stream - tvm::support::b64decode(state, bytes.data()); - - const std::string name = tmpnam(NULL); - auto file_name = name + ".so"; - std::unique_ptr pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name)); - fwrite(bytes.data(), sizeof(u_char), length, pFile.get()); - fflush(pFile.get()); - - std::string load_f_name = "runtime.module.loadfile_so"; - const PackedFunc* f = runtime::Registry::Get(load_f_name); - ICHECK(f != nullptr) << "Loader for `.so` files is not registered," - << " resolved to (" << load_f_name << ") in the global registry." - << "Ensure that you have loaded the correct runtime code, and" - << "that you are on the correct hardware architecture."; - - tvm::runtime::Module ret = (*f)(file_name, ""); - - return ret; -} - TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) { ThreadLocalStore::ThreadLocal()->mod = mod; }); @@ -242,15 +194,104 @@ size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeMod return output_length; } +inline size_t b64strlen(const std::string b64str) { + ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding"; + size_t length = b64str.size() / 4 * 3; + if (b64str[b64str.size() - 2] == '=') { + length -= 2; + } else if (b64str[b64str.size() - 1] == '=') { + length -= 1; + } + return length; +} + +inline void b64decode(const std::string b64str, uint8_t* ret) { + size_t index = 0; + const auto length = b64str.size(); + for (size_t i = 0; i < length; i += 4) { + int8_t ch0 = base64::DecodeTable[(int32_t)b64str[i]]; + int8_t ch1 = base64::DecodeTable[(int32_t)b64str[i + 1]]; + int8_t ch2 = base64::DecodeTable[(int32_t)b64str[i + 2]]; + int8_t ch3 = base64::DecodeTable[(int32_t)b64str[i + 3]]; + uint8_t st1 = (ch0 << 2) + (ch1 >> 4); + ret[index++] = st1; + if (b64str[i + 2] != '=') { + uint8_t st2 = ((ch1 & 0b1111) << 4) + (ch2 >> 2); + ret[index++] = st2; + if (b64str[i + 3] != '=') { + uint8_t st3 = ((ch2 & 0b11) << 6) + ch3; + ret[index++] = st3; + } + } + } + ICHECK(b64strlen(b64str) == index) << "base64 decoding fails"; +} + +/*! + * \brief Export TVM runtime module to base64 stream including its submodules. + * Note that this targets modules that are binary serializable and DSOExportable. + * \param module The runtime module to export + * \return std::string The content of exported file + */ +std::string ExportModuleToBase64(tvm::runtime::Module module) { + static const tvm::runtime::PackedFunc* f_to_str = + tvm::runtime::Registry::Get("export_runtime_module"); + ICHECK(f_to_str) << "IndexError: Cannot find the packed function " + "`export_runtime_module` in the global registry"; + return (*f_to_str)(module); +} + +struct Deleter { // deleter + explicit Deleter(std::string file_name) { this->file_name = file_name; } + void operator()(FILE* p) const { + fclose(p); + ICHECK(remove(file_name.c_str()) == 0) + << "remove temporary file (" << file_name << ") unsuccessfully"; + } + std::string file_name; +}; + +/*! + * \brief Import TVM runtime module from base64 stream + * Note that this targets modules that are binary serializable and DSOExportable. + * \param base64str base64 stream, which are generated by `ExportModuleToBase64`. + * \return runtime::Module runtime module constructed from the given stream + */ +tvm::runtime::Module ImportModuleFromBase64(std::string base64str) { + auto length = b64strlen(base64str); + + std::vector bytes(length); // bytes stream + b64decode(base64str, bytes.data()); + + auto now = std::chrono::system_clock::now(); + auto in_time_t = std::chrono::system_clock::to_time_t(now); + std::stringstream datetime; + datetime << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d-%X"); + const std::string file_name = "tmp-module-" + datetime.str() + ".so"; + LOG(INFO) << file_name; + std::unique_ptr pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name)); + fwrite(bytes.data(), sizeof(uint8_t), length, pFile.get()); + fflush(pFile.get()); + + std::string load_f_name = "runtime.module.loadfile_so"; + const tvm::runtime::PackedFunc* f = tvm::runtime::Registry::Get(load_f_name); + ICHECK(f != nullptr) << "Loader for `.so` files is not registered," + << " resolved to (" << load_f_name << ") in the global registry." + << "Ensure that you have loaded the correct runtime code, and" + << "that you are on the correct hardware architecture."; + tvm::runtime::Module ret = (*f)(file_name, ""); + return ret; +} + char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module) { - std::string std = tvm::contrib::serialize(runtime_module->mod); + std::string std = ExportModuleToBase64(runtime_module->mod); char* ret = new char[std.length() + 1]; snprintf(ret, std.length() + 1, "%s", std.c_str()); return ret; } TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state) { - tvm::runtime::Module ret = tvm::contrib::deserialize(state); + tvm::runtime::Module ret = ImportModuleFromBase64(state); return new TVMContribTorchRuntimeModule(ret); } diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 6cf796d34447..9643480292bc 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -360,6 +361,22 @@ struct ADTObjTrait { TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait); +struct ModuleNodeTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + static constexpr const std::nullptr_t SHashReduce = nullptr; + static constexpr const std::nullptr_t SEqualReduce = nullptr; +}; + +TVM_REGISTER_REFLECTION_VTABLE(runtime::ModuleNode, ModuleNodeTrait) + .set_creator([](const std::string& blob) { + runtime::Module rtmod = codegen::DeserializeModuleFromBytes(blob); + return RefToObjectPtr::Get(rtmod); + }) + .set_repr_bytes([](const Object* n) -> std::string { + const auto* rtmod = static_cast(n); + return codegen::SerializeModuleToBytes(GetRef(rtmod), /*export_dso*/ false); + }); + void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce, bool hash_data) { ICHECK_EQ(arr->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index eb5e85beb5d3..a1a86d03886f 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -67,15 +67,6 @@ class LibraryModuleNode final : public ModuleNode { PackedFuncWrapper packed_func_wrapper_; }; -/*! - * \brief Helper classes to get into internal of a module. - */ -class ModuleInternal { - public: - // Get mutable reference of imports. - static std::vector* GetImportsAddr(ModuleNode* node) { return &(node->imports_); } -}; - PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& sptr_to_self) { return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { TVMValue ret_value; diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index 167e819601fa..d4d32abe2110 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -30,6 +30,7 @@ #include #include +#include namespace tvm { namespace runtime { @@ -78,6 +79,15 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& */ void InitContextFunctions(std::function fgetsymbol); +/*! + * \brief Helper classes to get into internal of a module. + */ +class ModuleInternal { + public: + // Get mutable reference of imports. + static std::vector* GetImportsAddr(ModuleNode* node) { return &(node->imports_); } +}; + /*! * \brief Type alias for function to wrap a TVMBackendPackedCFunc. * \param The function address imported from a module. diff --git a/src/target/codegen.cc b/src/target/codegen.cc index bbb2c15a647f..f6ebd843d1c0 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -37,6 +37,9 @@ #include #include +#include "../runtime/library_module.h" +#include "../support/base64.h" + namespace tvm { namespace codegen { @@ -63,13 +66,16 @@ class ModuleSerializer { public: explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { Init(); } - void SerializeModule(dmlc::Stream* stream) { + void SerializeModuleToBytes(dmlc::Stream* stream, bool export_dso) { // Only have one DSO module and it is in the root, then // we will not produce import_tree_. bool has_import_tree = true; - if (mod_->IsDSOExportable() && mod_->imports().empty()) { - has_import_tree = false; + + if (mod_->IsDSOExportable()) { + ICHECK(export_dso) << "`export_dso` should be enabled for DSOExportable modules"; + has_import_tree = !mod_->imports().empty(); } + uint64_t sz = 0; if (has_import_tree) { // we will append one key for _import_tree @@ -83,17 +89,26 @@ class ModuleSerializer { for (const auto& group : mod_group_vec_) { ICHECK_NE(group.size(), 0) << "Every allocated group must have at least one module"; - if (!group[0]->IsDSOExportable()) { + // we prioritize export dso when a module is both serializable and exportable + if (export_dso) { + if (group[0]->IsDSOExportable()) { + if (has_import_tree) { + std::string mod_type_key = "_lib"; + stream->Write(mod_type_key); + } + } else if (group[0]->IsBinarySerializable()) { + ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged"; + std::string mod_type_key = group[0]->type_key(); + stream->Write(mod_type_key); + group[0]->SaveToBinary(stream); + } + } else { + ICHECK(group[0]->IsBinarySerializable()) + << group[0]->type_key() << " is not binary serializable."; ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged"; std::string mod_type_key = group[0]->type_key(); stream->Write(mod_type_key); group[0]->SaveToBinary(stream); - } else { - // DSOExportable: do not need binary - if (has_import_tree) { - std::string mod_type_key = "_lib"; - stream->Write(mod_type_key); - } } } @@ -227,22 +242,60 @@ class ModuleSerializer { std::vector import_tree_child_indices_; }; -namespace { -std::string SerializeModule(const runtime::Module& mod) { +std::string SerializeModuleToBytes(const runtime::Module& mod, bool export_dso) { std::string bin; dmlc::MemoryStringStream ms(&bin); dmlc::Stream* stream = &ms; ModuleSerializer module_serializer(mod); - module_serializer.SerializeModule(stream); - + module_serializer.SerializeModuleToBytes(stream, export_dso); return bin; } -} // namespace + +runtime::Module DeserializeModuleFromBytes(std::string blob) { + dmlc::MemoryStringStream ms(&blob); + dmlc::Stream* stream = &ms; + + uint64_t size; + ICHECK(stream->Read(&size)); + std::vector modules; + std::vector import_tree_row_ptr; + std::vector import_tree_child_indices; + + for (uint64_t i = 0; i < size; ++i) { + std::string tkey; + ICHECK(stream->Read(&tkey)); + // "_lib" serves as a placeholder in the module import tree to indicate where + // to place the DSOModule + ICHECK(tkey != "_lib") << "Should not contain any placeholder for DSOModule."; + if (tkey == "_import_tree") { + ICHECK(stream->Read(&import_tree_row_ptr)); + ICHECK(stream->Read(&import_tree_child_indices)); + } else { + auto m = runtime::LoadModuleFromBinary(tkey, stream); + modules.emplace_back(m); + } + } + + for (size_t i = 0; i < modules.size(); ++i) { + for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) { + auto module_import_addr = runtime::ModuleInternal::GetImportsAddr(modules[i].operator->()); + auto child_index = import_tree_child_indices[j]; + ICHECK(child_index < modules.size()); + module_import_addr->emplace_back(modules[child_index]); + } + } + + ICHECK(!modules.empty()) << "modules cannot be empty when import tree is present"; + // invariance: root module is always at location 0. + // The module order is collected via DFS + runtime::Module root_mod = modules[0]; + return root_mod; +} std::string PackImportsToC(const runtime::Module& mod, bool system_lib, const std::string& c_symbol_prefix) { - std::string bin = SerializeModule(mod); + std::string bin = SerializeModuleToBytes(mod); std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_dev_mblob; if (c_symbol_prefix.length() != 0) { @@ -304,7 +357,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, << "c_symbol_prefix advanced option should be used in conjuction with system-lib"; } - std::string bin = SerializeModule(mod); + std::string bin = SerializeModuleToBytes(mod); uint64_t nbytes = bin.length(); std::string header; diff --git a/tests/python/unittest/test_roundtrip_runtime_module.py b/tests/python/unittest/test_roundtrip_runtime_module.py new file mode 100644 index 000000000000..6a1abeedd914 --- /dev/null +++ b/tests/python/unittest/test_roundtrip_runtime_module.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test roundtrip of runtime modules """ +# pylint: disable=missing-docstring + +import pytest +import tvm +import tvm.testing +from tvm import TVMError +from tvm import relay + + +def test_csource_module(): + mod = tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], None) + # source module that is not binary serializable. + # Thus, it would raise an error. + assert not mod.is_binary_serializable + with pytest.raises(TVMError): + tvm.ir.load_json(tvm.ir.save_json(mod)) + + +def test_aot_module(): + mod = tvm.get_global_func("relay.build_module._AOTExecutorCodegen")() + # aot module that is not binary serializable. + # Thus, it would raise an error. + assert not mod.is_binary_serializable + with pytest.raises(TVMError): + tvm.ir.load_json(tvm.ir.save_json(mod)) + + +def get_test_mod(): + x = relay.var("x", shape=(1, 10), dtype="float32") + y = relay.var("y", shape=(1, 10), dtype="float32") + z = relay.add(x, y) + func = relay.Function([x, y], z) + return relay.build_module._build_module_no_factory(func, target="cuda") + + +def get_cuda_mod(): + # Get Cuda module which is binary serializable + return get_test_mod().imported_modules[0].imported_modules[0] + + +@tvm.testing.requires_cuda +def test_cuda_module(): + mod = get_cuda_mod() + assert mod.type_key == "cuda" + assert mod.is_binary_serializable + new_mod = tvm.ir.load_json(tvm.ir.save_json(mod)) + assert new_mod.type_key == "cuda" + assert new_mod.is_binary_serializable + + +@tvm.testing.requires_cuda +def test_valid_submodules(): + mod, mod2, mod3, mod4 = get_cuda_mod(), get_cuda_mod(), get_cuda_mod(), get_cuda_mod() + + # Create the nested cuda module + mod.import_module(mod2) + mod2.import_module(mod3) + mod2.import_module(mod4) + + # Root module and all submodules should be binary serializable since they are cuda module + assert mod.type_key == "cuda" + assert mod.is_binary_serializable + assert mod.imported_modules[0].type_key == "cuda" + assert mod.imported_modules[0].is_binary_serializable + assert mod.imported_modules[0].imported_modules[0].type_key == "cuda" + assert mod.imported_modules[0].imported_modules[1].type_key == "cuda" + assert mod.imported_modules[0].imported_modules[0].is_binary_serializable + assert mod.imported_modules[0].imported_modules[1].is_binary_serializable + + # The roundtripped mod should have the same structure + new_mod = tvm.ir.load_json(tvm.ir.save_json(mod)) + assert new_mod.type_key == "cuda" + assert new_mod.is_binary_serializable + assert new_mod.imported_modules[0].type_key == "cuda" + assert new_mod.imported_modules[0].is_binary_serializable + assert new_mod.imported_modules[0].imported_modules[0].type_key == "cuda" + assert new_mod.imported_modules[0].imported_modules[1].type_key == "cuda" + assert new_mod.imported_modules[0].imported_modules[0].is_binary_serializable + assert new_mod.imported_modules[0].imported_modules[1].is_binary_serializable + + +@tvm.testing.requires_cuda +def test_invalid_submodules(): + mod, mod2, mod3 = get_cuda_mod(), get_cuda_mod(), get_cuda_mod() + mod4 = tvm.get_global_func("relay.build_module._AOTExecutorCodegen")() + + # Create the nested cuda module + mod.import_module(mod2) + mod2.import_module(mod3) + mod2.import_module(mod4) + + # One of submodules is not binary serializable. + assert mod.is_binary_serializable + assert mod.imported_modules[0].is_binary_serializable + assert mod.imported_modules[0].imported_modules[0].is_binary_serializable + assert not mod.imported_modules[0].imported_modules[1].is_binary_serializable + + # Therefore, we cannot roundtrip. + with pytest.raises(TVMError): + tvm.ir.load_json(tvm.ir.save_json(mod)) + + +if __name__ == "__main__": + tvm.testing.main()