diff --git a/README.md b/README.md index b93b7be..5d41768 100644 --- a/README.md +++ b/README.md @@ -66,25 +66,42 @@ cmake --install build #include int main() { - infini::rt::SetDevice({infini::rt::Device::Type::kCpu, 0}); + infini::rt::runtime::SetDevice(0); constexpr std::size_t size = 1024; void* ptr = nullptr; - infini::rt::Malloc(&ptr, size); - infini::rt::Memset(ptr, 0, size); - infini::rt::Free(ptr); + infini::rt::runtime::Malloc(&ptr, size); + infini::rt::runtime::Memset(ptr, 0, size); + infini::rt::runtime::Free(ptr); return 0; } ``` -For NVIDIA: +The CUDA Runtime API-aligned layer lives under `infini::rt::runtime`. The +top-level `infini::rt::set_runtime_device_type` and +`infini::rt::runtime_device_type` APIs select which enabled backend receives +those runtime calls. A GPU backend is selected initially when one is enabled; +otherwise CPU is selected. ```cpp -infini::rt::SetDevice({infini::rt::Device::Type::kNvidia, 0}); +constexpr std::size_t size = 1024; +void* ptr = nullptr; + +infini::rt::set_runtime_device_type(infini::rt::Device::Type::kCpu); +infini::rt::runtime::Malloc(&ptr, size); +infini::rt::runtime::Free(ptr); + +infini::rt::set_runtime_device_type(infini::rt::Device::Type::kNvidia); +infini::rt::runtime::Malloc(&ptr, size); +infini::rt::runtime::Free(ptr); ``` +Use `infini::rt::runtime::Runtime` when CPU +runtime calls are needed explicitly in a build that also enables an accelerator +backend. + ## Using Installed InfiniRT From Another Project Downstream projects should consume the installed headers and libraries instead of depending on the InfiniRT source tree. diff --git a/scripts/generate_public_headers.py b/scripts/generate_public_headers.py index 2ab58cc..3363741 100644 --- a/scripts/generate_public_headers.py +++ b/scripts/generate_public_headers.py @@ -53,21 +53,23 @@ "cpu": "Device::Type::kCpu", "nvidia": "Device::Type::kNvidia", "iluvatar": "Device::Type::kIluvatar", + "hygon": "Device::Type::kHygon", "metax": "Device::Type::kMetax", "moore": "Device::Type::kMoore", "cambricon": "Device::Type::kCambricon", "ascend": "Device::Type::kAscend", } -_RUNTIME_HEADERS = { - "cpu": "native/cpu/runtime_.h", - "nvidia": "native/cuda/nvidia/runtime_.h", - "iluvatar": "native/cuda/iluvatar/runtime_.h", - "metax": "native/cuda/metax/runtime_.h", - "moore": "native/cuda/moore/runtime_.h", - "cambricon": "native/cambricon/runtime_.h", - "ascend": "native/ascend/runtime_.h", -} +_DEFAULT_DEVICE_PRIORITY = ( + "nvidia", + "iluvatar", + "hygon", + "metax", + "moore", + "cambricon", + "ascend", + "cpu", +) def _guard(path): @@ -101,7 +103,7 @@ def _rewrite_detail_include(match): _DETAIL_INCLUDE_PATTERN = re.compile( - r'#include "((?:common|native)/[^"]+|data_type\.h|device\.h|dispatcher\.h|graph\.h|hash\.h|runtime\.h|tensor_view\.h)"' + r'#include "((?:common|native)/[^"]+|data_type\.h|device\.h|dispatcher\.h|hash\.h|runtime\.h|tensor_view\.h)"' ) @@ -133,7 +135,6 @@ def _write_detail_headers(include_root, source_root, devices): "data_type.h", "device.h", "dispatcher.h", - "graph.h", "hash.h", "runtime.h", "tensor_view.h", @@ -155,11 +156,37 @@ def _write_detail_headers(include_root, source_root, devices): _write_detail_header(include_root, source_root, relative_path) -def _write_generated_header(include_root, devices): +def _write_generated_header(include_root, source_root, devices): + default_device = _default_device(devices) + default_device_type = _DEVICE_TYPES[default_device] + public_runtime_functions = _public_runtime_functions_for_devices( + devices, source_root + ) + has_graph_api = any( + function.name in {"StreamBeginCapture", "GraphLaunch"} + for function in public_runtime_functions + ) + graph_declarations = ( + """ +using Graph = void*; + +using GraphExec = void*; + +enum class StreamCaptureMode { + kStreamCaptureModeGlobal = 0, + kStreamCaptureModeThreadLocal = 1, + kStreamCaptureModeRelaxed = 2, +}; +""" + if has_graph_api + else "" + ) includes = [ + "#include ", + "#include ", + "#include ", f"#include {_detail_include('data_type.h')}", f"#include {_detail_include('device.h')}", - f"#include {_detail_include('graph.h')}", f"#include {_detail_include('hash.h')}", f"#include {_detail_include('runtime.h')}", f"#include {_detail_include('tensor_view.h')}", @@ -168,6 +195,13 @@ def _write_generated_header(include_root, devices): for device in devices: includes.append(f"#include ") + for device in devices: + includes.append(f"#include ") + + runtime_declarations = "\n\n".join( + f"{function.signature()};" for function in public_runtime_functions + ) + path = include_root / "infini" / "rt" / "generated.h" path.parent.mkdir(parents=True, exist_ok=True) path.write_text( @@ -176,6 +210,50 @@ def _write_generated_header(include_root, devices): {chr(10).join(includes)} +namespace infini::rt {{ + +void set_runtime_device_type(Device::Type device_type); + +Device::Type runtime_device_type(); + +namespace runtime {{ +namespace generated_detail {{ + +using DefaultErrorRuntime = Runtime<{default_device_type}>; + +inline constexpr Device::Type kDefaultDeviceType = {default_device_type}; + +}} // namespace generated_detail + +using Error = typename generated_detail::DefaultErrorRuntime::Error; + +using Stream = typename generated_detail::DefaultErrorRuntime::Stream; + +using Event = void*; + +{graph_declarations} +using MemcpyKind = std::remove_cv_t< + decltype(generated_detail::DefaultErrorRuntime::kMemcpyHostToHost)>; + +inline constexpr Error kSuccess = generated_detail::DefaultErrorRuntime::kSuccess; + +inline constexpr MemcpyKind kMemcpyHostToHost = + generated_detail::DefaultErrorRuntime::kMemcpyHostToHost; + +inline constexpr MemcpyKind kMemcpyHostToDevice = + generated_detail::DefaultErrorRuntime::kMemcpyHostToDevice; + +inline constexpr MemcpyKind kMemcpyDeviceToHost = + generated_detail::DefaultErrorRuntime::kMemcpyDeviceToHost; + +inline constexpr MemcpyKind kMemcpyDeviceToDevice = + generated_detail::DefaultErrorRuntime::kMemcpyDeviceToDevice; + +{runtime_declarations} + +}} // namespace runtime +}} // namespace infini::rt + #endif """ ) @@ -200,327 +278,369 @@ def params_decl(self): return ", ".join(f"{param.type} {param.name}" for param in self.params) -def _parse_param(param): - param_type, param_name = param.strip().rsplit(" ", 1) - - return _Param(param_type, param_name) - - -def _parse_runtime_functions(runtime_header): - text = pathlib.Path(runtime_header).read_text() - return tuple( - _Function( - return_type, - name, - tuple( - _parse_param(param) - for param in re.split(r",\s*", params.strip()) - if param - ), - ) - for return_type, name, params in re.findall( - r"^(void) ([A-Z]\w*)\(([^()]*)\);$", text, re.MULTILINE - ) - ) - - -def _abort_statement(message): - return f""" assert(false && "{message}"); - std::abort();""" - - -def _dispatch_cases(devices, statements): - return "\n".join( - f""" case {_DEVICE_TYPES[device]}: {{ -{statements.replace("__DEVICE_TYPE__", _DEVICE_TYPES[device])} - return; - }}""" - for device in devices - ) +_PUBLIC_RUNTIME_FUNCTIONS = ( + _Function("Error", "SetDevice", (_Param("int", "device"),)), + _Function("Error", "GetDevice", (_Param("int*", "device"),)), + _Function("Error", "GetDeviceCount", (_Param("int*", "count"),)), + _Function("Error", "DeviceSynchronize", ()), + _Function( + "Error", + "Malloc", + (_Param("void**", "ptr"), _Param("std::size_t", "size")), + ), + _Function( + "Error", + "MallocHost", + (_Param("void**", "ptr"), _Param("std::size_t", "size")), + ), + _Function( + "Error", + "MallocAsync", + ( + _Param("void**", "ptr"), + _Param("std::size_t", "size"), + _Param("Stream", "stream"), + ), + ), + _Function("Error", "Free", (_Param("void*", "ptr"),)), + _Function("Error", "FreeHost", (_Param("void*", "ptr"),)), + _Function( + "Error", + "FreeAsync", + (_Param("void*", "ptr"), _Param("Stream", "stream")), + ), + _Function( + "Error", + "MemGetInfo", + (_Param("std::size_t*", "free"), _Param("std::size_t*", "total")), + ), + _Function( + "Error", + "Memcpy", + ( + _Param("void*", "dst"), + _Param("const void*", "src"), + _Param("std::size_t", "count"), + _Param("MemcpyKind", "kind"), + ), + ), + _Function( + "Error", + "MemcpyAsync", + ( + _Param("void*", "dst"), + _Param("const void*", "src"), + _Param("std::size_t", "count"), + _Param("MemcpyKind", "kind"), + _Param("Stream", "stream"), + ), + ), + _Function( + "Error", + "Memset", + ( + _Param("void*", "ptr"), + _Param("int", "value"), + _Param("std::size_t", "count"), + ), + ), + _Function( + "Error", + "MemsetAsync", + ( + _Param("void*", "ptr"), + _Param("int", "value"), + _Param("std::size_t", "count"), + _Param("Stream", "stream"), + ), + ), + _Function("Error", "StreamCreate", (_Param("Stream*", "stream"),)), + _Function("Error", "StreamDestroy", (_Param("Stream", "stream"),)), + _Function("Error", "StreamSynchronize", (_Param("Stream", "stream"),)), + _Function( + "Error", + "StreamWaitEvent", + ( + _Param("Stream", "stream"), + _Param("Event", "event"), + _Param("unsigned int", "flags"), + ), + ), + _Function("Error", "EventCreate", (_Param("Event*", "event"),)), + _Function( + "Error", + "EventCreateWithFlags", + (_Param("Event*", "event"), _Param("unsigned int", "flags")), + ), + _Function( + "Error", + "EventRecord", + (_Param("Event", "event"), _Param("Stream", "stream")), + ), + _Function("Error", "EventQuery", (_Param("Event", "event"),)), + _Function("Error", "EventSynchronize", (_Param("Event", "event"),)), + _Function("Error", "EventDestroy", (_Param("Event", "event"),)), + _Function( + "Error", + "EventElapsedTime", + ( + _Param("float*", "ms"), + _Param("Event", "start"), + _Param("Event", "end"), + ), + ), + _Function( + "Error", + "StreamBeginCapture", + (_Param("Stream", "stream"), _Param("StreamCaptureMode", "mode")), + ), + _Function( + "Error", + "StreamEndCapture", + (_Param("Stream", "stream"), _Param("Graph*", "graph")), + ), + _Function("Error", "GraphDestroy", (_Param("Graph", "graph"),)), + _Function( + "Error", + "GraphInstantiate", + (_Param("GraphExec*", "graph_exec"), _Param("Graph", "graph")), + ), + _Function( + "Error", + "GraphExecDestroy", + (_Param("GraphExec", "graph_exec"),), + ), + _Function( + "Error", + "GraphLaunch", + (_Param("GraphExec", "graph_exec"), _Param("Stream", "stream")), + ), +) -def _selector(function): - for param in function.params: - if param.type == "Device": - return f"{param.name}.type()" - if param.type == "Device::Type": - return param.name - if param.type in {"Stream", "Graph", "GraphExec"}: - return f"{param.name}.device_type()" +def _default_device(devices): + for device in _DEFAULT_DEVICE_PRIORITY: + if device in devices: + return device - return "current_device.type()" + raise ValueError("at least one device is required") -def _runtime_arg(param): - if param.type == "Device": - return f"{param.name}.index()" - if param.type == "Device::Type": - return None +def _runtime_arg(param, device): + device_type = _DEVICE_TYPES[device] if param.type == "MemcpyKind": - return f"RuntimeMemcpyKind<__DEVICE_TYPE__>({param.name})" - if param.type == "StreamCaptureMode": - return f"RuntimeStreamCaptureMode<__DEVICE_TYPE__>({param.name})" - if param.type in {"Stream", "Graph", "GraphExec"}: + return f"RuntimeMemcpyKind<{device_type}>({param.name})" + if param.type == "Stream": return ( - f"static_cast::{param.type}>" - f"({param.name}.raw())" + f"reinterpret_cast::Stream>({param.name})" ) + if param.type == "Stream*": + return ( + f"reinterpret_cast::Stream*>({param.name})" + ) + if param.type == "Event": + return f"reinterpret_cast::Event>({param.name})" + if param.type == "Event*": + return ( + f"reinterpret_cast::Event*>({param.name})" + ) + if param.type == "Graph": + return f"reinterpret_cast::Graph>({param.name})" + if param.type == "Graph*": + return ( + f"reinterpret_cast::Graph*>({param.name})" + ) + if param.type == "GraphExec": + return ( + f"reinterpret_cast::GraphExec>" + f"({param.name})" + ) + if param.type == "GraphExec*": + return ( + f"reinterpret_cast::GraphExec*>" + f"({param.name})" + ) + if param.type == "StreamCaptureMode": + return f"RuntimeStreamCaptureMode<{device_type}>({param.name})" return param.name -def _runtime_args(function): - args = (_runtime_arg(param) for param in function.params) +def _runtime_args(function, device): + args = (_runtime_arg(param, device) for param in function.params) return ", ".join(arg for arg in args if arg is not None) -def _preconditions(function): - required_pointer_names = { - "GetDevice": {"device"}, - "GetDeviceCount": {"count"}, - "StreamCreate": {"stream"}, - "StreamEndCapture": {"graph"}, - "GraphInstantiate": {"graph_exec"}, - } - checks = [] - for param in function.params: - if param.type.endswith("**") or param.name in required_pointer_names.get( - function.name, set() - ): - checks.append(f" assert({param.name} != nullptr);") - - return "\n".join(checks) - - -def _post_dispatch(function): - if function.name == "SetDevice": - return "\n current_device = device;" - - return "" - - -def _runtime_call(function): - args = _runtime_args(function) +def _runtime_call(function, device): + device_type = _DEVICE_TYPES[device] + args = _runtime_args(function, device) if args: - return f"Runtime<__DEVICE_TYPE__>::{function.name}({args})" + return f"Runtime<{device_type}>::{function.name}({args})" - return f"Runtime<__DEVICE_TYPE__>::{function.name}()" + return f"Runtime<{device_type}>::{function.name}()" -def _write_stream_create(function, devices): - stream_param = function.params[0].name - cases = _dispatch_cases( - devices, - f""" typename Runtime<__DEVICE_TYPE__>::Stream raw_stream = {{}}; - CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::StreamCreate(&raw_stream); }}); - *{stream_param} = Stream{{__DEVICE_TYPE__, static_cast(raw_stream)}};""", +def _dispatch_cases(devices, function): + return "\n".join( + f""" case {_DEVICE_TYPES[device]}: + return CheckCall([&] {{ return {_runtime_call(function, device)}; }});""" + for device in devices ) - return f"""void StreamCreate(Stream* {stream_param}) {{ - assert({stream_param} != nullptr); - switch (current_device.type()) {{ -{cases} - default: -{_abort_statement("runtime device is not enabled")} +def _write_runtime_dispatch_function(function, devices): + return f"""{function.signature()} {{ + switch (infini::rt::runtime_device_type()) {{ +{_dispatch_cases(devices, function)} }} -}} -""" - - -def _write_get_device(function, devices): - device_param = function.params[0].name - cases = _dispatch_cases( - devices, - f""" int index = current_device.index(); - CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GetDevice(&index); }}); - current_device = Device{{current_device.type(), index}}; - *{device_param} = current_device;""", - ) - - return f"""void GetDevice(Device* {device_param}) {{ - assert({device_param} != nullptr); - switch (current_device.type()) {{ -{cases} - default: -{_abort_statement("runtime device is not enabled")} - }} + assert(false && "unsupported runtime device type"); + return InvalidValueError(); }} """ -def _write_stream_end_capture(function, devices): - stream_param = function.params[0].name - graph_param = function.params[1].name - cases = _dispatch_cases( - devices, - f""" typename Runtime<__DEVICE_TYPE__>::Graph raw_graph = {{}}; - CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::StreamEndCapture(static_cast::Stream>({stream_param}.raw()), &raw_graph); }}); - *{graph_param} = Graph{{__DEVICE_TYPE__, static_cast(raw_graph)}};""", - ) +def _runtime_header_for_device(source_root, device): + for _, header_name, target in _DEVICE_HEADERS[device]: + if header_name == "runtime_.h": + return source_root / target - return f"""void StreamEndCapture(Stream {stream_param}, Graph* {graph_param}) {{ - assert({graph_param} != nullptr); + raise ValueError(f"device {device!r} does not have a runtime header") - switch ({stream_param}.device_type()) {{ -{cases} - default: -{_abort_statement("runtime device is not enabled")} - }} -}} -""" +def _devices_for_function(function, devices, source_root): + pattern = re.compile(r"\b" + re.escape(function.name) + r"\b") -def _write_graph_instantiate(function, devices): - graph_exec_param = function.params[0].name - graph_param = function.params[1].name - cases = _dispatch_cases( - devices, - f""" typename Runtime<__DEVICE_TYPE__>::GraphExec raw_graph_exec = {{}}; - CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GraphInstantiate(&raw_graph_exec, static_cast::Graph>({graph_param}.raw())); }}); - *{graph_exec_param} = GraphExec{{__DEVICE_TYPE__, static_cast(raw_graph_exec)}};""", + return tuple( + device + for device in devices + if pattern.search(_runtime_header_for_device(source_root, device).read_text()) ) - return f"""void GraphInstantiate(GraphExec* {graph_exec_param}, Graph {graph_param}) {{ - assert({graph_exec_param} != nullptr); - switch ({graph_param}.device_type()) {{ -{cases} - default: -{_abort_statement("runtime device is not enabled")} - }} -}} -""" - - -def _write_graph_launch(function, devices): - graph_exec_param = function.params[0].name - stream_param = function.params[1].name - cases = _dispatch_cases( - devices, - f""" CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GraphLaunch(static_cast::GraphExec>({graph_exec_param}.raw()), static_cast::Stream>({stream_param}.raw())); }});""", +def _public_runtime_functions_for_devices(devices, source_root): + return tuple( + function + for function in _PUBLIC_RUNTIME_FUNCTIONS + if _devices_for_function(function, devices, source_root) ) - return f"""void GraphLaunch(GraphExec {graph_exec_param}, Stream {stream_param}) {{ - assert({graph_exec_param}.device_type() == {stream_param}.device_type()); - - switch ({graph_exec_param}.device_type()) {{ -{cases} - default: -{_abort_statement("runtime device is not enabled")} - }} -}} -""" +def _write_runtime_dispatch(source_path, source_root, devices): + functions = _public_runtime_functions_for_devices(devices, source_root) + stream_capture_mode_helper = ( + """ +template +auto RuntimeStreamCaptureMode(StreamCaptureMode mode) { + using DeviceRuntime = Runtime; -def _write_dispatch_function(function, devices): - if function.name == "GetDevice": - return _write_get_device(function, devices) - if function.name == "StreamCreate": - return _write_stream_create(function, devices) - if function.name == "StreamEndCapture": - return _write_stream_end_capture(function, devices) - if function.name == "GraphInstantiate": - return _write_graph_instantiate(function, devices) - if function.name == "GraphLaunch": - return _write_graph_launch(function, devices) - - cases = _dispatch_cases( - devices, - f""" CheckCall([&] {{ return {_runtime_call(function)}; }});{_post_dispatch(function)}""", - ) - preconditions = _preconditions(function) - if preconditions: - preconditions = f"{preconditions}\n\n" + switch (mode) { + case StreamCaptureMode::kStreamCaptureModeGlobal: + return DeviceRuntime::kStreamCaptureModeGlobal; + case StreamCaptureMode::kStreamCaptureModeThreadLocal: + return DeviceRuntime::kStreamCaptureModeThreadLocal; + case StreamCaptureMode::kStreamCaptureModeRelaxed: + return DeviceRuntime::kStreamCaptureModeRelaxed; + } - return f"""{function.signature()} {{ -{preconditions} switch ({_selector(function)}) {{ -{cases} - default: -{_abort_statement("runtime device is not enabled")} - }} -}} + assert(false && "unsupported stream capture mode"); + return DeviceRuntime::kStreamCaptureModeRelaxed; +} """ - - -def _write_runtime_dispatch(source_path, runtime_header, devices): - first_device_type = _DEVICE_TYPES[devices[0]] - includes = ['#include "runtime.h"'] - includes.extend(f'#include "{_RUNTIME_HEADERS[device]}"' for device in devices) - functions = _parse_runtime_functions(runtime_header) + if any( + param.type == "StreamCaptureMode" + for function in functions + for param in function.params + ) + else "" + ) dispatch_functions = "\n".join( - _write_dispatch_function(function, devices) for function in functions + _write_runtime_dispatch_function( + function, + devices=_devices_for_function(function, devices, source_root), + ) + for function in functions + ) + set_device_type_cases = "\n".join( + f""" case {_DEVICE_TYPES[device]}: + runtime_device_type_ = device_type; + return;""" + for device in devices ) source_path.parent.mkdir(parents=True, exist_ok=True) source_path.write_text( f"""#include -#include +#include #include #include -{chr(10).join(includes)} +#include namespace infini::rt {{ namespace {{ -thread_local Device current_device{{{first_device_type}, 0}}; +thread_local Device::Type runtime_device_type_ = + runtime::generated_detail::kDefaultDeviceType; + +}} // namespace + +void set_runtime_device_type(Device::Type device_type) {{ + switch (device_type) {{ +{set_device_type_cases} + }} + + assert(false && "unsupported runtime device type"); +}} + +Device::Type runtime_device_type() {{ + return runtime_device_type_; +}} + +}} // namespace infini::rt + +namespace infini::rt::runtime {{ +namespace {{ template -void CheckCall(Func&& func) {{ +Error CheckCall(Func&& func) {{ using ReturnType = decltype(std::forward(func)()); if constexpr (std::is_void_v) {{ std::forward(func)(); + return kSuccess; }} else {{ - ReturnType status = std::forward(func)(); - if (status != ReturnType{{}}) {{ - assert(false && "runtime call failed"); - std::abort(); - }} + return static_cast(std::forward(func)()); }} }} -template +Error InvalidValueError() {{ return static_cast(1); }} + +template auto RuntimeMemcpyKind(MemcpyKind kind) {{ + using DeviceRuntime = Runtime; + switch (kind) {{ - case MemcpyKind::kHostToHost: - return Runtime::MemcpyHostToHost; - case MemcpyKind::kHostToDevice: - return Runtime::MemcpyHostToDevice; - case MemcpyKind::kDeviceToHost: - return Runtime::MemcpyDeviceToHost; - case MemcpyKind::kDeviceToDevice: - return Runtime::MemcpyDeviceToDevice; + case kMemcpyHostToHost: + return DeviceRuntime::kMemcpyHostToHost; + case kMemcpyHostToDevice: + return DeviceRuntime::kMemcpyHostToDevice; + case kMemcpyDeviceToHost: + return DeviceRuntime::kMemcpyDeviceToHost; + case kMemcpyDeviceToDevice: + return DeviceRuntime::kMemcpyDeviceToDevice; }} assert(false && "unsupported memcpy kind"); - std::abort(); - return Runtime::MemcpyHostToHost; -}} - -template -auto RuntimeStreamCaptureMode(StreamCaptureMode mode) {{ - switch (mode) {{ - case StreamCaptureMode::kGlobal: - return Runtime::StreamCaptureModeGlobal; - case StreamCaptureMode::kThreadLocal: - return Runtime::StreamCaptureModeThreadLocal; - case StreamCaptureMode::kRelaxed: - return Runtime::StreamCaptureModeRelaxed; - }} - - assert(false && "unsupported stream capture mode"); - std::abort(); - return Runtime::StreamCaptureModeGlobal; + return DeviceRuntime::kMemcpyHostToHost; }} +{stream_capture_mode_helper} }} // namespace {dispatch_functions} -}} // namespace infini::rt +}} // namespace infini::rt::runtime """ ) @@ -548,10 +668,8 @@ def main(): for wrapper_device, header_name, target in _DEVICE_HEADERS[device]: _write_wrapper(include_root, wrapper_device, header_name, target) - _write_generated_header(include_root, devices) - _write_runtime_dispatch( - pathlib.Path(args.source_output), args.runtime_header, devices - ) + _write_generated_header(include_root, source_root, devices) + _write_runtime_dispatch(pathlib.Path(args.source_output), source_root, devices) if __name__ == "__main__": diff --git a/src/c_api.cc b/src/c_api.cc index 178514a..61b40b3 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -13,32 +13,26 @@ namespace { using infini::rt::Device; -using infini::rt::Graph; -using infini::rt::GraphExec; -using infini::rt::Runtime; -using infini::rt::Stream; +using infini::rt::runtime::Runtime; struct CStream { - Stream stream; + Device::Type device_type; + void* raw; }; struct CGraph { - Graph graph; + Device::Type device_type; + void* raw; }; struct CGraphExec { - GraphExec graph_exec; + Device::Type device_type; + void* raw; }; template infiniRtStatus_t Guard(Func&& func) { - try { - return std::forward(func)(); - } catch (const std::bad_alloc&) { - return INFINI_RT_STATUS_RUNTIME_ERROR; - } catch (...) { - return INFINI_RT_STATUS_RUNTIME_ERROR; - } + return std::forward(func)(); } template @@ -84,28 +78,28 @@ Device::Type ToCppDeviceType(infiniRtDeviceType_t type) { auto ToNvidiaCaptureMode(infiniRtStreamCaptureMode_t mode) { switch (mode) { case INFINI_RT_STREAM_CAPTURE_MODE_GLOBAL: - return Runtime::StreamCaptureModeGlobal; + return Runtime::kStreamCaptureModeGlobal; case INFINI_RT_STREAM_CAPTURE_MODE_THREAD_LOCAL: - return Runtime::StreamCaptureModeThreadLocal; + return Runtime::kStreamCaptureModeThreadLocal; case INFINI_RT_STREAM_CAPTURE_MODE_RELAXED: - return Runtime::StreamCaptureModeRelaxed; + return Runtime::kStreamCaptureModeRelaxed; } - return Runtime::StreamCaptureModeRelaxed; + return Runtime::kStreamCaptureModeRelaxed; } -auto RawNvidiaStream(Stream stream) { +auto RawNvidiaStream(const CStream* stream) { return static_cast::Stream>( - stream.raw()); + stream->raw); } -auto RawNvidiaGraph(Graph graph) { +auto RawNvidiaGraph(const CGraph* graph) { return static_cast::Graph>( - graph.raw()); + graph->raw); } -auto RawNvidiaGraphExec(GraphExec graph_exec) { +auto RawNvidiaGraphExec(const CGraphExec* graph_exec) { return static_cast::GraphExec>( - graph_exec.raw()); + graph_exec->raw); } #endif @@ -134,7 +128,12 @@ infiniRtStatus_t infiniRtStreamWrap(infiniRtDevice_t device, if (device_type == Device::Type::kCount) { return INFINI_RT_STATUS_UNSUPPORTED_DEVICE; } - *stream = new CStream{Stream{device_type, native_stream}}; + auto* wrapped = new (std::nothrow) CStream{device_type, native_stream}; + if (wrapped == nullptr) { + return INFINI_RT_STATUS_RUNTIME_ERROR; + } + *stream = wrapped; + return INFINI_RT_STATUS_SUCCESS; }); } @@ -154,12 +153,12 @@ infiniRtStatus_t infiniRtStreamBeginCapture(infiniRtStream_t stream, } return Guard([&] { auto* wrapped = AsStream(stream); - switch (wrapped->stream.device_type()) { + switch (wrapped->device_type) { #if defined(WITH_NVIDIA) case Device::Type::kNvidia: return CheckBackendCall([&] { return Runtime::StreamBeginCapture( - RawNvidiaStream(wrapped->stream), ToNvidiaCaptureMode(mode)); + RawNvidiaStream(wrapped), ToNvidiaCaptureMode(mode)); }); #endif default: @@ -175,19 +174,24 @@ infiniRtStatus_t infiniRtStreamEndCapture(infiniRtStream_t stream, } return Guard([&] { auto* wrapped = AsStream(stream); - switch (wrapped->stream.device_type()) { + switch (wrapped->device_type) { #if defined(WITH_NVIDIA) case Device::Type::kNvidia: { typename Runtime::Graph raw_graph = {}; const auto status = CheckBackendCall([&] { return Runtime::StreamEndCapture( - RawNvidiaStream(wrapped->stream), &raw_graph); + RawNvidiaStream(wrapped), &raw_graph); }); if (status != INFINI_RT_STATUS_SUCCESS) { return status; } - *graph = new CGraph{ - Graph{Device::Type::kNvidia, static_cast(raw_graph)}}; + auto* wrapped_graph = new (std::nothrow) + CGraph{Device::Type::kNvidia, static_cast(raw_graph)}; + if (wrapped_graph == nullptr) { + return INFINI_RT_STATUS_RUNTIME_ERROR; + } + *graph = wrapped_graph; + return INFINI_RT_STATUS_SUCCESS; } #endif @@ -203,12 +207,12 @@ infiniRtStatus_t infiniRtGraphDestroy(infiniRtGraph_t graph) { } return Guard([&] { auto* wrapped = AsGraph(graph); - switch (wrapped->graph.device_type()) { + switch (wrapped->device_type) { #if defined(WITH_NVIDIA) case Device::Type::kNvidia: { const auto status = CheckBackendCall([&] { return Runtime::GraphDestroy( - RawNvidiaGraph(wrapped->graph)); + RawNvidiaGraph(wrapped)); }); // The C wrapper owns only the wrapper object. The backend destroy call // above owns the native graph handle. @@ -230,19 +234,24 @@ infiniRtStatus_t infiniRtGraphInstantiate(infiniRtGraphExec_t* graph_exec, } return Guard([&] { auto* wrapped = AsGraph(graph); - switch (wrapped->graph.device_type()) { + switch (wrapped->device_type) { #if defined(WITH_NVIDIA) case Device::Type::kNvidia: { typename Runtime::GraphExec raw_exec = {}; const auto status = CheckBackendCall([&] { return Runtime::GraphInstantiate( - &raw_exec, RawNvidiaGraph(wrapped->graph)); + &raw_exec, RawNvidiaGraph(wrapped)); }); if (status != INFINI_RT_STATUS_SUCCESS) { return status; } - *graph_exec = new CGraphExec{ - GraphExec{Device::Type::kNvidia, static_cast(raw_exec)}}; + auto* wrapped_exec = new (std::nothrow) + CGraphExec{Device::Type::kNvidia, static_cast(raw_exec)}; + if (wrapped_exec == nullptr) { + return INFINI_RT_STATUS_RUNTIME_ERROR; + } + *graph_exec = wrapped_exec; + return INFINI_RT_STATUS_SUCCESS; } #endif @@ -258,12 +267,12 @@ infiniRtStatus_t infiniRtGraphExecDestroy(infiniRtGraphExec_t graph_exec) { } return Guard([&] { auto* wrapped = AsGraphExec(graph_exec); - switch (wrapped->graph_exec.device_type()) { + switch (wrapped->device_type) { #if defined(WITH_NVIDIA) case Device::Type::kNvidia: { const auto status = CheckBackendCall([&] { return Runtime::GraphExecDestroy( - RawNvidiaGraphExec(wrapped->graph_exec)); + RawNvidiaGraphExec(wrapped)); }); // The C wrapper owns only the wrapper object. The backend destroy call // above owns the native executable graph handle. @@ -286,17 +295,15 @@ infiniRtStatus_t infiniRtGraphLaunch(infiniRtGraphExec_t graph_exec, return Guard([&] { auto* exec = AsGraphExec(graph_exec); auto* wrapped_stream = AsStream(stream); - if (exec->graph_exec.device_type() != - wrapped_stream->stream.device_type()) { + if (exec->device_type != wrapped_stream->device_type) { return INFINI_RT_STATUS_INVALID_ARGUMENT; } - switch (exec->graph_exec.device_type()) { + switch (exec->device_type) { #if defined(WITH_NVIDIA) case Device::Type::kNvidia: return CheckBackendCall([&] { return Runtime::GraphLaunch( - RawNvidiaGraphExec(exec->graph_exec), - RawNvidiaStream(wrapped_stream->stream)); + RawNvidiaGraphExec(exec), RawNvidiaStream(wrapped_stream)); }); #endif default: diff --git a/src/graph.h b/src/graph.h deleted file mode 100644 index 4a972e3..0000000 --- a/src/graph.h +++ /dev/null @@ -1,74 +0,0 @@ -#ifndef INFINI_RT_GRAPH_H_ -#define INFINI_RT_GRAPH_H_ - -#include - -#include "device.h" - -namespace infini::rt { - -enum class StreamCaptureMode { - kGlobal = 0, - kThreadLocal = 1, - kRelaxed = 2, -}; - -// Public dispatch wrappers keep backend handles opaque while preserving -// enough device identity for cross-device graph dispatch. -class Stream { - public: - Stream() = default; - - Stream(Device::Type device_type, void* raw) - : device_type_{device_type}, raw_{raw} {} - - Device::Type device_type() const { return device_type_; } - - void* raw() const { return raw_; } - - explicit operator bool() const { return raw_ != nullptr; } - - private: - Device::Type device_type_{Device::Type::kCpu}; - void* raw_{nullptr}; -}; - -class Graph { - public: - Graph() = default; - - Graph(Device::Type device_type, void* raw) - : device_type_{device_type}, raw_{raw} {} - - Device::Type device_type() const { return device_type_; } - - void* raw() const { return raw_; } - - explicit operator bool() const { return raw_ != nullptr; } - - private: - Device::Type device_type_{Device::Type::kCpu}; - void* raw_{nullptr}; -}; - -class GraphExec { - public: - GraphExec() = default; - - GraphExec(Device::Type device_type, void* raw) - : device_type_{device_type}, raw_{raw} {} - - Device::Type device_type() const { return device_type_; } - - void* raw() const { return raw_; } - - explicit operator bool() const { return raw_ != nullptr; } - - private: - Device::Type device_type_{Device::Type::kCpu}; - void* raw_{nullptr}; -}; - -} // namespace infini::rt - -#endif diff --git a/src/native/ascend/runtime_.h b/src/native/ascend/runtime_.h index c183714..0642027 100644 --- a/src/native/ascend/runtime_.h +++ b/src/native/ascend/runtime_.h @@ -2,6 +2,7 @@ #define INFINI_RT_ASCEND_RUNTIME__H_ #include +#include #include // clang-format off @@ -11,19 +12,21 @@ #include "native/ascend/device_.h" #include "runtime.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : DeviceRuntime> { - using Stream = aclrtStream; + using Error = aclError; - using Graph = void*; + using Stream = aclrtStream; - using GraphExec = void*; + using Event = void*; static constexpr Device::Type kDeviceType = Device::Type::kAscend; + static constexpr Error kSuccess = ACL_SUCCESS; + static constexpr auto SetDevice = aclrtSetDevice; static constexpr auto GetDevice = aclrtGetDevice; @@ -42,59 +45,79 @@ struct Runtime return aclrtMalloc(ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); }; + static Error MallocHost(void**, std::size_t) { return Unsupported(); } + + static Error MallocAsync(void**, std::size_t, Stream) { + return Unsupported(); + } + static constexpr auto Free = aclrtFree; + static Error FreeHost(void*) { return Unsupported(); } + + static Error FreeAsync(void*, Stream) { return Unsupported(); } + + static Error MemGetInfo(std::size_t*, std::size_t*) { return Unsupported(); } + static constexpr auto Memcpy = [](void* dst, const void* src, size_t count, aclrtMemcpyKind kind) { return aclrtMemcpy(dst, count, src, count, kind); }; - static auto MemcpyAsync(void* dst, const void* src, size_t count, - aclrtMemcpyKind kind, Stream stream) { + static constexpr auto MemcpyAsync = [](void* dst, const void* src, + size_t count, aclrtMemcpyKind kind, + Stream stream) { return aclrtMemcpyAsync(dst, count, src, count, kind, stream); - } + }; - static constexpr auto MemcpyHostToHost = ACL_MEMCPY_HOST_TO_HOST; + static constexpr auto kMemcpyHostToHost = ACL_MEMCPY_HOST_TO_HOST; - static constexpr auto MemcpyHostToDevice = ACL_MEMCPY_HOST_TO_DEVICE; + static constexpr auto kMemcpyHostToDevice = ACL_MEMCPY_HOST_TO_DEVICE; - static constexpr auto MemcpyDeviceToHost = ACL_MEMCPY_DEVICE_TO_HOST; + static constexpr auto kMemcpyDeviceToHost = ACL_MEMCPY_DEVICE_TO_HOST; - static constexpr auto MemcpyDeviceToDevice = ACL_MEMCPY_DEVICE_TO_DEVICE; + static constexpr auto kMemcpyDeviceToDevice = ACL_MEMCPY_DEVICE_TO_DEVICE; - static auto Memset(void* ptr, int value, size_t count) { + static constexpr auto Memset = [](void* ptr, int value, size_t count) { return aclrtMemset(ptr, count, value, count); - } + }; - static auto StreamCreate(Stream* stream) { - return aclrtCreateStreamWithConfig(stream, 0, ACL_STREAM_FAST_LAUNCH); + static Error MemsetAsync(void*, int, std::size_t, Stream) { + return Unsupported(); } + static constexpr auto StreamCreate = aclrtCreateStream; + static constexpr auto StreamDestroy = aclrtDestroyStream; static constexpr auto StreamSynchronize = aclrtSynchronizeStream; - static constexpr int StreamCaptureModeGlobal = 0; + static Error StreamWaitEvent(Stream, Event, unsigned int) { + return Unsupported(); + } - static constexpr int StreamCaptureModeThreadLocal = 1; + static Error EventCreate(Event*) { return Unsupported(); } - static constexpr int StreamCaptureModeRelaxed = 2; + static Error EventCreateWithFlags(Event*, unsigned int) { + return Unsupported(); + } - static int StreamBeginCapture(Stream, int) { return 1; } + static Error EventRecord(Event, Stream) { return Unsupported(); } - static int StreamEndCapture(Stream, Graph*) { return 1; } + static Error EventQuery(Event) { return Unsupported(); } - static int GraphDestroy(Graph) { return 1; } + static Error EventSynchronize(Event) { return Unsupported(); } - static int GraphInstantiate(GraphExec*, Graph) { return 1; } + static Error EventDestroy(Event) { return Unsupported(); } - static int GraphExecDestroy(GraphExec) { return 1; } + static Error EventElapsedTime(float*, Event, Event) { return Unsupported(); } - static int GraphLaunch(GraphExec, Stream) { return 1; } + private: + static Error Unsupported() { return static_cast(1); } }; static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cambricon/runtime_.h b/src/native/cambricon/runtime_.h index 76e6a2d..c4f4112 100644 --- a/src/native/cambricon/runtime_.h +++ b/src/native/cambricon/runtime_.h @@ -5,23 +5,30 @@ #include #include +#include #include "native/cambricon/device_.h" #include "runtime.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : DeviceRuntime> { - using Stream = cnrtQueue_t; + using Error = cnrtRet_t; - using Graph = void*; + using Stream = cnrtQueue_t; - using GraphExec = void*; + using Event = void*; static constexpr Device::Type kDeviceType = Device::Type::kCambricon; +#ifdef CNRT_RET_SUCCESS + static constexpr Error kSuccess = CNRT_RET_SUCCESS; +#else + static constexpr Error kSuccess = cnrtSuccess; +#endif + static constexpr auto SetDevice = cnrtSetDevice; static constexpr auto GetDevice = cnrtGetDevice; @@ -38,56 +45,82 @@ struct Runtime static constexpr auto Malloc = cnrtMalloc; + static Error MallocHost(void**, std::size_t) { return Unsupported(); } + + static Error MallocAsync(void**, std::size_t, Stream) { + return Unsupported(); + } + static constexpr auto Free = cnrtFree; + static Error FreeHost(void*) { return Unsupported(); } + + static Error FreeAsync(void*, Stream) { return Unsupported(); } + + static Error MemGetInfo(std::size_t*, std::size_t*) { return Unsupported(); } + static constexpr auto Memcpy = [](void* dst, const void* src, std::size_t size, auto kind) { return cnrtMemcpy(dst, const_cast(src), size, kind); }; - static auto MemcpyAsync(void* dst, const void* src, std::size_t size, - cnrtMemTransDir_t kind, Stream stream) { - return cnrtMemcpyAsync_V2(dst, const_cast(src), size, stream, kind); - } + static constexpr auto MemcpyAsync = [](void* dst, const void* src, + std::size_t size, auto kind, + Stream stream) { + if constexpr (std::is_invocable_v) { + return cnrtMemcpyAsync(dst, const_cast(src), size, kind, stream); + } else { + return cnrtMemcpyAsync(dst, const_cast(src), size, stream, kind); + } + }; - static constexpr auto MemcpyHostToHost = cnrtMemcpyHostToHost; + static constexpr auto kMemcpyHostToHost = cnrtMemcpyHostToHost; - static constexpr auto MemcpyHostToDevice = cnrtMemcpyHostToDev; + static constexpr auto kMemcpyHostToDevice = cnrtMemcpyHostToDev; - static constexpr auto MemcpyDeviceToHost = cnrtMemcpyDevToHost; + static constexpr auto kMemcpyDeviceToHost = cnrtMemcpyDevToHost; - static constexpr auto MemcpyDeviceToDevice = cnrtMemcpyDevToDev; + static constexpr auto kMemcpyDeviceToDevice = cnrtMemcpyDevToDev; static constexpr auto Memset = cnrtMemset; + static Error MemsetAsync(void*, int, std::size_t, Stream) { + return Unsupported(); + } + static constexpr auto StreamCreate = cnrtQueueCreate; static constexpr auto StreamDestroy = cnrtQueueDestroy; static constexpr auto StreamSynchronize = cnrtQueueSync; - static constexpr auto StreamCaptureModeGlobal = cnrtQueueCaptureModeGlobal; + static Error StreamWaitEvent(Stream, Event, unsigned int) { + return Unsupported(); + } - static constexpr auto StreamCaptureModeThreadLocal = - cnrtQueueCaptureModeThreadLocal; + static Error EventCreate(Event*) { return Unsupported(); } - static constexpr auto StreamCaptureModeRelaxed = cnrtQueueCaptureModeRelaxed; + static Error EventCreateWithFlags(Event*, unsigned int) { + return Unsupported(); + } - static int StreamBeginCapture(Stream, cnrtQueueCaptureMode_t) { return 1; } + static Error EventRecord(Event, Stream) { return Unsupported(); } - static int StreamEndCapture(Stream, Graph*) { return 1; } + static Error EventQuery(Event) { return Unsupported(); } - static int GraphDestroy(Graph) { return 1; } + static Error EventSynchronize(Event) { return Unsupported(); } - static int GraphInstantiate(GraphExec*, Graph) { return 1; } + static Error EventDestroy(Event) { return Unsupported(); } - static int GraphExecDestroy(GraphExec) { return 1; } + static Error EventElapsedTime(float*, Event, Event) { return Unsupported(); } - static int GraphLaunch(GraphExec, Stream) { return 1; } + private: + static Error Unsupported() { return static_cast(1); } }; static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cpu/runtime_.h b/src/native/cpu/runtime_.h index 3946d95..c928e1f 100644 --- a/src/native/cpu/runtime_.h +++ b/src/native/cpu/runtime_.h @@ -1,94 +1,233 @@ #ifndef INFINI_RT_CPU_RUNTIME__H_ #define INFINI_RT_CPU_RUNTIME__H_ -#include +#include +#include #include #include +#include #include "runtime.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : RuntimeBase> { + static constexpr Device::Type kDeviceType = Device::Type::kCpu; + + using Error = int; + using Stream = void*; - using Graph = void*; + using Event = void*; - using GraphExec = void*; + static constexpr Error kSuccess = 0; - static constexpr Device::Type kDeviceType = Device::Type::kCpu; + static constexpr Error kErrorInvalidValue = 1; + + static constexpr Error kErrorMemoryAllocation = 2; - static void SetDevice(int index) { - if (index != 0) { - assert(false && "CPU device index must be 0"); - std::abort(); + static Error SetDevice(int device) { + if (device != 0) { + return kErrorInvalidValue; } + + return kSuccess; } - static void GetDevice(int* index) { - assert(index != nullptr); - *index = 0; + static Error GetDevice(int* device) { + if (device == nullptr) { + return kErrorInvalidValue; + } + + *device = 0; + + return kSuccess; } - static void GetDeviceCount(int* count) { - assert(count != nullptr); + static Error GetDeviceCount(int* count) { + if (count == nullptr) { + return kErrorInvalidValue; + } + *count = 1; + + return kSuccess; } - static void DeviceSynchronize() {} + static Error DeviceSynchronize() { return kSuccess; } + + static Error Malloc(void** ptr, std::size_t size) { + if (ptr == nullptr) { + return kErrorInvalidValue; + } - static void Malloc(void** ptr, std::size_t size) { *ptr = std::malloc(size); } + *ptr = std::malloc(size); - static void Free(void* ptr) { std::free(ptr); } + if (size != 0 && *ptr == nullptr) { + return kErrorMemoryAllocation; + } - static void Memcpy(void* dst, const void* src, std::size_t size, int) { - std::memcpy(dst, src, size); + return kSuccess; + } + + static Error MallocHost(void** ptr, std::size_t size) { + return Malloc(ptr, size); } - static int MemcpyAsync(void*, const void*, std::size_t, int, Stream) { - return 1; + static Error MallocAsync(void** ptr, std::size_t size, Stream) { + return kErrorInvalidValue; } - static constexpr int MemcpyHostToHost = 0; + static Error Free(void* ptr) { + std::free(ptr); + + return kSuccess; + } - static constexpr int MemcpyHostToDevice = 0; + static Error FreeHost(void* ptr) { return Free(ptr); } - static constexpr int MemcpyDeviceToHost = 1; + static Error FreeAsync(void* ptr, Stream) { return kErrorInvalidValue; } + + static Error MemGetInfo(std::size_t* free, std::size_t* total) { + if (free == nullptr || total == nullptr) { + return kErrorInvalidValue; + } - static constexpr int MemcpyDeviceToDevice = 0; + *free = 0; + *total = 0; + +#ifndef _WIN32 + FILE* fp = std::fopen("/proc/meminfo", "r"); + if (fp == nullptr) { + return kErrorInvalidValue; + } + + char label[64]; + std::size_t value = 0; + while (std::fscanf(fp, "%63s %zu %*s", label, &value) == 2) { + if (std::strcmp(label, "MemTotal:") == 0) { + *total = value * 1024; + } else if (std::strcmp(label, "MemAvailable:") == 0) { + *free = value * 1024; + } + } + std::fclose(fp); +#endif + + return *total == 0 ? kErrorInvalidValue : kSuccess; + } + + static Error Memcpy(void* dst, const void* src, std::size_t size, int) { + if ((dst == nullptr || src == nullptr) && size != 0) { + return kErrorInvalidValue; + } + + std::memcpy(dst, src, size); + + return kSuccess; + } + + static Error MemcpyAsync(void* dst, const void* src, std::size_t size, + int kind, Stream) { + return kErrorInvalidValue; + } + + static Error Memset(void* ptr, int value, std::size_t count) { + if (ptr == nullptr && count != 0) { + return kErrorInvalidValue; + } - static void Memset(void* ptr, int value, std::size_t count) { std::memset(ptr, value, count); + + return kSuccess; } - static int StreamCreate(Stream*) { return 1; } + static Error MemsetAsync(void* ptr, int value, std::size_t count, Stream) { + return kErrorInvalidValue; + } - static int StreamDestroy(Stream) { return 1; } + static Error StreamCreate(Stream* stream) { + if (stream == nullptr) { + return kErrorInvalidValue; + } - static int StreamSynchronize(Stream) { return 1; } + *stream = nullptr; - static constexpr int StreamCaptureModeGlobal = 0; + return kSuccess; + } - static constexpr int StreamCaptureModeThreadLocal = 1; + static Error StreamDestroy(Stream) { return kSuccess; } - static constexpr int StreamCaptureModeRelaxed = 2; + static Error StreamSynchronize(Stream) { return kSuccess; } - static int StreamBeginCapture(Stream, int) { return 1; } + static Error StreamWaitEvent(Stream, Event, unsigned int) { return kSuccess; } + + using CpuEvent = std::chrono::steady_clock::time_point; + + static Error EventCreate(Event* event) { + if (event == nullptr) { + return kErrorInvalidValue; + } + + *event = new (std::nothrow) CpuEvent(std::chrono::steady_clock::now()); + + return *event == nullptr ? kErrorMemoryAllocation : kSuccess; + } + + static Error EventCreateWithFlags(Event* event, unsigned int) { + return EventCreate(event); + } - static int StreamEndCapture(Stream, Graph*) { return 1; } + static Error EventRecord(Event event, Stream) { + if (event == nullptr) { + return kErrorInvalidValue; + } + + *static_cast(event) = std::chrono::steady_clock::now(); + + return kSuccess; + } + + static Error EventQuery(Event event) { + return event == nullptr ? kErrorInvalidValue : kSuccess; + } + + static Error EventSynchronize(Event event) { + return event == nullptr ? kErrorInvalidValue : kSuccess; + } + + static Error EventDestroy(Event event) { + delete static_cast(event); + + return kSuccess; + } + + static Error EventElapsedTime(float* ms, Event start, Event end) { + if (ms == nullptr || start == nullptr || end == nullptr) { + return kErrorInvalidValue; + } + + const auto* start_time = static_cast(start); + const auto* end_time = static_cast(end); + const auto duration = std::chrono::duration_cast( + *end_time - *start_time); + *ms = static_cast(duration.count()) / 1000.0f; + + return kSuccess; + } - static int GraphDestroy(Graph) { return 1; } + static constexpr int kMemcpyHostToHost = 0; - static int GraphInstantiate(GraphExec*, Graph) { return 1; } + static constexpr int kMemcpyHostToDevice = 1; - static int GraphExecDestroy(GraphExec) { return 1; } + static constexpr int kMemcpyDeviceToHost = 2; - static int GraphLaunch(GraphExec, Stream) { return 1; } + static constexpr int kMemcpyDeviceToDevice = 3; }; static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cuda/hygon/runtime_.h b/src/native/cuda/hygon/runtime_.h index a9f41ec..5051a87 100644 --- a/src/native/cuda/hygon/runtime_.h +++ b/src/native/cuda/hygon/runtime_.h @@ -10,34 +10,122 @@ #include "native/cuda/hygon/device_.h" #include "native/cuda/runtime_.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : CudaRuntime> { + using Error = cudaError_t; + using Stream = cudaStream_t; + using Event = cudaEvent_t; + static constexpr Device::Type kDeviceType = Device::Type::kHygon; + static constexpr Error kSuccess = cudaSuccess; + + static constexpr auto SetDevice = cudaSetDevice; + + static constexpr auto GetDevice = cudaGetDevice; + + static constexpr auto GetDeviceCount = cudaGetDeviceCount; + + static constexpr auto DeviceSynchronize = cudaDeviceSynchronize; + static constexpr auto Malloc = [](auto&&... args) { return cudaMalloc(std::forward(args)...); }; - static constexpr auto Memcpy = cudaMemcpy; + static constexpr auto MallocHost = [](auto&&... args) { + return cudaMallocHost(std::forward(args)...); + }; + + static constexpr auto MallocAsync = [](auto&&... args) { + return cudaMallocAsync(std::forward(args)...); + }; static constexpr auto Free = [](auto&&... args) { return cudaFree(std::forward(args)...); }; - static constexpr auto MemcpyHostToDevice = cudaMemcpyHostToDevice; + static constexpr auto FreeHost = [](auto&&... args) { + return cudaFreeHost(std::forward(args)...); + }; - static constexpr auto MemcpyDeviceToHost = cudaMemcpyDeviceToHost; + static constexpr auto FreeAsync = [](auto&&... args) { + return cudaFreeAsync(std::forward(args)...); + }; + + static constexpr auto MemGetInfo = [](auto&&... args) { + return cudaMemGetInfo(std::forward(args)...); + }; + + static constexpr auto Memcpy = cudaMemcpy; + + static constexpr auto MemcpyAsync = cudaMemcpyAsync; + + static constexpr auto kMemcpyHostToHost = cudaMemcpyHostToHost; + + static constexpr auto kMemcpyHostToDevice = cudaMemcpyHostToDevice; + + static constexpr auto kMemcpyDeviceToHost = cudaMemcpyDeviceToHost; + + static constexpr auto kMemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; static constexpr auto Memset = cudaMemset; + + static constexpr auto MemsetAsync = [](auto&&... args) { + return cudaMemsetAsync(std::forward(args)...); + }; + + static constexpr auto StreamCreate = [](auto&&... args) { + return cudaStreamCreate(std::forward(args)...); + }; + + static constexpr auto StreamDestroy = [](auto&&... args) { + return cudaStreamDestroy(std::forward(args)...); + }; + + static constexpr auto StreamSynchronize = [](auto&&... args) { + return cudaStreamSynchronize(std::forward(args)...); + }; + + static constexpr auto StreamWaitEvent = [](auto&&... args) { + return cudaStreamWaitEvent(std::forward(args)...); + }; + + static constexpr auto EventCreate = [](auto&&... args) { + return cudaEventCreate(std::forward(args)...); + }; + + static constexpr auto EventCreateWithFlags = [](auto&&... args) { + return cudaEventCreateWithFlags(std::forward(args)...); + }; + + static constexpr auto EventRecord = [](auto&&... args) { + return cudaEventRecord(std::forward(args)...); + }; + + static constexpr auto EventQuery = [](auto&&... args) { + return cudaEventQuery(std::forward(args)...); + }; + + static constexpr auto EventSynchronize = [](auto&&... args) { + return cudaEventSynchronize(std::forward(args)...); + }; + + static constexpr auto EventDestroy = [](auto&&... args) { + return cudaEventDestroy(std::forward(args)...); + }; + + static constexpr auto EventElapsedTime = [](auto&&... args) { + return cudaEventElapsedTime(std::forward(args)...); + }; }; static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cuda/iluvatar/runtime_.h b/src/native/cuda/iluvatar/runtime_.h index 81a5cd2..09feb65 100644 --- a/src/native/cuda/iluvatar/runtime_.h +++ b/src/native/cuda/iluvatar/runtime_.h @@ -10,19 +10,25 @@ #include "native/cuda/iluvatar/device_.h" #include "native/cuda/runtime_.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : CudaRuntime> { + using Error = cudaError_t; + using Stream = cudaStream_t; using Graph = cudaGraph_t; using GraphExec = cudaGraphExec_t; + using Event = cudaEvent_t; + static constexpr Device::Type kDeviceType = Device::Type::kIluvatar; + static constexpr Error kSuccess = cudaSuccess; + static constexpr auto SetDevice = cudaSetDevice; static constexpr auto GetDevice = cudaGetDevice; @@ -35,20 +41,46 @@ struct Runtime return cudaMalloc(std::forward(args)...); }; - static constexpr auto Memcpy = cudaMemcpy; + static constexpr auto MallocHost = [](auto&&... args) { + return cudaMallocHost(std::forward(args)...); + }; + + static constexpr auto MallocAsync = [](auto&&... args) { + return cudaMallocAsync(std::forward(args)...); + }; static constexpr auto Free = cudaFree; - static constexpr auto MemcpyHostToHost = cudaMemcpyHostToHost; + static constexpr auto FreeHost = [](auto&&... args) { + return cudaFreeHost(std::forward(args)...); + }; + + static constexpr auto FreeAsync = [](auto&&... args) { + return cudaFreeAsync(std::forward(args)...); + }; + + static constexpr auto MemGetInfo = [](auto&&... args) { + return cudaMemGetInfo(std::forward(args)...); + }; + + static constexpr auto Memcpy = cudaMemcpy; + + static constexpr auto MemcpyAsync = cudaMemcpyAsync; + + static constexpr auto kMemcpyHostToHost = cudaMemcpyHostToHost; - static constexpr auto MemcpyHostToDevice = cudaMemcpyHostToDevice; + static constexpr auto kMemcpyHostToDevice = cudaMemcpyHostToDevice; - static constexpr auto MemcpyDeviceToHost = cudaMemcpyDeviceToHost; + static constexpr auto kMemcpyDeviceToHost = cudaMemcpyDeviceToHost; - static constexpr auto MemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; + static constexpr auto kMemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; static constexpr auto Memset = cudaMemset; + static constexpr auto MemsetAsync = [](auto&&... args) { + return cudaMemsetAsync(std::forward(args)...); + }; + static constexpr auto StreamCreate = [](auto&&... args) { return cudaStreamCreateWithFlags(std::forward(args)..., cudaStreamNonBlocking); @@ -62,16 +94,45 @@ struct Runtime return cudaStreamSynchronize(std::forward(args)...); }; - static constexpr auto MemcpyAsync = [](auto&&... args) { - return cudaMemcpyAsync(std::forward(args)...); + static constexpr auto StreamWaitEvent = [](auto&&... args) { + return cudaStreamWaitEvent(std::forward(args)...); + }; + + static constexpr auto EventCreate = [](auto&&... args) { + return cudaEventCreate(std::forward(args)...); + }; + + static constexpr auto EventCreateWithFlags = [](auto&&... args) { + return cudaEventCreateWithFlags(std::forward(args)...); + }; + + static constexpr auto EventRecord = [](auto&&... args) { + return cudaEventRecord(std::forward(args)...); + }; + + static constexpr auto EventQuery = [](auto&&... args) { + return cudaEventQuery(std::forward(args)...); + }; + + static constexpr auto EventSynchronize = [](auto&&... args) { + return cudaEventSynchronize(std::forward(args)...); + }; + + static constexpr auto EventDestroy = [](auto&&... args) { + return cudaEventDestroy(std::forward(args)...); + }; + + static constexpr auto EventElapsedTime = [](auto&&... args) { + return cudaEventElapsedTime(std::forward(args)...); }; - static constexpr auto StreamCaptureModeGlobal = cudaStreamCaptureModeGlobal; + static constexpr auto kStreamCaptureModeGlobal = cudaStreamCaptureModeGlobal; - static constexpr auto StreamCaptureModeThreadLocal = + static constexpr auto kStreamCaptureModeThreadLocal = cudaStreamCaptureModeThreadLocal; - static constexpr auto StreamCaptureModeRelaxed = cudaStreamCaptureModeRelaxed; + static constexpr auto kStreamCaptureModeRelaxed = + cudaStreamCaptureModeRelaxed; static constexpr auto StreamBeginCapture = [](auto&&... args) { return cudaStreamBeginCapture(std::forward(args)...); @@ -100,6 +161,6 @@ struct Runtime static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cuda/metax/runtime_.h b/src/native/cuda/metax/runtime_.h index 3e7fb5c..201f9b0 100644 --- a/src/native/cuda/metax/runtime_.h +++ b/src/native/cuda/metax/runtime_.h @@ -9,19 +9,25 @@ #include "native/cuda/metax/device_.h" #include "native/cuda/runtime_.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : CudaRuntime> { + using Error = mcError_t; + using Stream = mcStream_t; using Graph = void*; using GraphExec = void*; + using Event = mcEvent_t; + static constexpr Device::Type kDeviceType = Device::Type::kMetax; + static constexpr Error kSuccess = mcSuccess; + static constexpr auto SetDevice = mcSetDevice; static constexpr auto GetDevice = mcGetDevice; @@ -34,56 +40,121 @@ struct Runtime return mcMalloc(std::forward(args)...); }; - static constexpr auto Memcpy = [](auto&&... args) { - return mcMemcpy(std::forward(args)...); + static constexpr auto MallocHost = [](auto&&... args) { + return mcMallocHost(std::forward(args)...); + }; + + static constexpr auto MallocAsync = [](auto&&... args) { + return mcMallocAsync(std::forward(args)...); }; static constexpr auto Free = [](auto&&... args) { return mcFree(std::forward(args)...); }; - static constexpr auto MemcpyHostToHost = mcMemcpyHostToHost; + static constexpr auto FreeHost = [](auto&&... args) { + return mcFreeHost(std::forward(args)...); + }; - static constexpr auto MemcpyHostToDevice = mcMemcpyHostToDevice; + static constexpr auto FreeAsync = [](auto&&... args) { + return mcFreeAsync(std::forward(args)...); + }; - static constexpr auto MemcpyDeviceToHost = mcMemcpyDeviceToHost; + static constexpr auto MemGetInfo = [](auto&&... args) { + return mcMemGetInfo(std::forward(args)...); + }; - static constexpr auto MemcpyDeviceToDevice = mcMemcpyDeviceToDevice; + static constexpr auto Memcpy = [](auto&&... args) { + return mcMemcpy(std::forward(args)...); + }; + + static constexpr auto MemcpyAsync = [](auto&&... args) { + return mcMemcpyAsync(std::forward(args)...); + }; + + static constexpr auto kMemcpyHostToHost = mcMemcpyHostToHost; + + static constexpr auto kMemcpyHostToDevice = mcMemcpyHostToDevice; + + static constexpr auto kMemcpyDeviceToHost = mcMemcpyDeviceToHost; + + static constexpr auto kMemcpyDeviceToDevice = mcMemcpyDeviceToDevice; static constexpr auto Memset = mcMemset; - static int StreamCreate(Stream*) { return 1; } + static constexpr auto MemsetAsync = [](auto&&... args) { + return mcMemsetAsync(std::forward(args)...); + }; + + static constexpr auto StreamCreate = [](auto&&... args) { + return mcStreamCreate(std::forward(args)...); + }; - static int StreamDestroy(Stream) { return 1; } + static constexpr auto StreamDestroy = [](auto&&... args) { + return mcStreamDestroy(std::forward(args)...); + }; - static int StreamSynchronize(Stream) { return 1; } + static constexpr auto StreamSynchronize = [](auto&&... args) { + return mcStreamSynchronize(std::forward(args)...); + }; - static int MemcpyAsync(void*, const void*, std::size_t, - decltype(MemcpyHostToDevice), Stream) { - return 1; - } + static constexpr auto StreamWaitEvent = [](auto&&... args) { + return mcStreamWaitEvent(std::forward(args)...); + }; - static constexpr int StreamCaptureModeGlobal = 0; + static constexpr auto EventCreate = [](auto&&... args) { + return mcEventCreate(std::forward(args)...); + }; - static constexpr int StreamCaptureModeThreadLocal = 1; + static constexpr auto EventCreateWithFlags = [](auto&&... args) { + return mcEventCreateWithFlags(std::forward(args)...); + }; - static constexpr int StreamCaptureModeRelaxed = 2; + static constexpr auto EventRecord = [](auto&&... args) { + return mcEventRecord(std::forward(args)...); + }; - static int StreamBeginCapture(Stream, int) { return 1; } + static constexpr auto EventQuery = [](auto&&... args) { + return mcEventQuery(std::forward(args)...); + }; - static int StreamEndCapture(Stream, Graph*) { return 1; } + static constexpr auto EventSynchronize = [](auto&&... args) { + return mcEventSynchronize(std::forward(args)...); + }; - static int GraphDestroy(Graph) { return 1; } + static constexpr auto EventDestroy = [](auto&&... args) { + return mcEventDestroy(std::forward(args)...); + }; - static int GraphInstantiate(GraphExec*, Graph) { return 1; } + static constexpr auto EventElapsedTime = [](auto&&... args) { + return mcEventElapsedTime(std::forward(args)...); + }; + + static constexpr int kStreamCaptureModeGlobal = 0; + + static constexpr int kStreamCaptureModeThreadLocal = 1; + + static constexpr int kStreamCaptureModeRelaxed = 2; + + static Error StreamBeginCapture(Stream, int) { return static_cast(1); } + + static Error StreamEndCapture(Stream, Graph*) { + return static_cast(1); + } + + static Error GraphDestroy(Graph) { return static_cast(1); } + + static Error GraphInstantiate(GraphExec*, Graph) { + return static_cast(1); + } - static int GraphExecDestroy(GraphExec) { return 1; } + static Error GraphExecDestroy(GraphExec) { return static_cast(1); } - static int GraphLaunch(GraphExec, Stream) { return 1; } + static Error GraphLaunch(GraphExec, Stream) { return static_cast(1); } }; static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cuda/moore/runtime_.h b/src/native/cuda/moore/runtime_.h index 88702ff..37eab28 100644 --- a/src/native/cuda/moore/runtime_.h +++ b/src/native/cuda/moore/runtime_.h @@ -9,19 +9,25 @@ #include "native/cuda/moore/device_.h" #include "native/cuda/runtime_.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : CudaRuntime> { + using Error = musaError_t; + using Stream = musaStream_t; using Graph = void*; using GraphExec = void*; + using Event = musaEvent_t; + static constexpr Device::Type kDeviceType = Device::Type::kMoore; + static constexpr Error kSuccess = musaSuccess; + static constexpr auto SetDevice = musaSetDevice; static constexpr auto GetDevice = [](auto&&... args) { @@ -40,56 +46,121 @@ struct Runtime return musaMalloc(std::forward(args)...); }; - static constexpr auto Memcpy = [](auto&&... args) { - return musaMemcpy(std::forward(args)...); + static constexpr auto MallocHost = [](auto&&... args) { + return musaMallocHost(std::forward(args)...); + }; + + static constexpr auto MallocAsync = [](void**, std::size_t, Stream) { + return static_cast(1); }; static constexpr auto Free = [](auto&&... args) { return musaFree(std::forward(args)...); }; - static constexpr auto MemcpyHostToHost = musaMemcpyHostToHost; + static constexpr auto FreeHost = [](auto&&... args) { + return musaFreeHost(std::forward(args)...); + }; - static constexpr auto MemcpyHostToDevice = musaMemcpyHostToDevice; + static constexpr auto FreeAsync = [](void*, Stream) { + return static_cast(1); + }; - static constexpr auto MemcpyDeviceToHost = musaMemcpyDeviceToHost; + static constexpr auto MemGetInfo = [](auto&&... args) { + return musaMemGetInfo(std::forward(args)...); + }; - static constexpr auto MemcpyDeviceToDevice = musaMemcpyDeviceToDevice; + static constexpr auto Memcpy = [](auto&&... args) { + return musaMemcpy(std::forward(args)...); + }; + + static constexpr auto MemcpyAsync = [](auto&&... args) { + return musaMemcpyAsync(std::forward(args)...); + }; + + static constexpr auto kMemcpyHostToHost = musaMemcpyHostToHost; + + static constexpr auto kMemcpyHostToDevice = musaMemcpyHostToDevice; + + static constexpr auto kMemcpyDeviceToHost = musaMemcpyDeviceToHost; + + static constexpr auto kMemcpyDeviceToDevice = musaMemcpyDeviceToDevice; static constexpr auto Memset = musaMemset; - static int StreamCreate(Stream*) { return 1; } + static constexpr auto MemsetAsync = [](auto&&... args) { + return musaMemsetAsync(std::forward(args)...); + }; + + static constexpr auto StreamCreate = [](auto&&... args) { + return musaStreamCreate(std::forward(args)...); + }; - static int StreamDestroy(Stream) { return 1; } + static constexpr auto StreamDestroy = [](auto&&... args) { + return musaStreamDestroy(std::forward(args)...); + }; - static int StreamSynchronize(Stream) { return 1; } + static constexpr auto StreamSynchronize = [](auto&&... args) { + return musaStreamSynchronize(std::forward(args)...); + }; - static int MemcpyAsync(void*, const void*, std::size_t, - decltype(MemcpyHostToDevice), Stream) { - return 1; - } + static constexpr auto StreamWaitEvent = [](auto&&... args) { + return musaStreamWaitEvent(std::forward(args)...); + }; - static constexpr int StreamCaptureModeGlobal = 0; + static constexpr auto EventCreate = [](auto&&... args) { + return musaEventCreate(std::forward(args)...); + }; - static constexpr int StreamCaptureModeThreadLocal = 1; + static constexpr auto EventCreateWithFlags = [](auto&&... args) { + return musaEventCreateWithFlags(std::forward(args)...); + }; - static constexpr int StreamCaptureModeRelaxed = 2; + static constexpr auto EventRecord = [](auto&&... args) { + return musaEventRecord(std::forward(args)...); + }; - static int StreamBeginCapture(Stream, int) { return 1; } + static constexpr auto EventQuery = [](auto&&... args) { + return musaEventQuery(std::forward(args)...); + }; - static int StreamEndCapture(Stream, Graph*) { return 1; } + static constexpr auto EventSynchronize = [](auto&&... args) { + return musaEventSynchronize(std::forward(args)...); + }; - static int GraphDestroy(Graph) { return 1; } + static constexpr auto EventDestroy = [](auto&&... args) { + return musaEventDestroy(std::forward(args)...); + }; - static int GraphInstantiate(GraphExec*, Graph) { return 1; } + static constexpr auto EventElapsedTime = [](auto&&... args) { + return musaEventElapsedTime(std::forward(args)...); + }; + + static constexpr int kStreamCaptureModeGlobal = 0; + + static constexpr int kStreamCaptureModeThreadLocal = 1; + + static constexpr int kStreamCaptureModeRelaxed = 2; + + static Error StreamBeginCapture(Stream, int) { return static_cast(1); } + + static Error StreamEndCapture(Stream, Graph*) { + return static_cast(1); + } + + static Error GraphDestroy(Graph) { return static_cast(1); } + + static Error GraphInstantiate(GraphExec*, Graph) { + return static_cast(1); + } - static int GraphExecDestroy(GraphExec) { return 1; } + static Error GraphExecDestroy(GraphExec) { return static_cast(1); } - static int GraphLaunch(GraphExec, Stream) { return 1; } + static Error GraphLaunch(GraphExec, Stream) { return static_cast(1); } }; static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cuda/nvidia/runtime_.h b/src/native/cuda/nvidia/runtime_.h index 68651ad..c9d2649 100644 --- a/src/native/cuda/nvidia/runtime_.h +++ b/src/native/cuda/nvidia/runtime_.h @@ -10,19 +10,25 @@ #include "native/cuda/nvidia/device_.h" #include "native/cuda/runtime_.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : CudaRuntime> { + using Error = cudaError_t; + using Stream = cudaStream_t; using Graph = cudaGraph_t; using GraphExec = cudaGraphExec_t; + using Event = cudaEvent_t; + static constexpr Device::Type kDeviceType = Device::Type::kNvidia; + static constexpr Error kSuccess = cudaSuccess; + static constexpr auto SetDevice = cudaSetDevice; static constexpr auto GetDevice = cudaGetDevice; @@ -35,20 +41,46 @@ struct Runtime return cudaMalloc(std::forward(args)...); }; - static constexpr auto Memcpy = cudaMemcpy; + static constexpr auto MallocHost = [](auto&&... args) { + return cudaMallocHost(std::forward(args)...); + }; + + static constexpr auto MallocAsync = [](auto&&... args) { + return cudaMallocAsync(std::forward(args)...); + }; static constexpr auto Free = cudaFree; - static constexpr auto MemcpyHostToHost = cudaMemcpyHostToHost; + static constexpr auto FreeHost = [](auto&&... args) { + return cudaFreeHost(std::forward(args)...); + }; + + static constexpr auto FreeAsync = [](auto&&... args) { + return cudaFreeAsync(std::forward(args)...); + }; + + static constexpr auto MemGetInfo = [](auto&&... args) { + return cudaMemGetInfo(std::forward(args)...); + }; + + static constexpr auto Memcpy = cudaMemcpy; + + static constexpr auto MemcpyAsync = cudaMemcpyAsync; + + static constexpr auto kMemcpyHostToHost = cudaMemcpyHostToHost; - static constexpr auto MemcpyHostToDevice = cudaMemcpyHostToDevice; + static constexpr auto kMemcpyHostToDevice = cudaMemcpyHostToDevice; - static constexpr auto MemcpyDeviceToHost = cudaMemcpyDeviceToHost; + static constexpr auto kMemcpyDeviceToHost = cudaMemcpyDeviceToHost; - static constexpr auto MemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; + static constexpr auto kMemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; static constexpr auto Memset = cudaMemset; + static constexpr auto MemsetAsync = [](auto&&... args) { + return cudaMemsetAsync(std::forward(args)...); + }; + static constexpr auto StreamCreate = [](auto&&... args) { return cudaStreamCreateWithFlags(std::forward(args)..., cudaStreamNonBlocking); @@ -62,16 +94,45 @@ struct Runtime return cudaStreamSynchronize(std::forward(args)...); }; - static constexpr auto MemcpyAsync = [](auto&&... args) { - return cudaMemcpyAsync(std::forward(args)...); + static constexpr auto StreamWaitEvent = [](auto&&... args) { + return cudaStreamWaitEvent(std::forward(args)...); + }; + + static constexpr auto EventCreate = [](auto&&... args) { + return cudaEventCreate(std::forward(args)...); + }; + + static constexpr auto EventCreateWithFlags = [](auto&&... args) { + return cudaEventCreateWithFlags(std::forward(args)...); + }; + + static constexpr auto EventRecord = [](auto&&... args) { + return cudaEventRecord(std::forward(args)...); + }; + + static constexpr auto EventQuery = [](auto&&... args) { + return cudaEventQuery(std::forward(args)...); + }; + + static constexpr auto EventSynchronize = [](auto&&... args) { + return cudaEventSynchronize(std::forward(args)...); + }; + + static constexpr auto EventDestroy = [](auto&&... args) { + return cudaEventDestroy(std::forward(args)...); + }; + + static constexpr auto EventElapsedTime = [](auto&&... args) { + return cudaEventElapsedTime(std::forward(args)...); }; - static constexpr auto StreamCaptureModeGlobal = cudaStreamCaptureModeGlobal; + static constexpr auto kStreamCaptureModeGlobal = cudaStreamCaptureModeGlobal; - static constexpr auto StreamCaptureModeThreadLocal = + static constexpr auto kStreamCaptureModeThreadLocal = cudaStreamCaptureModeThreadLocal; - static constexpr auto StreamCaptureModeRelaxed = cudaStreamCaptureModeRelaxed; + static constexpr auto kStreamCaptureModeRelaxed = + cudaStreamCaptureModeRelaxed; static constexpr auto StreamBeginCapture = [](auto&&... args) { return cudaStreamBeginCapture(std::forward(args)...); @@ -140,6 +201,6 @@ struct Runtime static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cuda/runtime_.h b/src/native/cuda/runtime_.h index 1634d0b..59dd6bc 100644 --- a/src/native/cuda/runtime_.h +++ b/src/native/cuda/runtime_.h @@ -1,12 +1,11 @@ #ifndef INFINI_RT_CUDA_RUNTIME_H_ #define INFINI_RT_CUDA_RUNTIME_H_ -#include #include #include "runtime.h" -namespace infini::rt { +namespace infini::rt::runtime { /// ## CUDA-like runtime interface enforcement via CRTP. /// @@ -18,13 +17,19 @@ struct CudaRuntime : DeviceRuntime { DeviceRuntime::Validate(); static_assert( std::is_invocable_v, + size_t, decltype(Derived::kMemcpyHostToDevice)>, "`Runtime::Memcpy` must be callable with " - "`(void*, const void*, size_t, MemcpyHostToDevice)`."); + "`(void*, const void*, size_t, kMemcpyHostToDevice)`."); + static_assert( + std::is_invocable_v, + "`Runtime::MemcpyAsync` must be callable with " + "`(void*, const void*, size_t, kMemcpyHostToDevice, Stream)`."); return true; } }; -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/runtime.h b/src/runtime.h index b81ca92..c3e7640 100644 --- a/src/runtime.h +++ b/src/runtime.h @@ -5,9 +5,8 @@ #include #include "device.h" -#include "graph.h" -namespace infini::rt { +namespace infini::rt::runtime { template struct Runtime; @@ -15,8 +14,9 @@ struct Runtime; /// ## Interface enforcement via CRTP. /// /// Inherit from the appropriate base to declare which interface level a -/// `Runtime` specialization implements. After the struct is fully defined, call -/// `static_assert(Runtime<...>::Validate())`. The chained `Validate()` checks +/// `runtime::Runtime` specialization implements. After the struct is fully +/// defined, call `static_assert(Runtime<...>::Validate())`. The chained +/// `Validate()` checks /// every required member's existence and signature at compile time, analogous /// to how `override` catches signature mismatches for virtual functions. /// @@ -31,6 +31,11 @@ struct RuntimeBase { std::is_same_v, Device::Type>, "`Runtime` must define `static constexpr Device::Type kDeviceType`."); + static_assert(sizeof(typename Derived::Error) > 0, + "`Runtime` must define an `Error` type alias."); + static_assert(std::is_same_v, + typename Derived::Error>, + "`Runtime` must define `static constexpr Error kSuccess`."); return true; } }; @@ -52,50 +57,6 @@ struct DeviceRuntime : RuntimeBase { } }; -enum class MemcpyKind { - kHostToHost, - kHostToDevice, - kDeviceToHost, - kDeviceToDevice, -}; - -void SetDevice(Device device); - -void GetDevice(Device* device); - -void GetDeviceCount(int* count, Device::Type type); - -void DeviceSynchronize(); - -void Malloc(void** ptr, std::size_t size); - -void Free(void* ptr); - -void Memcpy(void* dst, const void* src, std::size_t count, MemcpyKind kind); - -void MemcpyAsync(void* dst, const void* src, std::size_t count, MemcpyKind kind, - Stream stream); - -void Memset(void* ptr, int value, std::size_t count); - -void StreamCreate(Stream* stream); - -void StreamDestroy(Stream stream); - -void StreamSynchronize(Stream stream); - -void StreamBeginCapture(Stream stream, StreamCaptureMode mode); - -void StreamEndCapture(Stream stream, Graph* graph); - -void GraphDestroy(Graph graph); - -void GraphInstantiate(GraphExec* graph_exec, Graph graph); - -void GraphExecDestroy(GraphExec graph_exec); - -void GraphLaunch(GraphExec graph_exec, Stream stream); - -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/tensor_view.h b/src/tensor_view.h index 0b6fc59..dcf7cc9 100644 --- a/src/tensor_view.h +++ b/src/tensor_view.h @@ -3,6 +3,8 @@ #include #include +#include +#include #include #include "data_type.h" @@ -11,6 +13,22 @@ namespace infini::rt { +namespace tensor_view_detail { + +template +struct IsTensorLike : std::false_type {}; + +template +struct IsTensorLike().data()), + decltype(std::declval().shape()), + decltype(std::declval().dtype()), + decltype(std::declval().device()), + decltype(std::declval().strides())>> + : std::true_type {}; + +} // namespace tensor_view_detail + class TensorView { public: using Size = std::size_t; @@ -23,7 +41,9 @@ class TensorView { using Strides = std::vector; - template + template ::value>> TensorView(const TensorLike& tensor) : data_{const_cast(static_cast(tensor.data()))}, shape_{tensor.shape()}, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0d765c6..b801026 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -4,19 +4,134 @@ function(add_infini_rt_test target) add_test(NAME ${target} COMMAND ${target}) endfunction() +function(add_infini_rt_backend_runtime_test backend device_type runtime_header + expect_async_memcpy_success supports_host_memory supports_async_memory + supports_mem_get_info supports_memset_async supports_stream_wait_event + supports_event supports_event_elapsed_time) + string(TOLOWER "${backend}" backend_lower) + set(target "test_${backend_lower}_runtime") + add_infini_rt_test(${target} test_native_runtime.cc) + target_compile_definitions(${target} + PRIVATE + "INFINI_RT_TEST_BACKEND_NAME=\"${backend}\"" + "INFINI_RT_TEST_DEVICE_TYPE=${device_type}" + "INFINI_RT_TEST_RUNTIME_HEADER=\"${runtime_header}\"" + "INFINI_RT_TEST_EXPECT_ASYNC_MEMCPY_SUCCESS=${expect_async_memcpy_success}" + "INFINI_RT_TEST_SUPPORTS_HOST_MEMORY=${supports_host_memory}" + "INFINI_RT_TEST_SUPPORTS_ASYNC_MEMORY=${supports_async_memory}" + "INFINI_RT_TEST_SUPPORTS_MEM_GET_INFO=${supports_mem_get_info}" + "INFINI_RT_TEST_SUPPORTS_MEMSET_ASYNC=${supports_memset_async}" + "INFINI_RT_TEST_SUPPORTS_STREAM_WAIT_EVENT=${supports_stream_wait_event}" + "INFINI_RT_TEST_SUPPORTS_EVENT=${supports_event}" + "INFINI_RT_TEST_SUPPORTS_EVENT_ELAPSED_TIME=${supports_event_elapsed_time}") +endfunction() + add_infini_rt_test(test_smoke test_smoke.cc) add_infini_rt_test(test_core test_core.cc) +set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND OFF) + if(WITH_CPU) - add_infini_rt_test(test_cpu_runtime test_cpu_runtime.cc) + set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) + add_infini_rt_backend_runtime_test( + CPU infini::rt::Device::Type::kCpu infini/rt/cpu/runtime_.h + 0 1 0 1 0 1 1 1) +endif() + +if(WITH_NVIDIA) + set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) + add_infini_rt_backend_runtime_test( + NVIDIA infini::rt::Device::Type::kNvidia + infini/rt/nvidia/runtime_.h + 1 1 1 1 1 1 1 1) +endif() + +if(WITH_ILUVATAR) + set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) + add_infini_rt_backend_runtime_test( + ILUVATAR infini::rt::Device::Type::kIluvatar + infini/rt/iluvatar/runtime_.h + 1 1 1 1 1 1 1 1) +endif() + +if(WITH_HYGON) + set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) + add_infini_rt_backend_runtime_test( + HYGON infini::rt::Device::Type::kHygon + infini/rt/hygon/runtime_.h + 1 1 1 1 1 1 1 1) +endif() + +if(WITH_METAX) + set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) + add_infini_rt_backend_runtime_test( + METAX infini::rt::Device::Type::kMetax + infini/rt/metax/runtime_.h + 1 1 1 1 1 1 1 1) +endif() + +if(WITH_MOORE) + set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) + add_infini_rt_backend_runtime_test( + MOORE infini::rt::Device::Type::kMoore + infini/rt/moore/runtime_.h + 1 1 0 1 1 1 1 1) +endif() + +if(WITH_CAMBRICON) + set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) + add_infini_rt_backend_runtime_test( + CAMBRICON infini::rt::Device::Type::kCambricon + infini/rt/cambricon/runtime_.h + 1 0 0 0 0 0 0 0) +endif() + +if(WITH_ASCEND) + set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) + add_infini_rt_backend_runtime_test( + ASCEND infini::rt::Device::Type::kAscend + infini/rt/ascend/runtime_.h + 1 0 0 0 0 0 0 0) endif() -if(WITH_CPU OR WITH_NVIDIA) +if(INFINI_RT_TEST_HAS_RUNTIME_BACKEND) add_infini_rt_test(test_runtime_dispatch test_runtime_dispatch.cc) + + if(WITH_CPU) + target_compile_definitions(test_runtime_dispatch + PRIVATE INFINI_RT_TEST_WITH_CPU=1) + endif() + if(WITH_NVIDIA) + target_compile_definitions(test_runtime_dispatch + PRIVATE INFINI_RT_TEST_WITH_NVIDIA=1) + endif() + if(WITH_ILUVATAR) + target_compile_definitions(test_runtime_dispatch + PRIVATE INFINI_RT_TEST_WITH_ILUVATAR=1) + endif() + if(WITH_HYGON) + target_compile_definitions(test_runtime_dispatch + PRIVATE INFINI_RT_TEST_WITH_HYGON=1) + endif() + if(WITH_METAX) + target_compile_definitions(test_runtime_dispatch + PRIVATE INFINI_RT_TEST_WITH_METAX=1) + endif() + if(WITH_MOORE) + target_compile_definitions(test_runtime_dispatch + PRIVATE INFINI_RT_TEST_WITH_MOORE=1) + endif() + if(WITH_CAMBRICON) + target_compile_definitions(test_runtime_dispatch + PRIVATE INFINI_RT_TEST_WITH_CAMBRICON=1) + endif() + if(WITH_ASCEND) + target_compile_definitions(test_runtime_dispatch + PRIVATE INFINI_RT_TEST_WITH_ASCEND=1) + endif() endif() if(WITH_NVIDIA) - add_infini_rt_test(test_nvidia_runtime test_nvidia_runtime.cc) add_infini_rt_test(test_nvidia_graph test_nvidia_graph.cc) add_infini_rt_test(test_nvidia_graph_c_api test_nvidia_graph_c_api.cc) endif() @@ -26,12 +141,72 @@ set(INFINI_RT_TEST_INSTALL_PREFIX set(INFINI_RT_TEST_CONSUMER_BINARY "${CMAKE_CURRENT_BINARY_DIR}/install_consumer_smoke") set(INFINI_RT_TEST_EXTRA_LIBRARY_DIRS "") +set(INFINI_RT_TEST_EXTRA_INCLUDE_DIRS "") set(INFINI_RT_TEST_CONSUMER_BACKEND NONE) -if(WITH_CPU) - set(INFINI_RT_TEST_CONSUMER_BACKEND CPU) -elseif(WITH_NVIDIA) +if(WITH_NVIDIA OR WITH_ILUVATAR OR WITH_HYGON) + if(CUDAToolkit_INCLUDE_DIRS) + list(APPEND INFINI_RT_TEST_EXTRA_INCLUDE_DIRS + ${CUDAToolkit_INCLUDE_DIRS}) + elseif(CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES) + list(APPEND INFINI_RT_TEST_EXTRA_INCLUDE_DIRS + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + endif() + + if(CUDAToolkit_LIBRARY_DIR) + list(APPEND INFINI_RT_TEST_EXTRA_LIBRARY_DIRS + "${CUDAToolkit_LIBRARY_DIR}") + endif() + if(CUDAToolkit_TARGET_DIR) + list(APPEND INFINI_RT_TEST_EXTRA_LIBRARY_DIRS + "${CUDAToolkit_TARGET_DIR}/lib64") + endif() + if(CUDAToolkit_ROOT_DIR) + list(APPEND INFINI_RT_TEST_EXTRA_LIBRARY_DIRS + "${CUDAToolkit_ROOT_DIR}/lib64") + endif() +endif() + +if(WITH_NVIDIA) set(INFINI_RT_TEST_CONSUMER_BACKEND NVIDIA) +elseif(WITH_ILUVATAR) + set(INFINI_RT_TEST_CONSUMER_BACKEND ILUVATAR) +elseif(WITH_HYGON) + set(INFINI_RT_TEST_CONSUMER_BACKEND HYGON) + if(HYGON_CUDA_ROOT) + list(APPEND INFINI_RT_TEST_EXTRA_INCLUDE_DIRS + "${HYGON_CUDA_ROOT}/include") + list(APPEND INFINI_RT_TEST_EXTRA_LIBRARY_DIRS + "${HYGON_CUDA_ROOT}/lib64") + endif() +elseif(WITH_METAX) + set(INFINI_RT_TEST_CONSUMER_BACKEND METAX) + list(APPEND INFINI_RT_TEST_EXTRA_INCLUDE_DIRS + "${MACA_PATH}/include") + list(APPEND INFINI_RT_TEST_EXTRA_LIBRARY_DIRS + "${MACA_PATH}/lib") +elseif(WITH_MOORE) + set(INFINI_RT_TEST_CONSUMER_BACKEND MOORE) + list(APPEND INFINI_RT_TEST_EXTRA_INCLUDE_DIRS + "${MUSA_ROOT}/include") + list(APPEND INFINI_RT_TEST_EXTRA_LIBRARY_DIRS + "${MUSA_ROOT}/lib") +elseif(WITH_CAMBRICON) + set(INFINI_RT_TEST_CONSUMER_BACKEND CAMBRICON) + list(APPEND INFINI_RT_TEST_EXTRA_INCLUDE_DIRS + "${NEUWARE_HOME}/include") + list(APPEND INFINI_RT_TEST_EXTRA_LIBRARY_DIRS + "${NEUWARE_HOME}/lib" + "${NEUWARE_HOME}/lib64") +elseif(WITH_ASCEND) + # The install-consumer smoke checks Ascend public headers and linkage. + # Runtime behavior is covered by test_ascend_runtime and dispatch tests. + list(APPEND INFINI_RT_TEST_EXTRA_INCLUDE_DIRS + "${ASCEND_HOME}/include" + "${ASCEND_HOME}/include/aclnn" + "${ASCEND_HOME}/include/aclnnop") +elseif(WITH_CPU) + set(INFINI_RT_TEST_CONSUMER_BACKEND CPU) endif() if(WITH_ASCEND) @@ -46,6 +221,7 @@ if(WITH_ASCEND) endif() list(JOIN INFINI_RT_TEST_EXTRA_LIBRARY_DIRS ":" INFINI_RT_TEST_EXTRA_LIBRARY_PATHS) +list(JOIN INFINI_RT_TEST_EXTRA_INCLUDE_DIRS ":" INFINI_RT_TEST_EXTRA_INCLUDE_PATHS) add_test( NAME test_install @@ -61,6 +237,7 @@ add_test( "-DINFINI_RT_CONSUMER_BINARY=${INFINI_RT_TEST_CONSUMER_BINARY}" "-DINFINI_RT_CXX_COMPILER=${CMAKE_CXX_COMPILER}" "-DINFINI_RT_EXTRA_LIBRARY_PATHS=${INFINI_RT_TEST_EXTRA_LIBRARY_PATHS}" + "-DINFINI_RT_EXTRA_INCLUDE_PATHS=${INFINI_RT_TEST_EXTRA_INCLUDE_PATHS}" "-DINFINI_RT_CONSUMER_BACKEND=${INFINI_RT_TEST_CONSUMER_BACKEND}" -P "${CMAKE_CURRENT_SOURCE_DIR}/compile_install_consumer.cmake") set_tests_properties(test_install_consumer PROPERTIES diff --git a/tests/compile_install_consumer.cmake b/tests/compile_install_consumer.cmake index 169311e..1bb96d1 100644 --- a/tests/compile_install_consumer.cmake +++ b/tests/compile_install_consumer.cmake @@ -17,6 +17,17 @@ set(INFINI_RT_EXTRA_LINK_ARGS "") set(INFINI_RT_EXTRA_COMPILE_ARGS "") set(INFINI_RT_LD_LIBRARY_PATH "${INFINI_RT_LIBRARY_DIR}") +if(INFINI_RT_EXTRA_INCLUDE_PATHS) + string(REPLACE ":" ";" INFINI_RT_EXTRA_INCLUDE_DIRS + "${INFINI_RT_EXTRA_INCLUDE_PATHS}") + foreach(INFINI_RT_EXTRA_INCLUDE_DIR ${INFINI_RT_EXTRA_INCLUDE_DIRS}) + if(EXISTS "${INFINI_RT_EXTRA_INCLUDE_DIR}") + list(APPEND INFINI_RT_EXTRA_COMPILE_ARGS + "-I${INFINI_RT_EXTRA_INCLUDE_DIR}") + endif() + endforeach() +endif() + if(INFINI_RT_CONSUMER_BACKEND AND NOT INFINI_RT_CONSUMER_BACKEND STREQUAL "NONE") list(APPEND INFINI_RT_EXTRA_COMPILE_ARGS "-DINFINI_RT_CONSUMER_BACKEND_${INFINI_RT_CONSUMER_BACKEND}=1") diff --git a/tests/install_consumer_smoke.cc b/tests/install_consumer_smoke.cc index 97442b5..6b72130 100644 --- a/tests/install_consumer_smoke.cc +++ b/tests/install_consumer_smoke.cc @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -18,21 +19,95 @@ int main() { return 1; } -#if defined(INFINI_RT_CONSUMER_BACKEND_CPU) || \ - defined(INFINI_RT_CONSUMER_BACKEND_NVIDIA) -#if defined(INFINI_RT_CONSUMER_BACKEND_NVIDIA) - const infini::rt::Device runtime_device{infini::rt::Device::Type::kNvidia}; -#else - const infini::rt::Device runtime_device{infini::rt::Device::Type::kCpu}; +#if defined(INFINI_RT_CONSUMER_BACKEND_CPU) || \ + defined(INFINI_RT_CONSUMER_BACKEND_NVIDIA) || \ + defined(INFINI_RT_CONSUMER_BACKEND_ILUVATAR) || \ + defined(INFINI_RT_CONSUMER_BACKEND_HYGON) || \ + defined(INFINI_RT_CONSUMER_BACKEND_METAX) || \ + defined(INFINI_RT_CONSUMER_BACKEND_MOORE) || \ + defined(INFINI_RT_CONSUMER_BACKEND_CAMBRICON) || \ + defined(INFINI_RT_CONSUMER_BACKEND_ASCEND) + namespace runtime = infini::rt::runtime; +#if defined(INFINI_RT_CONSUMER_BACKEND_CPU) + constexpr auto kExpectedDeviceType = infini::rt::Device::Type::kCpu; + constexpr bool kExpectAsyncMemcpySuccess = false; +#elif defined(INFINI_RT_CONSUMER_BACKEND_NVIDIA) + constexpr auto kExpectedDeviceType = infini::rt::Device::Type::kNvidia; + constexpr bool kExpectAsyncMemcpySuccess = true; +#elif defined(INFINI_RT_CONSUMER_BACKEND_ILUVATAR) + constexpr auto kExpectedDeviceType = infini::rt::Device::Type::kIluvatar; + constexpr bool kExpectAsyncMemcpySuccess = true; +#elif defined(INFINI_RT_CONSUMER_BACKEND_HYGON) + constexpr auto kExpectedDeviceType = infini::rt::Device::Type::kHygon; + constexpr bool kExpectAsyncMemcpySuccess = true; +#elif defined(INFINI_RT_CONSUMER_BACKEND_METAX) + constexpr auto kExpectedDeviceType = infini::rt::Device::Type::kMetax; + constexpr bool kExpectAsyncMemcpySuccess = true; +#elif defined(INFINI_RT_CONSUMER_BACKEND_MOORE) + constexpr auto kExpectedDeviceType = infini::rt::Device::Type::kMoore; + constexpr bool kExpectAsyncMemcpySuccess = true; +#elif defined(INFINI_RT_CONSUMER_BACKEND_CAMBRICON) + constexpr auto kExpectedDeviceType = infini::rt::Device::Type::kCambricon; + constexpr bool kExpectAsyncMemcpySuccess = true; +#elif defined(INFINI_RT_CONSUMER_BACKEND_ASCEND) + constexpr auto kExpectedDeviceType = infini::rt::Device::Type::kAscend; + constexpr bool kExpectAsyncMemcpySuccess = true; #endif + infini::rt::set_runtime_device_type(kExpectedDeviceType); + if (infini::rt::runtime_device_type() != kExpectedDeviceType) { + return 1; + } + std::array input{1, 2, 3, 4}; + std::array output{}; void* ptr = nullptr; - infini::rt::SetDevice(runtime_device); - infini::rt::Malloc(&ptr, sizeof(std::uint32_t)); + int device_count = 0; + if (runtime::GetDeviceCount(&device_count) != runtime::kSuccess || + device_count <= 0) { + return 0; + } + if (runtime::SetDevice(0) != runtime::kSuccess) { + return 0; + } + int current_device = -1; + if (runtime::GetDevice(¤t_device) != runtime::kSuccess) { + return 1; + } + if (current_device != 0) { + return 1; + } + if (runtime::Malloc(&ptr, input.size()) != runtime::kSuccess) { + return 1; + } if (ptr == nullptr) { return 1; } - infini::rt::Free(ptr); + if (runtime::Memcpy(ptr, input.data(), input.size(), + runtime::kMemcpyHostToDevice) != runtime::kSuccess) { + return 1; + } + runtime::Stream stream{}; + const auto async_status = runtime::MemcpyAsync( + ptr, input.data(), input.size(), runtime::kMemcpyHostToDevice, stream); + if (kExpectAsyncMemcpySuccess && async_status != runtime::kSuccess) { + return 1; + } + if (!kExpectAsyncMemcpySuccess && async_status == runtime::kSuccess) { + return 1; + } + if (runtime::DeviceSynchronize() != runtime::kSuccess) { + return 1; + } + if (runtime::Memcpy(output.data(), ptr, output.size(), + runtime::kMemcpyDeviceToHost) != runtime::kSuccess) { + return 1; + } + if (output != input) { + return 1; + } + if (runtime::Free(ptr) != runtime::kSuccess) { + return 1; + } #endif return 0; diff --git a/tests/test_core.cc b/tests/test_core.cc index e8b8071..273ecef 100644 --- a/tests/test_core.cc +++ b/tests/test_core.cc @@ -1,7 +1,9 @@ #include +#include #include #include +#include #include #include "test_helper.h" @@ -12,6 +14,9 @@ using infini::rt::DataType; using infini::rt::Device; using infini::rt::TensorView; +static_assert(!std::is_constructible_v>, + "TensorView should not treat tensor containers as tensor-like."); + void TestDevice(infini::rt::test::TestContext* context) { const Device cpu{Device::Type::kCpu}; const Device nvidia{Device::Type::kNvidia, 1}; diff --git a/tests/test_cpu_runtime.cc b/tests/test_cpu_runtime.cc deleted file mode 100644 index 20edce9..0000000 --- a/tests/test_cpu_runtime.cc +++ /dev/null @@ -1,75 +0,0 @@ -#include -#include - -#include -#include -#include - -#include "test_helper.h" - -namespace { - -using CpuRuntime = infini::rt::Runtime; - -void TestMallocAndFree(infini::rt::test::TestContext* context) { - void* ptr = nullptr; - CpuRuntime::Malloc(&ptr, 16); - - context->Expect(ptr != nullptr, "CPU runtime should allocate memory."); - - CpuRuntime::Free(ptr); -} - -void TestMemcpyRoundTrip(infini::rt::test::TestContext* context) { - std::array input{0, 1, 2, 3, 4, 5, 6, 7}; - std::array output{}; - void* ptr = nullptr; - - CpuRuntime::Malloc(&ptr, input.size()); - context->Expect(ptr != nullptr, "CPU runtime should allocate copy memory."); - if (ptr == nullptr) { - return; - } - - CpuRuntime::Memcpy(ptr, input.data(), input.size(), - CpuRuntime::MemcpyHostToDevice); - CpuRuntime::Memcpy(output.data(), ptr, output.size(), - CpuRuntime::MemcpyDeviceToHost); - CpuRuntime::Free(ptr); - - context->ExpectEqual(output, input, - "CPU runtime should copy data through runtime memory."); -} - -void TestMemset(infini::rt::test::TestContext* context) { - std::array output{}; - void* ptr = nullptr; - - CpuRuntime::Malloc(&ptr, output.size()); - context->Expect(ptr != nullptr, "CPU runtime should allocate memset memory."); - if (ptr == nullptr) { - return; - } - - CpuRuntime::Memset(ptr, 0x5A, output.size()); - CpuRuntime::Memcpy(output.data(), ptr, output.size(), - CpuRuntime::MemcpyDeviceToHost); - CpuRuntime::Free(ptr); - - for (const auto value : output) { - context->ExpectEqual(value, static_cast(0x5A), - "CPU runtime should fill memory with memset."); - } -} - -} // namespace - -int main() { - infini::rt::test::TestContext context; - - TestMallocAndFree(&context); - TestMemcpyRoundTrip(&context); - TestMemset(&context); - - return context.ExitCode(); -} diff --git a/tests/test_native_runtime.cc b/tests/test_native_runtime.cc new file mode 100644 index 0000000..3ae7ea9 --- /dev/null +++ b/tests/test_native_runtime.cc @@ -0,0 +1,485 @@ +#include +#include INFINI_RT_TEST_RUNTIME_HEADER + +#include +#include +#include +#include + +#include "test_helper.h" + +namespace { + +using Runtime = infini::rt::runtime::Runtime; + +constexpr bool kExpectAsyncMemcpySuccess = + INFINI_RT_TEST_EXPECT_ASYNC_MEMCPY_SUCCESS != 0; +constexpr bool kSupportsHostMemory = INFINI_RT_TEST_SUPPORTS_HOST_MEMORY != 0; +constexpr bool kSupportsAsyncMemory = INFINI_RT_TEST_SUPPORTS_ASYNC_MEMORY != 0; +constexpr bool kSupportsMemGetInfo = INFINI_RT_TEST_SUPPORTS_MEM_GET_INFO != 0; +constexpr bool kSupportsMemsetAsync = INFINI_RT_TEST_SUPPORTS_MEMSET_ASYNC != 0; +constexpr bool kSupportsStreamWaitEvent = + INFINI_RT_TEST_SUPPORTS_STREAM_WAIT_EVENT != 0; +constexpr bool kSupportsEvent = INFINI_RT_TEST_SUPPORTS_EVENT != 0; +constexpr bool kSupportsEventElapsedTime = + INFINI_RT_TEST_SUPPORTS_EVENT_ELAPSED_TIME != 0; + +void ExpectSuccess(infini::rt::test::TestContext* context, + typename Runtime::Error status, const char* message) { + context->Expect(status == Runtime::kSuccess, message); +} + +void ExpectFailure(infini::rt::test::TestContext* context, + typename Runtime::Error status, const char* message) { + context->Expect(status != Runtime::kSuccess, message); +} + +void ExpectStatus(infini::rt::test::TestContext* context, + typename Runtime::Error status, bool expect_success, + const char* message) { + if (expect_success) { + ExpectSuccess(context, status, message); + } else { + ExpectFailure(context, status, message); + } +} + +bool SelectDevice() { + int device_count = 0; + if (Runtime::GetDeviceCount(&device_count) != Runtime::kSuccess || + device_count <= 0) { + std::cout << INFINI_RT_TEST_BACKEND_NAME + << " runtime skipped: no available device." << std::endl; + return false; + } + + if (Runtime::SetDevice(0) != Runtime::kSuccess) { + std::cout << INFINI_RT_TEST_BACKEND_NAME + << " runtime skipped: device 0 is not available." << std::endl; + return false; + } + + return true; +} + +bool CreateStream(infini::rt::test::TestContext* context, + typename Runtime::Stream* stream) { + const auto status = Runtime::StreamCreate(stream); + ExpectSuccess(context, status, + INFINI_RT_TEST_BACKEND_NAME " runtime should create a stream."); + return status == Runtime::kSuccess; +} + +void DestroyStream(infini::rt::test::TestContext* context, + typename Runtime::Stream stream) { + ExpectSuccess(context, Runtime::StreamDestroy(stream), + INFINI_RT_TEST_BACKEND_NAME + " runtime should destroy a stream."); +} + +void TestDevice(infini::rt::test::TestContext* context) { + int current_device = -1; + ExpectSuccess(context, Runtime::GetDevice(¤t_device), + INFINI_RT_TEST_BACKEND_NAME + " runtime should get the current device."); + context->ExpectEqual(current_device, 0, + INFINI_RT_TEST_BACKEND_NAME + " runtime should keep the current device."); + + int device_count = 0; + ExpectSuccess(context, Runtime::GetDeviceCount(&device_count), + INFINI_RT_TEST_BACKEND_NAME + " runtime should get the device count."); + context->Expect(device_count > 0, INFINI_RT_TEST_BACKEND_NAME + " runtime should report at least one device."); +} + +void TestMallocAndFree(infini::rt::test::TestContext* context) { + void* ptr = nullptr; + ExpectSuccess(context, Runtime::Malloc(&ptr, 16), + INFINI_RT_TEST_BACKEND_NAME " runtime should allocate memory."); + context->Expect(ptr != nullptr, INFINI_RT_TEST_BACKEND_NAME + " runtime allocation should produce a pointer."); + if (ptr != nullptr) { + ExpectSuccess(context, Runtime::Free(ptr), + INFINI_RT_TEST_BACKEND_NAME " runtime should free memory."); + } +} + +void TestHostMemory(infini::rt::test::TestContext* context) { + void* ptr = nullptr; + const auto malloc_status = Runtime::MallocHost(&ptr, 16); + ExpectStatus(context, malloc_status, kSupportsHostMemory, + INFINI_RT_TEST_BACKEND_NAME + " runtime should report expected host memory support."); + + if constexpr (kSupportsHostMemory) { + context->Expect(ptr != nullptr, INFINI_RT_TEST_BACKEND_NAME + " runtime host allocation should produce a pointer."); + if (ptr != nullptr) { + ExpectSuccess(context, Runtime::FreeHost(ptr), + INFINI_RT_TEST_BACKEND_NAME + " runtime should free host memory."); + } + } +} + +void TestAsyncMemory(infini::rt::test::TestContext* context) { + typename Runtime::Stream stream{}; + if (!CreateStream(context, &stream)) { + return; + } + + void* ptr = nullptr; + const auto malloc_status = Runtime::MallocAsync(&ptr, 16, stream); + ExpectStatus(context, malloc_status, kSupportsAsyncMemory, + INFINI_RT_TEST_BACKEND_NAME + " runtime should report expected async allocation support."); + + if constexpr (kSupportsAsyncMemory) { + context->Expect(ptr != nullptr, INFINI_RT_TEST_BACKEND_NAME + " runtime async allocation should produce a pointer."); + if (ptr != nullptr) { + ExpectSuccess(context, Runtime::FreeAsync(ptr, stream), + INFINI_RT_TEST_BACKEND_NAME + " runtime should free async memory."); + ExpectSuccess(context, Runtime::StreamSynchronize(stream), + INFINI_RT_TEST_BACKEND_NAME + " runtime should synchronize async memory operations."); + } + } else if (ptr != nullptr) { + ExpectSuccess(context, Runtime::Free(ptr), + INFINI_RT_TEST_BACKEND_NAME + " runtime should free unexpected async allocation."); + } + + DestroyStream(context, stream); +} + +void TestMemGetInfo(infini::rt::test::TestContext* context) { + std::size_t free = 0; + std::size_t total = 0; + const auto status = Runtime::MemGetInfo(&free, &total); + ExpectStatus(context, status, kSupportsMemGetInfo, + INFINI_RT_TEST_BACKEND_NAME + " runtime should report expected memory info support."); + + if constexpr (kSupportsMemGetInfo) { + context->Expect(total > 0, INFINI_RT_TEST_BACKEND_NAME + " runtime should report total memory."); + context->Expect(free <= total, INFINI_RT_TEST_BACKEND_NAME + " runtime free memory should not exceed total memory."); + } +} + +void TestMemcpyRoundTrip(infini::rt::test::TestContext* context) { + std::array input{0, 1, 2, 3, 4, 5, 6, 7}; + std::array output{}; + void* ptr = nullptr; + + ExpectSuccess(context, Runtime::Malloc(&ptr, input.size()), + INFINI_RT_TEST_BACKEND_NAME + " runtime should allocate copy memory."); + if (ptr == nullptr) { + return; + } + + ExpectSuccess(context, + Runtime::Memcpy(ptr, input.data(), input.size(), + Runtime::kMemcpyHostToDevice), + INFINI_RT_TEST_BACKEND_NAME + " runtime should copy host data to runtime memory."); + ExpectSuccess(context, + Runtime::Memcpy(output.data(), ptr, output.size(), + Runtime::kMemcpyDeviceToHost), + INFINI_RT_TEST_BACKEND_NAME + " runtime should copy runtime memory to host."); + ExpectSuccess(context, Runtime::Free(ptr), + INFINI_RT_TEST_BACKEND_NAME + " runtime should free copy memory."); + + context->ExpectEqual(output, input, + INFINI_RT_TEST_BACKEND_NAME + " runtime should preserve copied bytes."); +} + +void TestMemcpyAsync(infini::rt::test::TestContext* context) { + std::array input{8, 9, 10, 11}; + std::array output{}; + void* ptr = nullptr; + + ExpectSuccess(context, Runtime::Malloc(&ptr, input.size()), + INFINI_RT_TEST_BACKEND_NAME + " runtime should allocate async copy memory."); + if (ptr == nullptr) { + return; + } + + typename Runtime::Stream stream{}; + const auto async_status = Runtime::MemcpyAsync( + ptr, input.data(), input.size(), Runtime::kMemcpyHostToDevice, stream); + + if constexpr (kExpectAsyncMemcpySuccess) { + ExpectSuccess(context, async_status, + INFINI_RT_TEST_BACKEND_NAME + " runtime should support async host-to-device copy."); + ExpectSuccess(context, Runtime::DeviceSynchronize(), + INFINI_RT_TEST_BACKEND_NAME + " runtime should synchronize async host-to-device copy."); + ExpectSuccess(context, + Runtime::Memcpy(output.data(), ptr, output.size(), + Runtime::kMemcpyDeviceToHost), + INFINI_RT_TEST_BACKEND_NAME + " runtime should copy async data back to host."); + context->ExpectEqual(output, input, + INFINI_RT_TEST_BACKEND_NAME + " runtime should preserve async copied bytes."); + } else { + ExpectFailure(context, async_status, + INFINI_RT_TEST_BACKEND_NAME + " runtime should not report async memcpy success."); + } + + ExpectSuccess(context, Runtime::Free(ptr), + INFINI_RT_TEST_BACKEND_NAME + " runtime should free async copy memory."); +} + +void TestMemset(infini::rt::test::TestContext* context) { + std::array output{}; + void* ptr = nullptr; + + ExpectSuccess(context, Runtime::Malloc(&ptr, output.size()), + INFINI_RT_TEST_BACKEND_NAME + " runtime should allocate memset memory."); + if (ptr == nullptr) { + return; + } + + ExpectSuccess(context, Runtime::Memset(ptr, 0x5A, output.size()), + INFINI_RT_TEST_BACKEND_NAME " runtime should fill memory."); + ExpectSuccess(context, + Runtime::Memcpy(output.data(), ptr, output.size(), + Runtime::kMemcpyDeviceToHost), + INFINI_RT_TEST_BACKEND_NAME + " runtime should copy filled memory to host."); + ExpectSuccess(context, Runtime::Free(ptr), + INFINI_RT_TEST_BACKEND_NAME + " runtime should free memset memory."); + + for (const auto value : output) { + context->ExpectEqual(value, static_cast(0x5A), + INFINI_RT_TEST_BACKEND_NAME + " runtime should preserve filled bytes."); + } +} + +void TestMemsetAsync(infini::rt::test::TestContext* context) { + std::array output{}; + void* ptr = nullptr; + + ExpectSuccess(context, Runtime::Malloc(&ptr, output.size()), + INFINI_RT_TEST_BACKEND_NAME + " runtime should allocate async memset memory."); + if (ptr == nullptr) { + return; + } + + typename Runtime::Stream stream{}; + if (!CreateStream(context, &stream)) { + ExpectSuccess(context, Runtime::Free(ptr), + INFINI_RT_TEST_BACKEND_NAME + " runtime should free async memset memory."); + return; + } + + const auto memset_status = + Runtime::MemsetAsync(ptr, 0xA5, output.size(), stream); + ExpectStatus(context, memset_status, kSupportsMemsetAsync, + INFINI_RT_TEST_BACKEND_NAME + " runtime should report expected async memset support."); + + if constexpr (kSupportsMemsetAsync) { + ExpectSuccess(context, Runtime::StreamSynchronize(stream), + INFINI_RT_TEST_BACKEND_NAME + " runtime should synchronize async memset."); + ExpectSuccess(context, + Runtime::Memcpy(output.data(), ptr, output.size(), + Runtime::kMemcpyDeviceToHost), + INFINI_RT_TEST_BACKEND_NAME + " runtime should copy async filled memory to host."); + for (const auto value : output) { + context->ExpectEqual(value, static_cast(0xA5), + INFINI_RT_TEST_BACKEND_NAME + " runtime should preserve async filled bytes."); + } + } + + DestroyStream(context, stream); + ExpectSuccess(context, Runtime::Free(ptr), + INFINI_RT_TEST_BACKEND_NAME + " runtime should free async memset memory."); +} + +void TestStream(infini::rt::test::TestContext* context) { + typename Runtime::Stream stream{}; + if (!CreateStream(context, &stream)) { + return; + } + + ExpectSuccess(context, Runtime::StreamSynchronize(stream), + INFINI_RT_TEST_BACKEND_NAME + " runtime should synchronize a stream."); + DestroyStream(context, stream); +} + +void TestEvent(infini::rt::test::TestContext* context) { + typename Runtime::Event event{}; + const auto create_status = Runtime::EventCreate(&event); + ExpectStatus(context, create_status, kSupportsEvent, + INFINI_RT_TEST_BACKEND_NAME + " runtime should report expected event support."); + + if constexpr (!kSupportsEvent) { + return; + } + + context->Expect(event != typename Runtime::Event{}, + INFINI_RT_TEST_BACKEND_NAME + " runtime event creation should produce an event."); + if (event == typename Runtime::Event{}) { + return; + } + + ExpectSuccess(context, + Runtime::EventRecord(event, typename Runtime::Stream{}), + INFINI_RT_TEST_BACKEND_NAME " runtime should record an event."); + ExpectSuccess(context, Runtime::EventSynchronize(event), + INFINI_RT_TEST_BACKEND_NAME + " runtime should synchronize an event."); + ExpectSuccess(context, Runtime::EventQuery(event), + INFINI_RT_TEST_BACKEND_NAME + " runtime should query a completed event."); + ExpectSuccess(context, Runtime::EventDestroy(event), + INFINI_RT_TEST_BACKEND_NAME + " runtime should destroy an event."); + + typename Runtime::Event flagged_event{}; + ExpectSuccess(context, Runtime::EventCreateWithFlags(&flagged_event, 0), + INFINI_RT_TEST_BACKEND_NAME + " runtime should create an event with flags."); + if (flagged_event != typename Runtime::Event{}) { + ExpectSuccess(context, Runtime::EventDestroy(flagged_event), + INFINI_RT_TEST_BACKEND_NAME + " runtime should destroy a flagged event."); + } +} + +void TestStreamWaitEvent(infini::rt::test::TestContext* context) { + typename Runtime::Stream stream{}; + if (!CreateStream(context, &stream)) { + return; + } + + typename Runtime::Event event{}; + if constexpr (kSupportsEvent) { + ExpectSuccess(context, Runtime::EventCreate(&event), + INFINI_RT_TEST_BACKEND_NAME + " runtime should create a stream wait event."); + if (event != typename Runtime::Event{}) { + ExpectSuccess(context, Runtime::EventRecord(event, stream), + INFINI_RT_TEST_BACKEND_NAME + " runtime should record a stream wait event."); + } + } + + const auto wait_status = Runtime::StreamWaitEvent(stream, event, 0); + ExpectStatus(context, wait_status, kSupportsStreamWaitEvent, + INFINI_RT_TEST_BACKEND_NAME + " runtime should report expected stream wait event support."); + ExpectSuccess(context, Runtime::StreamSynchronize(stream), + INFINI_RT_TEST_BACKEND_NAME + " runtime should synchronize after stream wait event."); + + if constexpr (kSupportsEvent) { + if (event != typename Runtime::Event{}) { + ExpectSuccess(context, Runtime::EventDestroy(event), + INFINI_RT_TEST_BACKEND_NAME + " runtime should destroy a stream wait event."); + } + } + DestroyStream(context, stream); +} + +void TestEventElapsedTime(infini::rt::test::TestContext* context) { + typename Runtime::Event start{}; + typename Runtime::Event end{}; + if constexpr (kSupportsEvent) { + ExpectSuccess(context, Runtime::EventCreate(&start), + INFINI_RT_TEST_BACKEND_NAME + " runtime should create an elapsed-time start event."); + ExpectSuccess(context, Runtime::EventCreate(&end), + INFINI_RT_TEST_BACKEND_NAME + " runtime should create an elapsed-time end event."); + } + + if constexpr (kSupportsEventElapsedTime) { + ExpectSuccess(context, + Runtime::EventRecord(start, typename Runtime::Stream{}), + INFINI_RT_TEST_BACKEND_NAME + " runtime should record an elapsed-time start event."); + ExpectSuccess(context, + Runtime::EventRecord(end, typename Runtime::Stream{}), + INFINI_RT_TEST_BACKEND_NAME + " runtime should record an elapsed-time end event."); + ExpectSuccess(context, Runtime::EventSynchronize(end), + INFINI_RT_TEST_BACKEND_NAME + " runtime should synchronize an elapsed-time event."); + } + + float elapsed_ms = 0.0f; + const auto elapsed_status = + Runtime::EventElapsedTime(&elapsed_ms, start, end); + ExpectStatus(context, elapsed_status, kSupportsEventElapsedTime, + INFINI_RT_TEST_BACKEND_NAME + " runtime should report expected event elapsed-time support."); + + if constexpr (kSupportsEvent) { + if (start != typename Runtime::Event{}) { + ExpectSuccess(context, Runtime::EventDestroy(start), + INFINI_RT_TEST_BACKEND_NAME + " runtime should destroy an elapsed-time start event."); + } + if (end != typename Runtime::Event{}) { + ExpectSuccess(context, Runtime::EventDestroy(end), + INFINI_RT_TEST_BACKEND_NAME + " runtime should destroy an elapsed-time end event."); + } + } +} + +} // namespace + +int main() { + infini::rt::test::TestContext context; + + if (!SelectDevice()) { + return context.ExitCode(); + } + + TestDevice(&context); + TestMallocAndFree(&context); + TestHostMemory(&context); + TestAsyncMemory(&context); + TestMemGetInfo(&context); + TestMemcpyRoundTrip(&context); + TestMemcpyAsync(&context); + TestMemset(&context); + TestMemsetAsync(&context); + TestStream(&context); + TestEvent(&context); + TestStreamWaitEvent(&context); + TestEventElapsedTime(&context); + + return context.ExitCode(); +} diff --git a/tests/test_nvidia_graph.cc b/tests/test_nvidia_graph.cc index 47f131b..291c46b 100644 --- a/tests/test_nvidia_graph.cc +++ b/tests/test_nvidia_graph.cc @@ -8,6 +8,13 @@ namespace { +namespace runtime = infini::rt::runtime; + +void ExpectSuccess(infini::rt::test::TestContext* context, + runtime::Error status, const char* message) { + context->Expect(status == runtime::kSuccess, message); +} + void FillPattern(std::array* input, std::uint8_t salt) { for (std::size_t i = 0; i < input->size(); ++i) { (*input)[i] = static_cast(i * 13 + salt); @@ -19,8 +26,8 @@ bool CopyDeviceToHostAndValidate(infini::rt::test::TestContext* context, const std::array& expected, std::string_view message) { std::array output{}; - infini::rt::Memcpy(output.data(), device_ptr, output.size(), - infini::rt::MemcpyKind::kDeviceToHost); + runtime::Memcpy(output.data(), device_ptr, output.size(), + runtime::kMemcpyDeviceToHost); return context->ExpectEqual(output, expected, message); } @@ -28,69 +35,106 @@ bool CopyDeviceToHostAndValidate(infini::rt::test::TestContext* context, int main() { infini::rt::test::TestContext context; - const infini::rt::Device device{infini::rt::Device::Type::kNvidia, 0}; - infini::rt::SetDevice(device); + infini::rt::set_runtime_device_type(infini::rt::Device::Type::kNvidia); + ExpectSuccess(&context, runtime::SetDevice(0), + "Failed to set NVIDIA runtime device."); void* src = nullptr; void* dst = nullptr; - infini::rt::Stream stream; - infini::rt::Graph graph; - infini::rt::GraphExec graph_exec; + runtime::Stream stream = nullptr; + runtime::Graph graph = nullptr; + runtime::GraphExec graph_exec = nullptr; std::array capture_input{}; FillPattern(&capture_input, 7); - infini::rt::Malloc(&src, capture_input.size()); - infini::rt::Malloc(&dst, capture_input.size()); - infini::rt::StreamCreate(&stream); - - context.Expect(stream.device_type() == infini::rt::Device::Type::kNvidia, - "Stream should remember its NVIDIA device type."); - - infini::rt::Memcpy(src, capture_input.data(), capture_input.size(), - infini::rt::MemcpyKind::kHostToDevice); - infini::rt::Memset(dst, 0, capture_input.size()); - - infini::rt::StreamBeginCapture(stream, - infini::rt::StreamCaptureMode::kRelaxed); - infini::rt::MemcpyAsync(dst, src, capture_input.size(), - infini::rt::MemcpyKind::kDeviceToDevice, stream); - infini::rt::StreamEndCapture(stream, &graph); - - context.Expect(graph.device_type() == infini::rt::Device::Type::kNvidia, - "Graph should remember its NVIDIA device type."); - - infini::rt::GraphInstantiate(&graph_exec, graph); - context.Expect(graph_exec.device_type() == infini::rt::Device::Type::kNvidia, - "GraphExec should remember its NVIDIA device type."); + ExpectSuccess(&context, runtime::Malloc(&src, capture_input.size()), + "Failed to allocate source buffer."); + ExpectSuccess(&context, runtime::Malloc(&dst, capture_input.size()), + "Failed to allocate destination buffer."); + ExpectSuccess(&context, runtime::StreamCreate(&stream), + "Failed to create stream."); + + ExpectSuccess(&context, + runtime::Memcpy(src, capture_input.data(), capture_input.size(), + runtime::kMemcpyHostToDevice), + "Failed to initialize source buffer."); + ExpectSuccess(&context, runtime::Memset(dst, 0, capture_input.size()), + "Failed to initialize destination buffer."); + + ExpectSuccess( + &context, + runtime::StreamBeginCapture( + stream, runtime::StreamCaptureMode::kStreamCaptureModeRelaxed), + "Failed to begin stream capture."); + ExpectSuccess(&context, + runtime::MemcpyAsync(dst, src, capture_input.size(), + runtime::kMemcpyDeviceToDevice, stream), + "Failed to record device-to-device copy."); + ExpectSuccess(&context, runtime::StreamEndCapture(stream, &graph), + "Failed to end stream capture."); + + ExpectSuccess(&context, runtime::GraphInstantiate(&graph_exec, graph), + "Failed to instantiate graph."); std::array replay_input_1{}; std::array replay_input_2{}; FillPattern(&replay_input_1, 31); FillPattern(&replay_input_2, 53); - infini::rt::Memcpy(src, replay_input_1.data(), replay_input_1.size(), - infini::rt::MemcpyKind::kHostToDevice); - infini::rt::Memset(dst, 0, replay_input_1.size()); - infini::rt::GraphLaunch(graph_exec, stream); - infini::rt::StreamSynchronize(stream); + ExpectSuccess( + &context, + runtime::Memcpy(src, replay_input_1.data(), replay_input_1.size(), + runtime::kMemcpyHostToDevice), + "Failed to refresh first source buffer."); + ExpectSuccess(&context, runtime::Memset(dst, 0, replay_input_1.size()), + "Failed to clear first destination buffer."); + ExpectSuccess(&context, runtime::DeviceSynchronize(), + "Failed to synchronize first replay inputs."); + ExpectSuccess(&context, runtime::GraphLaunch(graph_exec, stream), + "Failed to launch first graph replay."); + ExpectSuccess(&context, runtime::StreamSynchronize(stream), + "Failed to synchronize first graph replay."); CopyDeviceToHostAndValidate(&context, dst, replay_input_1, "First graph replay should copy D2D data."); - infini::rt::Memcpy(src, replay_input_2.data(), replay_input_2.size(), - infini::rt::MemcpyKind::kHostToDevice); - infini::rt::Memset(dst, 0, replay_input_2.size()); - infini::rt::GraphLaunch(graph_exec, stream); - infini::rt::StreamSynchronize(stream); + ExpectSuccess( + &context, + runtime::Memcpy(src, replay_input_2.data(), replay_input_2.size(), + runtime::kMemcpyHostToDevice), + "Failed to refresh second source buffer."); + ExpectSuccess(&context, runtime::Memset(dst, 0, replay_input_2.size()), + "Failed to clear second destination buffer."); + ExpectSuccess(&context, runtime::DeviceSynchronize(), + "Failed to synchronize second replay inputs."); + ExpectSuccess(&context, runtime::GraphLaunch(graph_exec, stream), + "Failed to launch second graph replay."); + ExpectSuccess(&context, runtime::StreamSynchronize(stream), + "Failed to synchronize second graph replay."); CopyDeviceToHostAndValidate(&context, dst, replay_input_2, "Second graph replay should copy D2D data."); - infini::rt::GraphExecDestroy(graph_exec); - infini::rt::GraphDestroy(graph); - infini::rt::StreamDestroy(stream); - infini::rt::Free(dst); - infini::rt::Free(src); + if (graph_exec != nullptr) { + ExpectSuccess(&context, runtime::GraphExecDestroy(graph_exec), + "Failed to destroy graph exec."); + } + if (graph != nullptr) { + ExpectSuccess(&context, runtime::GraphDestroy(graph), + "Failed to destroy graph."); + } + if (stream != nullptr) { + ExpectSuccess(&context, runtime::StreamDestroy(stream), + "Failed to destroy stream."); + } + if (dst != nullptr) { + ExpectSuccess(&context, runtime::Free(dst), + "Failed to free destination buffer."); + } + if (src != nullptr) { + ExpectSuccess(&context, runtime::Free(src), + "Failed to free source buffer."); + } return context.ExitCode(); } diff --git a/tests/test_nvidia_runtime.cc b/tests/test_nvidia_runtime.cc deleted file mode 100644 index 2d3dac2..0000000 --- a/tests/test_nvidia_runtime.cc +++ /dev/null @@ -1,92 +0,0 @@ -#include -#include - -#include -#include -#include - -#include "test_helper.h" - -namespace { - -using NvidiaRuntime = infini::rt::Runtime; - -void ExpectCudaSuccess(infini::rt::test::TestContext* context, - cudaError_t status, const char* message) { - context->Expect(status == cudaSuccess, message); -} - -void TestMallocAndFree(infini::rt::test::TestContext* context) { - void* ptr = nullptr; - ExpectCudaSuccess(context, NvidiaRuntime::Malloc(&ptr, 16), - "NVIDIA runtime should allocate device memory."); - context->Expect(ptr != nullptr, - "NVIDIA runtime allocation should produce a pointer."); - if (ptr != nullptr) { - ExpectCudaSuccess(context, NvidiaRuntime::Free(ptr), - "NVIDIA runtime should free device memory."); - } -} - -void TestMemcpyRoundTrip(infini::rt::test::TestContext* context) { - std::array input{0, 1, 2, 3, 4, 5, 6, 7}; - std::array output{}; - void* ptr = nullptr; - - ExpectCudaSuccess(context, NvidiaRuntime::Malloc(&ptr, input.size()), - "NVIDIA runtime should allocate copy memory."); - if (ptr == nullptr) { - return; - } - - ExpectCudaSuccess(context, - NvidiaRuntime::Memcpy(ptr, input.data(), input.size(), - NvidiaRuntime::MemcpyHostToDevice), - "NVIDIA runtime should copy host data to device memory."); - ExpectCudaSuccess(context, - NvidiaRuntime::Memcpy(output.data(), ptr, output.size(), - NvidiaRuntime::MemcpyDeviceToHost), - "NVIDIA runtime should copy device data to host memory."); - ExpectCudaSuccess(context, NvidiaRuntime::Free(ptr), - "NVIDIA runtime should free copy memory."); - - context->ExpectEqual( - output, input, "NVIDIA runtime should copy data through device memory."); -} - -void TestMemset(infini::rt::test::TestContext* context) { - std::array output{}; - void* ptr = nullptr; - - ExpectCudaSuccess(context, NvidiaRuntime::Malloc(&ptr, output.size()), - "NVIDIA runtime should allocate memset memory."); - if (ptr == nullptr) { - return; - } - - ExpectCudaSuccess(context, NvidiaRuntime::Memset(ptr, 0x5A, output.size()), - "NVIDIA runtime should memset device memory."); - ExpectCudaSuccess(context, - NvidiaRuntime::Memcpy(output.data(), ptr, output.size(), - NvidiaRuntime::MemcpyDeviceToHost), - "NVIDIA runtime should copy memset data to host memory."); - ExpectCudaSuccess(context, NvidiaRuntime::Free(ptr), - "NVIDIA runtime should free memset memory."); - - for (const auto value : output) { - context->ExpectEqual(value, static_cast(0x5A), - "NVIDIA runtime should fill memory with memset."); - } -} - -} // namespace - -int main() { - infini::rt::test::TestContext context; - - TestMallocAndFree(&context); - TestMemcpyRoundTrip(&context); - TestMemset(&context); - - return context.ExitCode(); -} diff --git a/tests/test_runtime_dispatch.cc b/tests/test_runtime_dispatch.cc index 9be92a7..be49b3b 100644 --- a/tests/test_runtime_dispatch.cc +++ b/tests/test_runtime_dispatch.cc @@ -3,63 +3,442 @@ #include #include #include +#include +#include +#include #include "test_helper.h" namespace { -infini::rt::Device RuntimeTestDevice() { -#if defined(WITH_NVIDIA) - return infini::rt::Device{infini::rt::Device::Type::kNvidia}; -#else - return infini::rt::Device{infini::rt::Device::Type::kCpu}; -#endif +namespace runtime = infini::rt::runtime; + +struct RuntimeApiExpectations { + bool async_memcpy; + bool host_memory; + bool async_memory; + bool mem_get_info; + bool memset_async; + bool stream_wait_event; + bool event; + bool event_elapsed_time; +}; + +void ExpectSuccess(infini::rt::test::TestContext* context, + runtime::Error status, std::string_view message) { + context->Expect(status == runtime::kSuccess, message); } -} // namespace +void ExpectFailure(infini::rt::test::TestContext* context, + runtime::Error status, std::string_view message) { + context->Expect(status != runtime::kSuccess, message); +} -int main() { - infini::rt::test::TestContext context; - const infini::rt::Device device = RuntimeTestDevice(); +void ExpectStatus(infini::rt::test::TestContext* context, runtime::Error status, + bool expect_success, std::string_view message) { + if (expect_success) { + ExpectSuccess(context, status, message); + } else { + ExpectFailure(context, status, message); + } +} + +std::string Message(std::string_view backend_name, std::string_view message) { + return std::string{backend_name} + " dispatch should " + + std::string{message} + "."; +} + +bool SelectDevice(infini::rt::test::TestContext* context, + infini::rt::Device::Type device_type, + const char* backend_name) { + infini::rt::set_runtime_device_type(device_type); + if (!context->Expect(infini::rt::runtime_device_type() == device_type, + Message(backend_name, "report the selected backend"))) { + return false; + } + + int device_count = 0; + if (runtime::GetDeviceCount(&device_count) != runtime::kSuccess || + device_count <= 0) { + std::cout << backend_name << " dispatch skipped: no available device." + << std::endl; + return false; + } + + if (runtime::SetDevice(0) != runtime::kSuccess) { + std::cout << backend_name << " dispatch skipped: device 0 is not available." + << std::endl; + return false; + } + + return true; +} + +bool CreateStream(infini::rt::test::TestContext* context, + const char* backend_name, runtime::Stream* stream) { + const auto status = runtime::StreamCreate(stream); + ExpectSuccess(context, status, Message(backend_name, "create a stream")); + return status == runtime::kSuccess; +} + +void DestroyStream(infini::rt::test::TestContext* context, + const char* backend_name, runtime::Stream stream) { + ExpectSuccess(context, runtime::StreamDestroy(stream), + Message(backend_name, "destroy a stream")); +} + +void TestCoreDispatch(infini::rt::test::TestContext* context, + const char* backend_name, + const RuntimeApiExpectations& expectations) { std::array input{1, 2, 3, 4}; std::array output{}; void* ptr = nullptr; - infini::rt::SetDevice(device); + int current_device = -1; + ExpectSuccess(context, runtime::GetDevice(¤t_device), + Message(backend_name, "get the current device")); + context->ExpectEqual(current_device, 0, + Message(backend_name, "keep the current device")); - infini::rt::Device current_device; - infini::rt::GetDevice(¤t_device); - context.ExpectEqual(current_device, device, - "Runtime dispatch should keep the current device."); + ExpectSuccess(context, runtime::Malloc(&ptr, input.size()), + Message(backend_name, "allocate memory")); + if (ptr == nullptr) { + return; + } - int device_count = 0; - infini::rt::GetDeviceCount(&device_count, device.type()); - context.Expect(device_count > 0, - "Runtime dispatch should report at least one device."); + ExpectSuccess(context, + runtime::Memcpy(ptr, input.data(), input.size(), + runtime::kMemcpyHostToDevice), + Message(backend_name, "copy host data to runtime memory")); - infini::rt::Malloc(&ptr, input.size()); - context.Expect(ptr != nullptr, "Runtime dispatch should allocate memory."); - if (ptr == nullptr) { - return context.ExitCode(); + runtime::Stream stream{}; + const auto async_status = runtime::MemcpyAsync( + ptr, input.data(), input.size(), runtime::kMemcpyHostToDevice, stream); + ExpectStatus(context, async_status, expectations.async_memcpy, + Message(backend_name, "report expected async memcpy support")); + if (expectations.async_memcpy) { + ExpectSuccess(context, runtime::DeviceSynchronize(), + Message(backend_name, "synchronize async copy")); } - infini::rt::Memcpy(ptr, input.data(), input.size(), - infini::rt::MemcpyKind::kHostToDevice); - infini::rt::Memcpy(output.data(), ptr, output.size(), - infini::rt::MemcpyKind::kDeviceToHost); - context.ExpectEqual(output, input, - "Runtime dispatch should copy data through memory."); + ExpectSuccess(context, + runtime::Memcpy(output.data(), ptr, output.size(), + runtime::kMemcpyDeviceToHost), + Message(backend_name, "copy runtime memory to host")); - infini::rt::Memset(ptr, 0x5A, output.size()); - infini::rt::Memcpy(output.data(), ptr, output.size(), - infini::rt::MemcpyKind::kDeviceToHost); + context->ExpectEqual(output, input, + Message(backend_name, "preserve copied bytes")); + + ExpectSuccess(context, runtime::Memset(ptr, 0x5A, output.size()), + Message(backend_name, "fill runtime memory")); + ExpectSuccess(context, runtime::DeviceSynchronize(), + Message(backend_name, "synchronize filled memory")); + ExpectSuccess(context, + runtime::Memcpy(output.data(), ptr, output.size(), + runtime::kMemcpyDeviceToHost), + Message(backend_name, "copy filled memory to host")); for (const auto value : output) { - context.ExpectEqual(value, static_cast(0x5A), - "Runtime dispatch should fill memory."); + context->ExpectEqual(value, static_cast(0x5A), + Message(backend_name, "preserve filled bytes")); + } + + ExpectSuccess(context, runtime::Free(ptr), + Message(backend_name, "free memory")); +} + +void TestHostMemory(infini::rt::test::TestContext* context, + const char* backend_name, + const RuntimeApiExpectations& expectations) { + void* ptr = nullptr; + const auto malloc_status = runtime::MallocHost(&ptr, 16); + ExpectStatus(context, malloc_status, expectations.host_memory, + Message(backend_name, "report expected host memory support")); + + if (expectations.host_memory) { + context->Expect( + ptr != nullptr, + Message(backend_name, "produce a pointer from host allocation")); + if (ptr != nullptr) { + ExpectSuccess(context, runtime::FreeHost(ptr), + Message(backend_name, "free host memory")); + } + } +} + +void TestAsyncMemory(infini::rt::test::TestContext* context, + const char* backend_name, + const RuntimeApiExpectations& expectations) { + runtime::Stream stream{}; + if (!CreateStream(context, backend_name, &stream)) { + return; + } + + void* ptr = nullptr; + const auto malloc_status = runtime::MallocAsync(&ptr, 16, stream); + ExpectStatus( + context, malloc_status, expectations.async_memory, + Message(backend_name, "report expected async allocation support")); + + if (expectations.async_memory) { + context->Expect( + ptr != nullptr, + Message(backend_name, "produce a pointer from async allocation")); + if (ptr != nullptr) { + ExpectSuccess(context, runtime::FreeAsync(ptr, stream), + Message(backend_name, "free async memory")); + ExpectSuccess( + context, runtime::StreamSynchronize(stream), + Message(backend_name, "synchronize async memory operations")); + } + } else if (ptr != nullptr) { + ExpectSuccess(context, runtime::Free(ptr), + Message(backend_name, "free unexpected async allocation")); + } + + DestroyStream(context, backend_name, stream); +} + +void TestMemGetInfo(infini::rt::test::TestContext* context, + const char* backend_name, + const RuntimeApiExpectations& expectations) { + std::size_t free = 0; + std::size_t total = 0; + const auto status = runtime::MemGetInfo(&free, &total); + ExpectStatus(context, status, expectations.mem_get_info, + Message(backend_name, "report expected memory info support")); + + if (expectations.mem_get_info) { + context->Expect(total > 0, Message(backend_name, "report total memory")); + context->Expect( + free <= total, + Message(backend_name, "report free memory not exceeding total memory")); + } +} + +void TestMemsetAsync(infini::rt::test::TestContext* context, + const char* backend_name, + const RuntimeApiExpectations& expectations) { + std::array output{}; + void* ptr = nullptr; + + ExpectSuccess(context, runtime::Malloc(&ptr, output.size()), + Message(backend_name, "allocate async memset memory")); + if (ptr == nullptr) { + return; + } + + runtime::Stream stream{}; + if (!CreateStream(context, backend_name, &stream)) { + ExpectSuccess(context, runtime::Free(ptr), + Message(backend_name, "free async memset memory")); + return; + } + + const auto memset_status = + runtime::MemsetAsync(ptr, 0xA5, output.size(), stream); + ExpectStatus(context, memset_status, expectations.memset_async, + Message(backend_name, "report expected async memset support")); + + if (expectations.memset_async) { + ExpectSuccess(context, runtime::StreamSynchronize(stream), + Message(backend_name, "synchronize async memset")); + ExpectSuccess(context, + runtime::Memcpy(output.data(), ptr, output.size(), + runtime::kMemcpyDeviceToHost), + Message(backend_name, "copy async filled memory to host")); + for (const auto value : output) { + context->ExpectEqual( + value, static_cast(0xA5), + Message(backend_name, "preserve async filled bytes")); + } + } + + DestroyStream(context, backend_name, stream); + ExpectSuccess(context, runtime::Free(ptr), + Message(backend_name, "free async memset memory")); +} + +void TestStream(infini::rt::test::TestContext* context, + const char* backend_name) { + runtime::Stream stream{}; + if (!CreateStream(context, backend_name, &stream)) { + return; + } + + ExpectSuccess(context, runtime::StreamSynchronize(stream), + Message(backend_name, "synchronize a stream")); + DestroyStream(context, backend_name, stream); +} + +void TestEvent(infini::rt::test::TestContext* context, const char* backend_name, + const RuntimeApiExpectations& expectations) { + runtime::Event event{}; + const auto create_status = runtime::EventCreate(&event); + ExpectStatus(context, create_status, expectations.event, + Message(backend_name, "report expected event support")); + + if (!expectations.event) { + return; + } + + context->Expect( + event != nullptr, + Message(backend_name, "produce an event from event creation")); + if (event == nullptr) { + return; + } + + ExpectSuccess(context, runtime::EventRecord(event, runtime::Stream{}), + Message(backend_name, "record an event")); + ExpectSuccess(context, runtime::EventSynchronize(event), + Message(backend_name, "synchronize an event")); + ExpectSuccess(context, runtime::EventQuery(event), + Message(backend_name, "query a completed event")); + ExpectSuccess(context, runtime::EventDestroy(event), + Message(backend_name, "destroy an event")); + + runtime::Event flagged_event{}; + ExpectSuccess(context, runtime::EventCreateWithFlags(&flagged_event, 0), + Message(backend_name, "create an event with flags")); + if (flagged_event != nullptr) { + ExpectSuccess(context, runtime::EventDestroy(flagged_event), + Message(backend_name, "destroy a flagged event")); + } +} + +void TestStreamWaitEvent(infini::rt::test::TestContext* context, + const char* backend_name, + const RuntimeApiExpectations& expectations) { + runtime::Stream stream{}; + if (!CreateStream(context, backend_name, &stream)) { + return; + } + + runtime::Event event{}; + if (expectations.event) { + ExpectSuccess(context, runtime::EventCreate(&event), + Message(backend_name, "create a stream wait event")); + if (event != nullptr) { + ExpectSuccess(context, runtime::EventRecord(event, stream), + Message(backend_name, "record a stream wait event")); + } + } + + const auto wait_status = runtime::StreamWaitEvent(stream, event, 0); + ExpectStatus( + context, wait_status, expectations.stream_wait_event, + Message(backend_name, "report expected stream wait event support")); + ExpectSuccess(context, runtime::StreamSynchronize(stream), + Message(backend_name, "synchronize after stream wait event")); + + if (event != nullptr) { + ExpectSuccess(context, runtime::EventDestroy(event), + Message(backend_name, "destroy a stream wait event")); + } + DestroyStream(context, backend_name, stream); +} + +void TestEventElapsedTime(infini::rt::test::TestContext* context, + const char* backend_name, + const RuntimeApiExpectations& expectations) { + runtime::Event start{}; + runtime::Event end{}; + if (expectations.event) { + ExpectSuccess(context, runtime::EventCreate(&start), + Message(backend_name, "create an elapsed-time start event")); + ExpectSuccess(context, runtime::EventCreate(&end), + Message(backend_name, "create an elapsed-time end event")); + } + + if (expectations.event_elapsed_time) { + ExpectSuccess(context, runtime::EventRecord(start, runtime::Stream{}), + Message(backend_name, "record an elapsed-time start event")); + ExpectSuccess(context, runtime::EventRecord(end, runtime::Stream{}), + Message(backend_name, "record an elapsed-time end event")); + ExpectSuccess(context, runtime::EventSynchronize(end), + Message(backend_name, "synchronize an elapsed-time event")); } - infini::rt::DeviceSynchronize(); - infini::rt::Free(ptr); + float elapsed_ms = 0.0f; + const auto elapsed_status = + runtime::EventElapsedTime(&elapsed_ms, start, end); + ExpectStatus( + context, elapsed_status, expectations.event_elapsed_time, + Message(backend_name, "report expected event elapsed-time support")); + + if (start != nullptr) { + ExpectSuccess(context, runtime::EventDestroy(start), + Message(backend_name, "destroy an elapsed-time start event")); + } + if (end != nullptr) { + ExpectSuccess(context, runtime::EventDestroy(end), + Message(backend_name, "destroy an elapsed-time end event")); + } +} + +void TestDispatch(infini::rt::test::TestContext* context, + infini::rt::Device::Type device_type, + const char* backend_name, + const RuntimeApiExpectations& expectations) { + if (!SelectDevice(context, device_type, backend_name)) { + return; + } + + TestCoreDispatch(context, backend_name, expectations); + TestHostMemory(context, backend_name, expectations); + TestAsyncMemory(context, backend_name, expectations); + TestMemGetInfo(context, backend_name, expectations); + TestMemsetAsync(context, backend_name, expectations); + TestStream(context, backend_name); + TestEvent(context, backend_name, expectations); + TestStreamWaitEvent(context, backend_name, expectations); + TestEventElapsedTime(context, backend_name, expectations); +} + +} // namespace + +int main() { + infini::rt::test::TestContext context; + +#if defined(INFINI_RT_TEST_WITH_CPU) + TestDispatch(&context, infini::rt::Device::Type::kCpu, "CPU", + {false, true, false, true, false, true, true, true}); +#endif + +#if defined(INFINI_RT_TEST_WITH_NVIDIA) + TestDispatch(&context, infini::rt::Device::Type::kNvidia, "NVIDIA", + {true, true, true, true, true, true, true, true}); +#endif + +#if defined(INFINI_RT_TEST_WITH_ILUVATAR) + TestDispatch(&context, infini::rt::Device::Type::kIluvatar, "ILUVATAR", + {true, true, true, true, true, true, true, true}); +#endif + +#if defined(INFINI_RT_TEST_WITH_HYGON) + TestDispatch(&context, infini::rt::Device::Type::kHygon, "HYGON", + {true, true, true, true, true, true, true, true}); +#endif + +#if defined(INFINI_RT_TEST_WITH_METAX) + TestDispatch(&context, infini::rt::Device::Type::kMetax, "METAX", + {true, true, true, true, true, true, true, true}); +#endif + +#if defined(INFINI_RT_TEST_WITH_MOORE) + TestDispatch(&context, infini::rt::Device::Type::kMoore, "MOORE", + {true, true, false, true, true, true, true, true}); +#endif + +#if defined(INFINI_RT_TEST_WITH_CAMBRICON) + TestDispatch(&context, infini::rt::Device::Type::kCambricon, "CAMBRICON", + {true, false, false, false, false, false, false, false}); +#endif + +#if defined(INFINI_RT_TEST_WITH_ASCEND) + TestDispatch(&context, infini::rt::Device::Type::kAscend, "ASCEND", + {true, false, false, false, false, false, false, false}); +#endif return context.ExitCode(); }