From def40f6cbdc7fbc22916724edf31ed4ce783d6be Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 5 Nov 2023 14:18:38 -0800 Subject: [PATCH] [Unity] Loading NDArrayCache by parameter names This PR adds support for loading parameters from NDArrayCache ordered by their names. --- src/runtime/relax_vm/ndarray_cache_support.cc | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc b/src/runtime/relax_vm/ndarray_cache_support.cc index b2f53bfe1e9a..ea90255fbaf2 100644 --- a/src/runtime/relax_vm/ndarray_cache_support.cc +++ b/src/runtime/relax_vm/ndarray_cache_support.cc @@ -308,6 +308,19 @@ class ParamModuleNode : public runtime::ModuleNode { return params; } + static Array GetParamByName(const Array& names) { + Array result; + result.reserve(names.size()); + for (const String& name : names) { + if (Optional opt = NDArrayCache::Get(name)) { + result.push_back(opt.value()); + } else { + LOG(FATAL) << "ValueError: Cannot find parameter in cache: " << name; + } + } + return result; + } + static Module Create(const std::string& prefix, int num_params) { auto n = make_object(); n->params_ = GetParams(prefix, num_params); @@ -320,6 +333,8 @@ class ParamModuleNode : public runtime::ModuleNode { TVM_REGISTER_GLOBAL("vm.builtin.param_module_from_cache").set_body_typed(ParamModuleNode::Create); TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache").set_body_typed(ParamModuleNode::GetParams); +TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name") + .set_body_typed(ParamModuleNode::GetParamByName); } // namespace relax_vm } // namespace runtime