diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index e678785cbfd5..59b57e08ba49 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -21,6 +21,7 @@ # pylint: disable=invalid-name import sys +from typing import Dict from .._ffi.base import py_str from . import tar as _tar @@ -178,6 +179,51 @@ def create_executable(output, objects, options=None, cc=None, cwd=None, ccache_e raise ValueError("Unsupported platform") +def get_global_symbol_section_map(path, *, nm=None) -> Dict[str, str]: + """Get global symbols from a library via nm -g + + Parameters + ---------- + path : str + The library path + + nm: str + The path to nm command + + Returns + ------- + symbol_section_map: Dict[str, str] + A map from defined global symbol to their sections + """ + if nm is None: + if not _is_linux_like(): + raise ValueError("Unsupported platform") + nm = "nm" + + symbol_section_map = {} + + if not os.path.isfile(path): + raise FileNotFoundError(f"{path} does not exist") + + cmd = [nm, "-gU", path] + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + + if proc.returncode != 0: + msg = "Runtime error:\n" + msg += py_str(out) + raise RuntimeError(msg) + + for line in py_str(out).split("\n"): + data = line.strip().split() + if len(data) != 3: + continue + symbol = data[-1] + section = data[-2] + symbol_section_map[symbol] = section + return symbol_section_map + + def get_target_by_dump_machine(compiler): """Functor of get_target_triple that can get the target triple using compiler. diff --git a/python/tvm/contrib/ndk.py b/python/tvm/contrib/ndk.py index 335bb2e46437..2a1105ed2bbb 100644 --- a/python/tvm/contrib/ndk.py +++ b/python/tvm/contrib/ndk.py @@ -21,8 +21,10 @@ import subprocess import os import shutil +from typing import Dict + from .._ffi.base import py_str -from . import utils as _utils, tar as _tar +from . import utils as _utils, tar as _tar, cc as _cc from .cc import get_target_by_dump_machine @@ -123,3 +125,30 @@ def create_staticlib(output, inputs): create_staticlib.output_format = "a" + + +def get_global_symbol_section_map(path, *, nm=None) -> Dict[str, str]: + """Get global symbols from a library via nm -gU in NDK + + Parameters + ---------- + path : str + The library path + + nm: str + The path to nm command + + Returns + ------- + symbol_section_map: Dict[str, str] + A map from defined global symbol to their sections + """ + if "TVM_NDK_CC" not in os.environ: + raise RuntimeError( + "Require environment variable TVM_NDK_CC" " to be the NDK standalone compiler" + ) + if nm is None: + compiler = os.environ["TVM_NDK_CC"] + base_path = os.path.dirname(compiler) + nm = os.path.join(base_path, "llvm-nm") + return _cc.get_global_symbol_section_map(path, nm=nm)