diff --git a/README.md b/README.md index b93b7be..5d41768 100644 --- a/README.md +++ b/README.md @@ -66,25 +66,42 @@ cmake --install build #include int main() { - infini::rt::SetDevice({infini::rt::Device::Type::kCpu, 0}); + infini::rt::runtime::SetDevice(0); constexpr std::size_t size = 1024; void* ptr = nullptr; - infini::rt::Malloc(&ptr, size); - infini::rt::Memset(ptr, 0, size); - infini::rt::Free(ptr); + infini::rt::runtime::Malloc(&ptr, size); + infini::rt::runtime::Memset(ptr, 0, size); + infini::rt::runtime::Free(ptr); return 0; } ``` -For NVIDIA: +The CUDA Runtime API-aligned layer lives under `infini::rt::runtime`. The +top-level `infini::rt::set_runtime_device_type` and +`infini::rt::runtime_device_type` APIs select which enabled backend receives +those runtime calls. A GPU backend is selected initially when one is enabled; +otherwise CPU is selected. ```cpp -infini::rt::SetDevice({infini::rt::Device::Type::kNvidia, 0}); +constexpr std::size_t size = 1024; +void* ptr = nullptr; + +infini::rt::set_runtime_device_type(infini::rt::Device::Type::kCpu); +infini::rt::runtime::Malloc(&ptr, size); +infini::rt::runtime::Free(ptr); + +infini::rt::set_runtime_device_type(infini::rt::Device::Type::kNvidia); +infini::rt::runtime::Malloc(&ptr, size); +infini::rt::runtime::Free(ptr); ``` +Use `infini::rt::runtime::Runtime` when CPU +runtime calls are needed explicitly in a build that also enables an accelerator +backend. + ## Using Installed InfiniRT From Another Project Downstream projects should consume the installed headers and libraries instead of depending on the InfiniRT source tree. diff --git a/scripts/generate_public_headers.py b/scripts/generate_public_headers.py index 93ceee3..e929682 100644 --- a/scripts/generate_public_headers.py +++ b/scripts/generate_public_headers.py @@ -53,21 +53,23 @@ "cpu": "Device::Type::kCpu", "nvidia": "Device::Type::kNvidia", "iluvatar": "Device::Type::kIluvatar", + "hygon": "Device::Type::kHygon", "metax": "Device::Type::kMetax", "moore": "Device::Type::kMoore", "cambricon": "Device::Type::kCambricon", "ascend": "Device::Type::kAscend", } -_RUNTIME_HEADERS = { - "cpu": "native/cpu/runtime_.h", - "nvidia": "native/cuda/nvidia/runtime_.h", - "iluvatar": "native/cuda/iluvatar/runtime_.h", - "metax": "native/cuda/metax/runtime_.h", - "moore": "native/cuda/moore/runtime_.h", - "cambricon": "native/cambricon/runtime_.h", - "ascend": "native/ascend/runtime_.h", -} +_DEFAULT_DEVICE_PRIORITY = ( + "nvidia", + "iluvatar", + "hygon", + "metax", + "moore", + "cambricon", + "ascend", + "cpu", +) def _guard(path): @@ -155,7 +157,11 @@ def _write_detail_headers(include_root, source_root, devices): def _write_generated_header(include_root, devices): + default_device = _default_device(devices) + default_device_type = _DEVICE_TYPES[default_device] includes = [ + "#include ", + "#include ", f"#include {_detail_include('data_type.h')}", f"#include {_detail_include('device.h')}", f"#include {_detail_include('hash.h')}", @@ -166,6 +172,13 @@ def _write_generated_header(include_root, devices): for device in devices: includes.append(f"#include ") + for device in devices: + includes.append(f"#include ") + + runtime_declarations = "\n\n".join( + f"{function.signature()};" for function in _PUBLIC_RUNTIME_FUNCTIONS + ) + path = include_root / "infini" / "rt" / "generated.h" path.parent.mkdir(parents=True, exist_ok=True) path.write_text( @@ -174,6 +187,47 @@ def _write_generated_header(include_root, devices): {chr(10).join(includes)} +namespace infini::rt {{ + +void set_runtime_device_type(Device::Type device_type); + +Device::Type runtime_device_type(); + +namespace runtime {{ +namespace generated_detail {{ + +using DefaultErrorRuntime = Runtime<{default_device_type}>; + +inline constexpr Device::Type kDefaultDeviceType = {default_device_type}; + +}} // namespace generated_detail + +using Error = typename generated_detail::DefaultErrorRuntime::Error; + +using Stream = typename generated_detail::DefaultErrorRuntime::Stream; + +using MemcpyKind = std::remove_cv_t< + decltype(generated_detail::DefaultErrorRuntime::kMemcpyHostToHost)>; + +inline constexpr Error kSuccess = generated_detail::DefaultErrorRuntime::kSuccess; + +inline constexpr MemcpyKind kMemcpyHostToHost = + generated_detail::DefaultErrorRuntime::kMemcpyHostToHost; + +inline constexpr MemcpyKind kMemcpyHostToDevice = + generated_detail::DefaultErrorRuntime::kMemcpyHostToDevice; + +inline constexpr MemcpyKind kMemcpyDeviceToHost = + generated_detail::DefaultErrorRuntime::kMemcpyDeviceToHost; + +inline constexpr MemcpyKind kMemcpyDeviceToDevice = + generated_detail::DefaultErrorRuntime::kMemcpyDeviceToDevice; + +{runtime_declarations} + +}} // namespace runtime +}} // namespace infini::rt + #endif """ ) @@ -198,202 +252,189 @@ def params_decl(self): return ", ".join(f"{param.type} {param.name}" for param in self.params) -def _parse_param(param): - param_type, param_name = param.strip().rsplit(" ", 1) - - return _Param(param_type, param_name) - - -def _parse_runtime_functions(runtime_header): - text = pathlib.Path(runtime_header).read_text() - return tuple( - _Function( - return_type, - name, - tuple(_parse_param(param) for param in params.split(", ") if param), - ) - for return_type, name, params in re.findall( - r"^(void) ([A-Z]\w*)\(([^()]*)\);$", text, re.MULTILINE - ) - ) - - -def _abort_statement(message): - return f""" assert(false && "{message}"); - std::abort();""" - - -def _dispatch_cases(devices, statements): - return "\n".join( - f""" case {_DEVICE_TYPES[device]}: {{ -{statements.replace("__DEVICE_TYPE__", _DEVICE_TYPES[device])} - return; - }}""" - for device in devices - ) +_PUBLIC_RUNTIME_FUNCTIONS = ( + _Function("Error", "SetDevice", (_Param("int", "device"),)), + _Function("Error", "GetDevice", (_Param("int*", "device"),)), + _Function("Error", "GetDeviceCount", (_Param("int*", "count"),)), + _Function("Error", "DeviceSynchronize", ()), + _Function( + "Error", + "Malloc", + (_Param("void**", "ptr"), _Param("std::size_t", "size")), + ), + _Function("Error", "Free", (_Param("void*", "ptr"),)), + _Function( + "Error", + "Memset", + ( + _Param("void*", "ptr"), + _Param("int", "value"), + _Param("std::size_t", "count"), + ), + ), + _Function( + "Error", + "Memcpy", + ( + _Param("void*", "dst"), + _Param("const void*", "src"), + _Param("std::size_t", "count"), + _Param("MemcpyKind", "kind"), + ), + ), + _Function( + "Error", + "MemcpyAsync", + ( + _Param("void*", "dst"), + _Param("const void*", "src"), + _Param("std::size_t", "count"), + _Param("MemcpyKind", "kind"), + _Param("Stream", "stream"), + ), + ), +) -def _selector(function): - for param in function.params: - if param.type == "Device": - return f"{param.name}.type()" - if param.type == "Device::Type": - return param.name +def _default_device(devices): + for device in _DEFAULT_DEVICE_PRIORITY: + if device in devices: + return device - return "current_device.type()" + raise ValueError("at least one device is required") -def _runtime_arg(param): - if param.type == "Device": - return f"{param.name}.index()" - if param.type == "Device::Type": - return None +def _runtime_arg(param, device): + device_type = _DEVICE_TYPES[device] if param.type == "MemcpyKind": - return f"RuntimeMemcpyKind<__DEVICE_TYPE__>({param.name})" + return f"RuntimeMemcpyKind<{device_type}>({param.name})" + if param.type == "Stream": + return ( + f"reinterpret_cast::Stream>({param.name})" + ) return param.name -def _runtime_args(function): - args = (_runtime_arg(param) for param in function.params) +def _runtime_args(function, device): + args = (_runtime_arg(param, device) for param in function.params) return ", ".join(arg for arg in args if arg is not None) -def _preconditions(function): - required_pointer_names = { - "GetDevice": {"device"}, - "GetDeviceCount": {"count"}, - } - checks = [] - for param in function.params: - if param.type.endswith("**") or param.name in required_pointer_names.get( - function.name, set() - ): - checks.append(f" assert({param.name} != nullptr);") - - return "\n".join(checks) - - -def _post_dispatch(function): - if function.name == "SetDevice": - return "\n current_device = device;" - - return "" - - -def _runtime_call(function): - args = _runtime_args(function) +def _runtime_call(function, device): + device_type = _DEVICE_TYPES[device] + args = _runtime_args(function, device) if args: - return f"Runtime<__DEVICE_TYPE__>::{function.name}({args})" + return f"Runtime<{device_type}>::{function.name}({args})" - return f"Runtime<__DEVICE_TYPE__>::{function.name}()" + return f"Runtime<{device_type}>::{function.name}()" -def _write_get_device(function, devices): - device_param = function.params[0].name - cases = _dispatch_cases( - devices, - f""" int index = current_device.index(); - CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GetDevice(&index); }}); - current_device = Device{{current_device.type(), index}}; - *{device_param} = current_device;""", +def _dispatch_cases(devices, function): + return "\n".join( + f""" case {_DEVICE_TYPES[device]}: + return CheckCall([&] {{ return {_runtime_call(function, device)}; }});""" + for device in devices ) - return f"""void GetDevice(Device* {device_param}) {{ - assert({device_param} != nullptr); - - switch (current_device.type()) {{ -{cases} - default: -{_abort_statement("runtime device is not enabled")} - }} -}} -""" - - -def _write_dispatch_function(function, devices): - if function.name == "GetDevice": - return _write_get_device(function, devices) - - cases = _dispatch_cases( - devices, - f""" CheckCall([&] {{ return {_runtime_call(function)}; }});{_post_dispatch(function)}""", - ) - preconditions = _preconditions(function) - if preconditions: - preconditions = f"{preconditions}\n\n" +def _write_runtime_dispatch_function(function, devices): return f"""{function.signature()} {{ -{preconditions} switch ({_selector(function)}) {{ -{cases} - default: -{_abort_statement("runtime device is not enabled")} + switch (infini::rt::runtime_device_type()) {{ +{_dispatch_cases(devices, function)} }} + + assert(false && "unsupported runtime device type"); + return InvalidValueError(); }} """ -def _write_runtime_dispatch(source_path, runtime_header, devices): - first_device_type = _DEVICE_TYPES[devices[0]] - includes = ['#include "runtime.h"'] - includes.extend(f'#include "{_RUNTIME_HEADERS[device]}"' for device in devices) - functions = _parse_runtime_functions(runtime_header) +def _write_runtime_dispatch(source_path, devices): + functions = _PUBLIC_RUNTIME_FUNCTIONS dispatch_functions = "\n".join( - _write_dispatch_function(function, devices) for function in functions + _write_runtime_dispatch_function(function, devices=devices) + for function in functions + ) + set_device_type_cases = "\n".join( + f""" case {_DEVICE_TYPES[device]}: + runtime_device_type_ = device_type; + return;""" + for device in devices ) source_path.parent.mkdir(parents=True, exist_ok=True) source_path.write_text( f"""#include -#include +#include #include #include -{chr(10).join(includes)} +#include namespace infini::rt {{ namespace {{ -thread_local Device current_device{{{first_device_type}, 0}}; +thread_local Device::Type runtime_device_type_ = + runtime::generated_detail::kDefaultDeviceType; + +}} // namespace + +void set_runtime_device_type(Device::Type device_type) {{ + switch (device_type) {{ +{set_device_type_cases} + }} + + assert(false && "unsupported runtime device type"); +}} + +Device::Type runtime_device_type() {{ + return runtime_device_type_; +}} + +}} // namespace infini::rt + +namespace infini::rt::runtime {{ +namespace {{ template -void CheckCall(Func&& func) {{ +Error CheckCall(Func&& func) {{ using ReturnType = decltype(std::forward(func)()); if constexpr (std::is_void_v) {{ std::forward(func)(); + return kSuccess; }} else {{ - ReturnType status = std::forward(func)(); - if (status != ReturnType{{}}) {{ - assert(false && "runtime call failed"); - std::abort(); - }} + return static_cast(std::forward(func)()); }} }} -template +Error InvalidValueError() {{ return static_cast(1); }} + +template auto RuntimeMemcpyKind(MemcpyKind kind) {{ + using DeviceRuntime = Runtime; + switch (kind) {{ - case MemcpyKind::kHostToHost: - return Runtime::MemcpyHostToHost; - case MemcpyKind::kHostToDevice: - return Runtime::MemcpyHostToDevice; - case MemcpyKind::kDeviceToHost: - return Runtime::MemcpyDeviceToHost; - case MemcpyKind::kDeviceToDevice: - return Runtime::MemcpyDeviceToDevice; + case kMemcpyHostToHost: + return DeviceRuntime::kMemcpyHostToHost; + case kMemcpyHostToDevice: + return DeviceRuntime::kMemcpyHostToDevice; + case kMemcpyDeviceToHost: + return DeviceRuntime::kMemcpyDeviceToHost; + case kMemcpyDeviceToDevice: + return DeviceRuntime::kMemcpyDeviceToDevice; }} assert(false && "unsupported memcpy kind"); - std::abort(); - return Runtime::MemcpyHostToHost; + return DeviceRuntime::kMemcpyHostToHost; }} }} // namespace {dispatch_functions} -}} // namespace infini::rt +}} // namespace infini::rt::runtime """ ) @@ -422,9 +463,7 @@ def main(): _write_wrapper(include_root, wrapper_device, header_name, target) _write_generated_header(include_root, devices) - _write_runtime_dispatch( - pathlib.Path(args.source_output), args.runtime_header, devices - ) + _write_runtime_dispatch(pathlib.Path(args.source_output), devices) if __name__ == "__main__": diff --git a/src/native/ascend/runtime_.h b/src/native/ascend/runtime_.h index 8b33e54..065a9d3 100644 --- a/src/native/ascend/runtime_.h +++ b/src/native/ascend/runtime_.h @@ -11,15 +11,19 @@ #include "native/ascend/device_.h" #include "runtime.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : DeviceRuntime> { + using Error = aclError; + using Stream = aclrtStream; static constexpr Device::Type kDeviceType = Device::Type::kAscend; + static constexpr Error kSuccess = ACL_SUCCESS; + static constexpr auto SetDevice = aclrtSetDevice; static constexpr auto GetDevice = aclrtGetDevice; @@ -45,13 +49,19 @@ struct Runtime return aclrtMemcpy(dst, count, src, count, kind); }; - static constexpr auto MemcpyHostToHost = ACL_MEMCPY_HOST_TO_HOST; + static constexpr auto MemcpyAsync = [](void* dst, const void* src, + size_t count, aclrtMemcpyKind kind, + Stream stream) { + return aclrtMemcpyAsync(dst, count, src, count, kind, stream); + }; + + static constexpr auto kMemcpyHostToHost = ACL_MEMCPY_HOST_TO_HOST; - static constexpr auto MemcpyHostToDevice = ACL_MEMCPY_HOST_TO_DEVICE; + static constexpr auto kMemcpyHostToDevice = ACL_MEMCPY_HOST_TO_DEVICE; - static constexpr auto MemcpyDeviceToHost = ACL_MEMCPY_DEVICE_TO_HOST; + static constexpr auto kMemcpyDeviceToHost = ACL_MEMCPY_DEVICE_TO_HOST; - static constexpr auto MemcpyDeviceToDevice = ACL_MEMCPY_DEVICE_TO_DEVICE; + static constexpr auto kMemcpyDeviceToDevice = ACL_MEMCPY_DEVICE_TO_DEVICE; static constexpr auto Memset = [](void* ptr, int value, size_t count) { return aclrtMemset(ptr, count, value, count); @@ -60,6 +70,6 @@ struct Runtime static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cambricon/runtime_.h b/src/native/cambricon/runtime_.h index 4db4920..927e2c5 100644 --- a/src/native/cambricon/runtime_.h +++ b/src/native/cambricon/runtime_.h @@ -9,15 +9,19 @@ #include "native/cambricon/device_.h" #include "runtime.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : DeviceRuntime> { + using Error = cnrtRet_t; + using Stream = cnrtQueue_t; static constexpr Device::Type kDeviceType = Device::Type::kCambricon; + static constexpr Error kSuccess = CNRT_RET_SUCCESS; + static constexpr auto SetDevice = cnrtSetDevice; static constexpr auto GetDevice = cnrtGetDevice; @@ -41,19 +45,25 @@ struct Runtime return cnrtMemcpy(dst, const_cast(src), size, kind); }; - static constexpr auto MemcpyHostToHost = cnrtMemcpyHostToHost; + static constexpr auto MemcpyAsync = [](void* dst, const void* src, + std::size_t size, auto kind, + Stream stream) { + return cnrtMemcpyAsync(dst, const_cast(src), size, kind, stream); + }; + + static constexpr auto kMemcpyHostToHost = cnrtMemcpyHostToHost; - static constexpr auto MemcpyHostToDevice = cnrtMemcpyHostToDev; + static constexpr auto kMemcpyHostToDevice = cnrtMemcpyHostToDev; - static constexpr auto MemcpyDeviceToHost = cnrtMemcpyDevToHost; + static constexpr auto kMemcpyDeviceToHost = cnrtMemcpyDevToHost; - static constexpr auto MemcpyDeviceToDevice = cnrtMemcpyDevToDev; + static constexpr auto kMemcpyDeviceToDevice = cnrtMemcpyDevToDev; static constexpr auto Memset = cnrtMemset; }; static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cpu/runtime_.h b/src/native/cpu/runtime_.h index bf5a81c..f4b18bf 100644 --- a/src/native/cpu/runtime_.h +++ b/src/native/cpu/runtime_.h @@ -1,60 +1,113 @@ #ifndef INFINI_RT_CPU_RUNTIME__H_ #define INFINI_RT_CPU_RUNTIME__H_ -#include #include #include #include "runtime.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : RuntimeBase> { static constexpr Device::Type kDeviceType = Device::Type::kCpu; - static void SetDevice(int index) { - if (index != 0) { - assert(false && "CPU device index must be 0"); - std::abort(); + using Error = int; + + using Stream = void*; + + static constexpr Error kSuccess = 0; + + static constexpr Error kErrorInvalidValue = 1; + + static constexpr Error kErrorMemoryAllocation = 2; + + static Error SetDevice(int device) { + if (device != 0) { + return kErrorInvalidValue; } + + return kSuccess; } - static void GetDevice(int* index) { - assert(index != nullptr); - *index = 0; + static Error GetDevice(int* device) { + if (device == nullptr) { + return kErrorInvalidValue; + } + + *device = 0; + + return kSuccess; } - static void GetDeviceCount(int* count) { - assert(count != nullptr); + static Error GetDeviceCount(int* count) { + if (count == nullptr) { + return kErrorInvalidValue; + } + *count = 1; + + return kSuccess; + } + + static Error DeviceSynchronize() { return kSuccess; } + + static Error Malloc(void** ptr, std::size_t size) { + if (ptr == nullptr) { + return kErrorInvalidValue; + } + + *ptr = std::malloc(size); + + if (size != 0 && *ptr == nullptr) { + return kErrorMemoryAllocation; + } + + return kSuccess; } - static void DeviceSynchronize() {} + static Error Free(void* ptr) { + std::free(ptr); - static void Malloc(void** ptr, std::size_t size) { *ptr = std::malloc(size); } + return kSuccess; + } - static void Free(void* ptr) { std::free(ptr); } + static Error Memcpy(void* dst, const void* src, std::size_t size, int) { + if ((dst == nullptr || src == nullptr) && size != 0) { + return kErrorInvalidValue; + } - static void Memcpy(void* dst, const void* src, std::size_t size, int) { std::memcpy(dst, src, size); + + return kSuccess; } - static void Memset(void* ptr, int value, std::size_t count) { + static Error Memset(void* ptr, int value, std::size_t count) { + if (ptr == nullptr && count != 0) { + return kErrorInvalidValue; + } + std::memset(ptr, value, count); + + return kSuccess; + } + + static Error MemcpyAsync(void* dst, const void* src, std::size_t size, + int kind, Stream) { + return kErrorInvalidValue; } - static constexpr int MemcpyHostToHost = 0; + static constexpr int kMemcpyHostToHost = 0; - static constexpr int MemcpyHostToDevice = 0; + static constexpr int kMemcpyHostToDevice = 1; - static constexpr int MemcpyDeviceToHost = 1; + static constexpr int kMemcpyDeviceToHost = 2; - static constexpr int MemcpyDeviceToDevice = 0; + static constexpr int kMemcpyDeviceToDevice = 3; }; static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cuda/hygon/runtime_.h b/src/native/cuda/hygon/runtime_.h index a9f41ec..52c47eb 100644 --- a/src/native/cuda/hygon/runtime_.h +++ b/src/native/cuda/hygon/runtime_.h @@ -10,34 +10,52 @@ #include "native/cuda/hygon/device_.h" #include "native/cuda/runtime_.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : CudaRuntime> { + using Error = cudaError_t; + using Stream = cudaStream_t; static constexpr Device::Type kDeviceType = Device::Type::kHygon; + static constexpr Error kSuccess = cudaSuccess; + + static constexpr auto SetDevice = cudaSetDevice; + + static constexpr auto GetDevice = cudaGetDevice; + + static constexpr auto GetDeviceCount = cudaGetDeviceCount; + + static constexpr auto DeviceSynchronize = cudaDeviceSynchronize; + static constexpr auto Malloc = [](auto&&... args) { return cudaMalloc(std::forward(args)...); }; static constexpr auto Memcpy = cudaMemcpy; + static constexpr auto MemcpyAsync = cudaMemcpyAsync; + static constexpr auto Free = [](auto&&... args) { return cudaFree(std::forward(args)...); }; - static constexpr auto MemcpyHostToDevice = cudaMemcpyHostToDevice; + static constexpr auto kMemcpyHostToHost = cudaMemcpyHostToHost; + + static constexpr auto kMemcpyHostToDevice = cudaMemcpyHostToDevice; + + static constexpr auto kMemcpyDeviceToHost = cudaMemcpyDeviceToHost; - static constexpr auto MemcpyDeviceToHost = cudaMemcpyDeviceToHost; + static constexpr auto kMemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; static constexpr auto Memset = cudaMemset; }; static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cuda/iluvatar/runtime_.h b/src/native/cuda/iluvatar/runtime_.h index 8a1b649..f49db23 100644 --- a/src/native/cuda/iluvatar/runtime_.h +++ b/src/native/cuda/iluvatar/runtime_.h @@ -10,15 +10,19 @@ #include "native/cuda/iluvatar/device_.h" #include "native/cuda/runtime_.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : CudaRuntime> { + using Error = cudaError_t; + using Stream = cudaStream_t; static constexpr Device::Type kDeviceType = Device::Type::kIluvatar; + static constexpr Error kSuccess = cudaSuccess; + static constexpr auto SetDevice = cudaSetDevice; static constexpr auto GetDevice = cudaGetDevice; @@ -33,21 +37,23 @@ struct Runtime static constexpr auto Memcpy = cudaMemcpy; + static constexpr auto MemcpyAsync = cudaMemcpyAsync; + static constexpr auto Free = cudaFree; - static constexpr auto MemcpyHostToHost = cudaMemcpyHostToHost; + static constexpr auto kMemcpyHostToHost = cudaMemcpyHostToHost; - static constexpr auto MemcpyHostToDevice = cudaMemcpyHostToDevice; + static constexpr auto kMemcpyHostToDevice = cudaMemcpyHostToDevice; - static constexpr auto MemcpyDeviceToHost = cudaMemcpyDeviceToHost; + static constexpr auto kMemcpyDeviceToHost = cudaMemcpyDeviceToHost; - static constexpr auto MemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; + static constexpr auto kMemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; static constexpr auto Memset = cudaMemset; }; static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cuda/metax/runtime_.h b/src/native/cuda/metax/runtime_.h index 5785a51..c1f19c0 100644 --- a/src/native/cuda/metax/runtime_.h +++ b/src/native/cuda/metax/runtime_.h @@ -8,15 +8,19 @@ #include "native/cuda/metax/device_.h" #include "native/cuda/runtime_.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : CudaRuntime> { + using Error = mcError_t; + using Stream = mcStream_t; static constexpr Device::Type kDeviceType = Device::Type::kMetax; + static constexpr Error kSuccess = mcSuccess; + static constexpr auto SetDevice = mcSetDevice; static constexpr auto GetDevice = mcGetDevice; @@ -33,23 +37,27 @@ struct Runtime return mcMemcpy(std::forward(args)...); }; + static constexpr auto MemcpyAsync = [](auto&&... args) { + return mcMemcpyAsync(std::forward(args)...); + }; + static constexpr auto Free = [](auto&&... args) { return mcFree(std::forward(args)...); }; - static constexpr auto MemcpyHostToHost = mcMemcpyHostToHost; + static constexpr auto kMemcpyHostToHost = mcMemcpyHostToHost; - static constexpr auto MemcpyHostToDevice = mcMemcpyHostToDevice; + static constexpr auto kMemcpyHostToDevice = mcMemcpyHostToDevice; - static constexpr auto MemcpyDeviceToHost = mcMemcpyDeviceToHost; + static constexpr auto kMemcpyDeviceToHost = mcMemcpyDeviceToHost; - static constexpr auto MemcpyDeviceToDevice = mcMemcpyDeviceToDevice; + static constexpr auto kMemcpyDeviceToDevice = mcMemcpyDeviceToDevice; static constexpr auto Memset = mcMemset; }; static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cuda/moore/runtime_.h b/src/native/cuda/moore/runtime_.h index 8ced2ed..5beffcf 100644 --- a/src/native/cuda/moore/runtime_.h +++ b/src/native/cuda/moore/runtime_.h @@ -8,15 +8,19 @@ #include "native/cuda/moore/device_.h" #include "native/cuda/runtime_.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : CudaRuntime> { + using Error = musaError_t; + using Stream = musaStream_t; static constexpr Device::Type kDeviceType = Device::Type::kMoore; + static constexpr Error kSuccess = musaSuccess; + static constexpr auto SetDevice = musaSetDevice; static constexpr auto GetDevice = [](auto&&... args) { @@ -39,23 +43,27 @@ struct Runtime return musaMemcpy(std::forward(args)...); }; + static constexpr auto MemcpyAsync = [](auto&&... args) { + return musaMemcpyAsync(std::forward(args)...); + }; + static constexpr auto Free = [](auto&&... args) { return musaFree(std::forward(args)...); }; - static constexpr auto MemcpyHostToHost = musaMemcpyHostToHost; + static constexpr auto kMemcpyHostToHost = musaMemcpyHostToHost; - static constexpr auto MemcpyHostToDevice = musaMemcpyHostToDevice; + static constexpr auto kMemcpyHostToDevice = musaMemcpyHostToDevice; - static constexpr auto MemcpyDeviceToHost = musaMemcpyDeviceToHost; + static constexpr auto kMemcpyDeviceToHost = musaMemcpyDeviceToHost; - static constexpr auto MemcpyDeviceToDevice = musaMemcpyDeviceToDevice; + static constexpr auto kMemcpyDeviceToDevice = musaMemcpyDeviceToDevice; static constexpr auto Memset = musaMemset; }; static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cuda/nvidia/runtime_.h b/src/native/cuda/nvidia/runtime_.h index f6a9f2d..1786e08 100644 --- a/src/native/cuda/nvidia/runtime_.h +++ b/src/native/cuda/nvidia/runtime_.h @@ -10,15 +10,19 @@ #include "native/cuda/nvidia/device_.h" #include "native/cuda/runtime_.h" -namespace infini::rt { +namespace infini::rt::runtime { template <> struct Runtime : CudaRuntime> { + using Error = cudaError_t; + using Stream = cudaStream_t; static constexpr Device::Type kDeviceType = Device::Type::kNvidia; + static constexpr Error kSuccess = cudaSuccess; + static constexpr auto SetDevice = cudaSetDevice; static constexpr auto GetDevice = cudaGetDevice; @@ -33,21 +37,23 @@ struct Runtime static constexpr auto Memcpy = cudaMemcpy; + static constexpr auto MemcpyAsync = cudaMemcpyAsync; + static constexpr auto Free = cudaFree; - static constexpr auto MemcpyHostToHost = cudaMemcpyHostToHost; + static constexpr auto kMemcpyHostToHost = cudaMemcpyHostToHost; - static constexpr auto MemcpyHostToDevice = cudaMemcpyHostToDevice; + static constexpr auto kMemcpyHostToDevice = cudaMemcpyHostToDevice; - static constexpr auto MemcpyDeviceToHost = cudaMemcpyDeviceToHost; + static constexpr auto kMemcpyDeviceToHost = cudaMemcpyDeviceToHost; - static constexpr auto MemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; + static constexpr auto kMemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; static constexpr auto Memset = cudaMemset; }; static_assert(Runtime::Validate()); -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/native/cuda/runtime_.h b/src/native/cuda/runtime_.h index 8765a05..59dd6bc 100644 --- a/src/native/cuda/runtime_.h +++ b/src/native/cuda/runtime_.h @@ -5,7 +5,7 @@ #include "runtime.h" -namespace infini::rt { +namespace infini::rt::runtime { /// ## CUDA-like runtime interface enforcement via CRTP. /// @@ -17,13 +17,19 @@ struct CudaRuntime : DeviceRuntime { DeviceRuntime::Validate(); static_assert( std::is_invocable_v, + size_t, decltype(Derived::kMemcpyHostToDevice)>, "`Runtime::Memcpy` must be callable with " - "`(void*, const void*, size_t, MemcpyHostToDevice)`."); + "`(void*, const void*, size_t, kMemcpyHostToDevice)`."); + static_assert( + std::is_invocable_v, + "`Runtime::MemcpyAsync` must be callable with " + "`(void*, const void*, size_t, kMemcpyHostToDevice, Stream)`."); return true; } }; -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/runtime.h b/src/runtime.h index ebc2698..c3e7640 100644 --- a/src/runtime.h +++ b/src/runtime.h @@ -6,7 +6,7 @@ #include "device.h" -namespace infini::rt { +namespace infini::rt::runtime { template struct Runtime; @@ -14,8 +14,9 @@ struct Runtime; /// ## Interface enforcement via CRTP. /// /// Inherit from the appropriate base to declare which interface level a -/// `Runtime` specialization implements. After the struct is fully defined, call -/// `static_assert(Runtime<...>::Validate())`. The chained `Validate()` checks +/// `runtime::Runtime` specialization implements. After the struct is fully +/// defined, call `static_assert(Runtime<...>::Validate())`. The chained +/// `Validate()` checks /// every required member's existence and signature at compile time, analogous /// to how `override` catches signature mismatches for virtual functions. /// @@ -30,6 +31,11 @@ struct RuntimeBase { std::is_same_v, Device::Type>, "`Runtime` must define `static constexpr Device::Type kDeviceType`."); + static_assert(sizeof(typename Derived::Error) > 0, + "`Runtime` must define an `Error` type alias."); + static_assert(std::is_same_v, + typename Derived::Error>, + "`Runtime` must define `static constexpr Error kSuccess`."); return true; } }; @@ -51,29 +57,6 @@ struct DeviceRuntime : RuntimeBase { } }; -enum class MemcpyKind { - kHostToHost, - kHostToDevice, - kDeviceToHost, - kDeviceToDevice, -}; - -void SetDevice(Device device); - -void GetDevice(Device* device); - -void GetDeviceCount(int* count, Device::Type type); - -void DeviceSynchronize(); - -void Malloc(void** ptr, std::size_t size); - -void Free(void* ptr); - -void Memset(void* ptr, int value, std::size_t count); - -void Memcpy(void* dst, const void* src, std::size_t count, MemcpyKind kind); - -} // namespace infini::rt +} // namespace infini::rt::runtime #endif diff --git a/src/tensor_view.h b/src/tensor_view.h index 0b6fc59..dcf7cc9 100644 --- a/src/tensor_view.h +++ b/src/tensor_view.h @@ -3,6 +3,8 @@ #include #include +#include +#include #include #include "data_type.h" @@ -11,6 +13,22 @@ namespace infini::rt { +namespace tensor_view_detail { + +template +struct IsTensorLike : std::false_type {}; + +template +struct IsTensorLike().data()), + decltype(std::declval().shape()), + decltype(std::declval().dtype()), + decltype(std::declval().device()), + decltype(std::declval().strides())>> + : std::true_type {}; + +} // namespace tensor_view_detail + class TensorView { public: using Size = std::size_t; @@ -23,7 +41,9 @@ class TensorView { using Strides = std::vector; - template + template ::value>> TensorView(const TensorLike& tensor) : data_{const_cast(static_cast(tensor.data()))}, shape_{tensor.shape()}, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index ab54530..f24954c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -7,12 +7,20 @@ endfunction() add_infini_rt_test(test_smoke test_smoke.cc) add_infini_rt_test(test_core test_core.cc) -if(WITH_CPU) - add_infini_rt_test(test_cpu_runtime test_cpu_runtime.cc) -endif() - if(WITH_CPU OR WITH_NVIDIA) add_infini_rt_test(test_runtime_dispatch test_runtime_dispatch.cc) + if(WITH_CPU) + target_compile_definitions(test_runtime_dispatch + PRIVATE INFINI_RT_TEST_WITH_CPU=1) + endif() + if(WITH_NVIDIA) + target_compile_definitions(test_runtime_dispatch + PRIVATE INFINI_RT_TEST_WITH_NVIDIA=1) + endif() +endif() + +if(WITH_CPU) + add_infini_rt_test(test_cpu_runtime test_cpu_runtime.cc) endif() if(WITH_NVIDIA) @@ -24,12 +32,20 @@ set(INFINI_RT_TEST_INSTALL_PREFIX set(INFINI_RT_TEST_CONSUMER_BINARY "${CMAKE_CURRENT_BINARY_DIR}/install_consumer_smoke") set(INFINI_RT_TEST_EXTRA_LIBRARY_DIRS "") +set(INFINI_RT_TEST_EXTRA_INCLUDE_DIRS "") set(INFINI_RT_TEST_CONSUMER_BACKEND NONE) -if(WITH_CPU) - set(INFINI_RT_TEST_CONSUMER_BACKEND CPU) -elseif(WITH_NVIDIA) +if(WITH_NVIDIA) set(INFINI_RT_TEST_CONSUMER_BACKEND NVIDIA) + if(CUDAToolkit_INCLUDE_DIRS) + list(APPEND INFINI_RT_TEST_EXTRA_INCLUDE_DIRS + ${CUDAToolkit_INCLUDE_DIRS}) + elseif(CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES) + list(APPEND INFINI_RT_TEST_EXTRA_INCLUDE_DIRS + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + endif() +elseif(WITH_CPU) + set(INFINI_RT_TEST_CONSUMER_BACKEND CPU) endif() if(WITH_ASCEND) @@ -44,6 +60,7 @@ if(WITH_ASCEND) endif() list(JOIN INFINI_RT_TEST_EXTRA_LIBRARY_DIRS ":" INFINI_RT_TEST_EXTRA_LIBRARY_PATHS) +list(JOIN INFINI_RT_TEST_EXTRA_INCLUDE_DIRS ":" INFINI_RT_TEST_EXTRA_INCLUDE_PATHS) add_test( NAME test_install @@ -59,6 +76,7 @@ add_test( "-DINFINI_RT_CONSUMER_BINARY=${INFINI_RT_TEST_CONSUMER_BINARY}" "-DINFINI_RT_CXX_COMPILER=${CMAKE_CXX_COMPILER}" "-DINFINI_RT_EXTRA_LIBRARY_PATHS=${INFINI_RT_TEST_EXTRA_LIBRARY_PATHS}" + "-DINFINI_RT_EXTRA_INCLUDE_PATHS=${INFINI_RT_TEST_EXTRA_INCLUDE_PATHS}" "-DINFINI_RT_CONSUMER_BACKEND=${INFINI_RT_TEST_CONSUMER_BACKEND}" -P "${CMAKE_CURRENT_SOURCE_DIR}/compile_install_consumer.cmake") set_tests_properties(test_install_consumer PROPERTIES diff --git a/tests/compile_install_consumer.cmake b/tests/compile_install_consumer.cmake index 169311e..1bb96d1 100644 --- a/tests/compile_install_consumer.cmake +++ b/tests/compile_install_consumer.cmake @@ -17,6 +17,17 @@ set(INFINI_RT_EXTRA_LINK_ARGS "") set(INFINI_RT_EXTRA_COMPILE_ARGS "") set(INFINI_RT_LD_LIBRARY_PATH "${INFINI_RT_LIBRARY_DIR}") +if(INFINI_RT_EXTRA_INCLUDE_PATHS) + string(REPLACE ":" ";" INFINI_RT_EXTRA_INCLUDE_DIRS + "${INFINI_RT_EXTRA_INCLUDE_PATHS}") + foreach(INFINI_RT_EXTRA_INCLUDE_DIR ${INFINI_RT_EXTRA_INCLUDE_DIRS}) + if(EXISTS "${INFINI_RT_EXTRA_INCLUDE_DIR}") + list(APPEND INFINI_RT_EXTRA_COMPILE_ARGS + "-I${INFINI_RT_EXTRA_INCLUDE_DIR}") + endif() + endforeach() +endif() + if(INFINI_RT_CONSUMER_BACKEND AND NOT INFINI_RT_CONSUMER_BACKEND STREQUAL "NONE") list(APPEND INFINI_RT_EXTRA_COMPILE_ARGS "-DINFINI_RT_CONSUMER_BACKEND_${INFINI_RT_CONSUMER_BACKEND}=1") diff --git a/tests/install_consumer_smoke.cc b/tests/install_consumer_smoke.cc index 97442b5..0684763 100644 --- a/tests/install_consumer_smoke.cc +++ b/tests/install_consumer_smoke.cc @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -20,19 +21,73 @@ int main() { #if defined(INFINI_RT_CONSUMER_BACKEND_CPU) || \ defined(INFINI_RT_CONSUMER_BACKEND_NVIDIA) -#if defined(INFINI_RT_CONSUMER_BACKEND_NVIDIA) - const infini::rt::Device runtime_device{infini::rt::Device::Type::kNvidia}; + namespace runtime = infini::rt::runtime; +#if defined(INFINI_RT_CONSUMER_BACKEND_CPU) + constexpr auto kExpectedDeviceType = infini::rt::Device::Type::kCpu; #else - const infini::rt::Device runtime_device{infini::rt::Device::Type::kCpu}; + constexpr auto kExpectedDeviceType = infini::rt::Device::Type::kNvidia; #endif + infini::rt::set_runtime_device_type(kExpectedDeviceType); + if (infini::rt::runtime_device_type() != kExpectedDeviceType) { + return 1; + } + std::array input{1, 2, 3, 4}; + std::array output{}; void* ptr = nullptr; - infini::rt::SetDevice(runtime_device); - infini::rt::Malloc(&ptr, sizeof(std::uint32_t)); + if (runtime::SetDevice(0) != runtime::kSuccess) { + return 1; + } + int current_device = -1; + if (runtime::GetDevice(¤t_device) != runtime::kSuccess) { + return 1; + } + if (current_device != 0) { + return 1; + } + int device_count = 0; + if (runtime::GetDeviceCount(&device_count) != runtime::kSuccess) { + return 1; + } + if (device_count <= 0) { + return 1; + } + if (runtime::Malloc(&ptr, input.size()) != runtime::kSuccess) { + return 1; + } if (ptr == nullptr) { return 1; } - infini::rt::Free(ptr); + if (runtime::Memcpy(ptr, input.data(), input.size(), + runtime::kMemcpyHostToDevice) != runtime::kSuccess) { + return 1; + } +#if defined(INFINI_RT_CONSUMER_BACKEND_CPU) + if (runtime::MemcpyAsync(ptr, input.data(), input.size(), + runtime::kMemcpyHostToDevice, + nullptr) == runtime::kSuccess) { + return 1; + } +#else + if (runtime::MemcpyAsync(ptr, input.data(), input.size(), + runtime::kMemcpyHostToDevice, + nullptr) != runtime::kSuccess) { + return 1; + } +#endif + if (runtime::DeviceSynchronize() != runtime::kSuccess) { + return 1; + } + if (runtime::Memcpy(output.data(), ptr, output.size(), + runtime::kMemcpyDeviceToHost) != runtime::kSuccess) { + return 1; + } + if (output != input) { + return 1; + } + if (runtime::Free(ptr) != runtime::kSuccess) { + return 1; + } #endif return 0; diff --git a/tests/test_core.cc b/tests/test_core.cc index e8b8071..273ecef 100644 --- a/tests/test_core.cc +++ b/tests/test_core.cc @@ -1,7 +1,9 @@ #include +#include #include #include +#include #include #include "test_helper.h" @@ -12,6 +14,9 @@ using infini::rt::DataType; using infini::rt::Device; using infini::rt::TensorView; +static_assert(!std::is_constructible_v>, + "TensorView should not treat tensor containers as tensor-like."); + void TestDevice(infini::rt::test::TestContext* context) { const Device cpu{Device::Type::kCpu}; const Device nvidia{Device::Type::kNvidia, 1}; diff --git a/tests/test_cpu_runtime.cc b/tests/test_cpu_runtime.cc index 20edce9..e066c85 100644 --- a/tests/test_cpu_runtime.cc +++ b/tests/test_cpu_runtime.cc @@ -9,7 +9,7 @@ namespace { -using CpuRuntime = infini::rt::Runtime; +using CpuRuntime = infini::rt::runtime::Runtime; void TestMallocAndFree(infini::rt::test::TestContext* context) { void* ptr = nullptr; @@ -32,15 +32,26 @@ void TestMemcpyRoundTrip(infini::rt::test::TestContext* context) { } CpuRuntime::Memcpy(ptr, input.data(), input.size(), - CpuRuntime::MemcpyHostToDevice); + CpuRuntime::kMemcpyHostToDevice); CpuRuntime::Memcpy(output.data(), ptr, output.size(), - CpuRuntime::MemcpyDeviceToHost); + CpuRuntime::kMemcpyDeviceToHost); CpuRuntime::Free(ptr); context->ExpectEqual(output, input, "CPU runtime should copy data through runtime memory."); } +void TestMemcpyAsyncUnsupported(infini::rt::test::TestContext* context) { + std::array input{1}; + std::array output{}; + + context->Expect( + CpuRuntime::MemcpyAsync(output.data(), input.data(), input.size(), + CpuRuntime::kMemcpyHostToHost, + nullptr) != CpuRuntime::kSuccess, + "CPU runtime should not report async memcpy success."); +} + void TestMemset(infini::rt::test::TestContext* context) { std::array output{}; void* ptr = nullptr; @@ -53,7 +64,7 @@ void TestMemset(infini::rt::test::TestContext* context) { CpuRuntime::Memset(ptr, 0x5A, output.size()); CpuRuntime::Memcpy(output.data(), ptr, output.size(), - CpuRuntime::MemcpyDeviceToHost); + CpuRuntime::kMemcpyDeviceToHost); CpuRuntime::Free(ptr); for (const auto value : output) { @@ -69,6 +80,7 @@ int main() { TestMallocAndFree(&context); TestMemcpyRoundTrip(&context); + TestMemcpyAsyncUnsupported(&context); TestMemset(&context); return context.ExitCode(); diff --git a/tests/test_nvidia_runtime.cc b/tests/test_nvidia_runtime.cc index 2d3dac2..fa10858 100644 --- a/tests/test_nvidia_runtime.cc +++ b/tests/test_nvidia_runtime.cc @@ -9,7 +9,8 @@ namespace { -using NvidiaRuntime = infini::rt::Runtime; +using NvidiaRuntime = + infini::rt::runtime::Runtime; void ExpectCudaSuccess(infini::rt::test::TestContext* context, cudaError_t status, const char* message) { @@ -41,11 +42,11 @@ void TestMemcpyRoundTrip(infini::rt::test::TestContext* context) { ExpectCudaSuccess(context, NvidiaRuntime::Memcpy(ptr, input.data(), input.size(), - NvidiaRuntime::MemcpyHostToDevice), + NvidiaRuntime::kMemcpyHostToDevice), "NVIDIA runtime should copy host data to device memory."); ExpectCudaSuccess(context, NvidiaRuntime::Memcpy(output.data(), ptr, output.size(), - NvidiaRuntime::MemcpyDeviceToHost), + NvidiaRuntime::kMemcpyDeviceToHost), "NVIDIA runtime should copy device data to host memory."); ExpectCudaSuccess(context, NvidiaRuntime::Free(ptr), "NVIDIA runtime should free copy memory."); @@ -68,7 +69,7 @@ void TestMemset(infini::rt::test::TestContext* context) { "NVIDIA runtime should memset device memory."); ExpectCudaSuccess(context, NvidiaRuntime::Memcpy(output.data(), ptr, output.size(), - NvidiaRuntime::MemcpyDeviceToHost), + NvidiaRuntime::kMemcpyDeviceToHost), "NVIDIA runtime should copy memset data to host memory."); ExpectCudaSuccess(context, NvidiaRuntime::Free(ptr), "NVIDIA runtime should free memset memory."); diff --git a/tests/test_runtime_dispatch.cc b/tests/test_runtime_dispatch.cc index 9be92a7..347d6e8 100644 --- a/tests/test_runtime_dispatch.cc +++ b/tests/test_runtime_dispatch.cc @@ -8,58 +8,152 @@ namespace { -infini::rt::Device RuntimeTestDevice() { -#if defined(WITH_NVIDIA) - return infini::rt::Device{infini::rt::Device::Type::kNvidia}; -#else - return infini::rt::Device{infini::rt::Device::Type::kCpu}; -#endif +namespace runtime = infini::rt::runtime; + +void ExpectSuccess(infini::rt::test::TestContext* context, + runtime::Error status, const char* message) { + context->Expect(status == runtime::kSuccess, message); } -} // namespace +#if defined(INFINI_RT_TEST_WITH_CPU) +void TestCpuDispatch(infini::rt::test::TestContext* context) { + infini::rt::set_runtime_device_type(infini::rt::Device::Type::kCpu); + context->Expect( + infini::rt::runtime_device_type() == infini::rt::Device::Type::kCpu, + "Runtime dispatch should report CPU dispatch."); -int main() { - infini::rt::test::TestContext context; - const infini::rt::Device device = RuntimeTestDevice(); std::array input{1, 2, 3, 4}; std::array output{}; void* ptr = nullptr; - infini::rt::SetDevice(device); + ExpectSuccess(context, runtime::SetDevice(0), + "CPU dispatch should set device 0."); + int current_device = -1; + ExpectSuccess(context, runtime::GetDevice(¤t_device), + "CPU dispatch should get the current device."); + context->ExpectEqual(current_device, 0, + "CPU dispatch should keep the current device."); + + int device_count = 0; + ExpectSuccess(context, runtime::GetDeviceCount(&device_count), + "CPU dispatch should get the device count."); + context->Expect(device_count > 0, + "CPU dispatch should report at least one device."); + + ExpectSuccess(context, runtime::Malloc(&ptr, input.size()), + "CPU dispatch should allocate memory."); + if (ptr == nullptr) { + return; + } + + ExpectSuccess(context, + runtime::Memcpy(ptr, input.data(), input.size(), + runtime::kMemcpyHostToDevice), + "CPU dispatch should copy host data to runtime memory."); + context->Expect(runtime::MemcpyAsync(ptr, input.data(), input.size(), + runtime::kMemcpyHostToDevice, + nullptr) != runtime::kSuccess, + "CPU dispatch should not report async memcpy success."); + ExpectSuccess(context, + runtime::Memcpy(output.data(), ptr, output.size(), + runtime::kMemcpyDeviceToHost), + "CPU dispatch should copy runtime memory to host."); + + context->ExpectEqual(output, input, + "CPU dispatch should preserve copied bytes."); + + ExpectSuccess(context, runtime::Memset(ptr, 0x5A, output.size()), + "CPU dispatch should fill runtime memory."); + ExpectSuccess(context, + runtime::Memcpy(output.data(), ptr, output.size(), + runtime::kMemcpyDeviceToHost), + "CPU dispatch should copy filled memory to host."); + for (const auto value : output) { + context->ExpectEqual(value, static_cast(0x5A), + "CPU dispatch should preserve filled bytes."); + } + + ExpectSuccess(context, runtime::Free(ptr), + "CPU dispatch should free memory."); +} +#endif + +#if defined(INFINI_RT_TEST_WITH_NVIDIA) +void TestNvidiaDispatch(infini::rt::test::TestContext* context) { + infini::rt::set_runtime_device_type(infini::rt::Device::Type::kNvidia); + context->Expect( + infini::rt::runtime_device_type() == infini::rt::Device::Type::kNvidia, + "Runtime dispatch should report NVIDIA dispatch."); + + std::array input{5, 6, 7, 8}; + std::array output{}; + void* ptr = nullptr; - infini::rt::Device current_device; - infini::rt::GetDevice(¤t_device); - context.ExpectEqual(current_device, device, - "Runtime dispatch should keep the current device."); + ExpectSuccess(context, runtime::SetDevice(0), + "NVIDIA dispatch should set device 0."); + int current_device = -1; + ExpectSuccess(context, runtime::GetDevice(¤t_device), + "NVIDIA dispatch should get the current device."); + context->ExpectEqual(current_device, 0, + "NVIDIA dispatch should keep the current device."); int device_count = 0; - infini::rt::GetDeviceCount(&device_count, device.type()); - context.Expect(device_count > 0, - "Runtime dispatch should report at least one device."); + ExpectSuccess(context, runtime::GetDeviceCount(&device_count), + "NVIDIA dispatch should get the device count."); + context->Expect(device_count > 0, + "NVIDIA dispatch should report at least one device."); - infini::rt::Malloc(&ptr, input.size()); - context.Expect(ptr != nullptr, "Runtime dispatch should allocate memory."); + ExpectSuccess(context, runtime::Malloc(&ptr, input.size()), + "NVIDIA dispatch should allocate memory."); if (ptr == nullptr) { - return context.ExitCode(); + return; } - infini::rt::Memcpy(ptr, input.data(), input.size(), - infini::rt::MemcpyKind::kHostToDevice); - infini::rt::Memcpy(output.data(), ptr, output.size(), - infini::rt::MemcpyKind::kDeviceToHost); - context.ExpectEqual(output, input, - "Runtime dispatch should copy data through memory."); + ExpectSuccess( + context, + runtime::MemcpyAsync(ptr, input.data(), input.size(), + runtime::kMemcpyHostToDevice, nullptr), + "NVIDIA dispatch should support async host-to-device copy."); + ExpectSuccess(context, runtime::DeviceSynchronize(), + "NVIDIA dispatch should synchronize the device."); + ExpectSuccess(context, + runtime::Memcpy(output.data(), ptr, output.size(), + runtime::kMemcpyDeviceToHost), + "NVIDIA dispatch should copy device data to host."); - infini::rt::Memset(ptr, 0x5A, output.size()); - infini::rt::Memcpy(output.data(), ptr, output.size(), - infini::rt::MemcpyKind::kDeviceToHost); + context->ExpectEqual(output, input, + "NVIDIA dispatch should preserve copied bytes."); + + ExpectSuccess(context, runtime::Memset(ptr, 0x5A, output.size()), + "NVIDIA dispatch should fill runtime memory."); + ExpectSuccess(context, runtime::DeviceSynchronize(), + "NVIDIA dispatch should synchronize filled memory."); + ExpectSuccess(context, + runtime::Memcpy(output.data(), ptr, output.size(), + runtime::kMemcpyDeviceToHost), + "NVIDIA dispatch should copy filled memory to host."); for (const auto value : output) { - context.ExpectEqual(value, static_cast(0x5A), - "Runtime dispatch should fill memory."); + context->ExpectEqual(value, static_cast(0x5A), + "NVIDIA dispatch should preserve filled bytes."); } - infini::rt::DeviceSynchronize(); - infini::rt::Free(ptr); + ExpectSuccess(context, runtime::Free(ptr), + "NVIDIA dispatch should free memory."); +} +#endif + +} // namespace + +int main() { + infini::rt::test::TestContext context; + +#if defined(INFINI_RT_TEST_WITH_CPU) + TestCpuDispatch(&context); +#endif + +#if defined(INFINI_RT_TEST_WITH_NVIDIA) + TestNvidiaDispatch(&context); +#endif return context.ExitCode(); }