From 187c34a40e97aa4bd85e11aea4ef7cb3d53194ff Mon Sep 17 00:00:00 2001 From: zhushuang Date: Thu, 18 Jun 2026 03:11:05 +0000 Subject: [PATCH] feat: refactor InfiniCore cpu runtime to InfiniRT --- scripts/generate_public_headers.py | 115 +++++++++++------ src/native/cpu/runtime_.h | 196 ++++++++++++++++++++++++++--- src/runtime.h | 81 ++++++++++-- 3 files changed, 321 insertions(+), 71 deletions(-) diff --git a/scripts/generate_public_headers.py b/scripts/generate_public_headers.py index 93ceee3..71b2c4f 100644 --- a/scripts/generate_public_headers.py +++ b/scripts/generate_public_headers.py @@ -210,24 +210,28 @@ def _parse_runtime_functions(runtime_header): _Function( return_type, name, - tuple(_parse_param(param) for param in params.split(", ") if param), + tuple( + _parse_param(param) + for param in re.split(r",\s*", params.strip()) + if param + ), ) for return_type, name, params in re.findall( - r"^(void) ([A-Z]\w*)\(([^()]*)\);$", text, re.MULTILINE + r"^(infini::rt::Error|Error|void) ([A-Z]\w*)\(([^()]*)\);$", + text, + re.MULTILINE, ) ) -def _abort_statement(message): - return f""" assert(false && "{message}"); - std::abort();""" +def _unsupported_statement(): + return " return infini::rt::kUnSuccess;" def _dispatch_cases(devices, statements): return "\n".join( f""" case {_DEVICE_TYPES[device]}: {{ {statements.replace("__DEVICE_TYPE__", _DEVICE_TYPES[device])} - return; }}""" for device in devices ) @@ -243,19 +247,21 @@ def _selector(function): return "current_device.type()" -def _runtime_arg(param): +def _runtime_arg(function, param): if param.type == "Device": - return f"{param.name}.index()" + if function.name in {"SetDevice", "GetDeviceResourceSnapshot"}: + return f"{param.name}.index()" + return None if param.type == "Device::Type": return None - if param.type == "MemcpyKind": + if param.type in {"MemcpyKind", "infini::rt::MemcpyKind"}: return f"RuntimeMemcpyKind<__DEVICE_TYPE__>({param.name})" return param.name def _runtime_args(function): - args = (_runtime_arg(param) for param in function.params) + args = (_runtime_arg(function, param) for param in function.params) return ", ".join(arg for arg in args if arg is not None) @@ -267,17 +273,21 @@ def _preconditions(function): } checks = [] for param in function.params: - if param.type.endswith("**") or param.name in required_pointer_names.get( - function.name, set() + if ( + param.type.endswith("**") + or param.type.endswith("*") + or param.name in required_pointer_names.get(function.name, set()) ): - checks.append(f" assert({param.name} != nullptr);") + checks.append(f" if ({param.name} == nullptr) {{") + checks.append(" return infini::rt::kUnSuccess;") + checks.append(" }") return "\n".join(checks) def _post_dispatch(function): if function.name == "SetDevice": - return "\n current_device = device;" + return "\n if (rt_status == infini::rt::kSuccess) {\n current_device = Device{current_device.type(), device};\n }" return "" @@ -294,19 +304,23 @@ def _write_get_device(function, devices): device_param = function.params[0].name cases = _dispatch_cases( devices, - f""" int index = current_device.index(); - CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GetDevice(&index); }}); - current_device = Device{{current_device.type(), index}}; - *{device_param} = current_device;""", + f""" infini::rt::Error status = CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GetDevice({device_param}); }}); + if (status != infini::rt::kSuccess) {{ + return status; + }} + current_device = Device{{current_device.type(), *{device_param}}}; + return infini::rt::kSuccess;""", ) - return f"""void GetDevice(Device* {device_param}) {{ - assert({device_param} != nullptr); + return f"""{function.return_type} GetDevice(int* {device_param}) {{ + if ({device_param} == nullptr) {{ + return infini::rt::kUnSuccess; + }} switch (current_device.type()) {{ {cases} default: -{_abort_statement("runtime device is not enabled")} +{_unsupported_statement()} }} }} """ @@ -318,7 +332,8 @@ def _write_dispatch_function(function, devices): cases = _dispatch_cases( devices, - f""" CheckCall([&] {{ return {_runtime_call(function)}; }});{_post_dispatch(function)}""", + f""" infini::rt::Error rt_status = CheckCall([&] {{ return {_runtime_call(function)}; }});{_post_dispatch(function)} + return rt_status;""", ) preconditions = _preconditions(function) if preconditions: @@ -328,25 +343,42 @@ def _write_dispatch_function(function, devices): {preconditions} switch ({_selector(function)}) {{ {cases} default: -{_abort_statement("runtime device is not enabled")} +{_unsupported_statement()} }} }} """ +def _runtime_header_for_device(source_root, device): + return source_root / _RUNTIME_HEADERS[device] + + +def _devices_for_function(function, devices, source_root): + enabled = [] + pattern = re.compile(r"\b" + re.escape(function.name) + r"\b") + for device in devices: + runtime_header = _runtime_header_for_device(source_root, device) + if runtime_header.exists() and pattern.search(runtime_header.read_text()): + enabled.append(device) + return enabled + + def _write_runtime_dispatch(source_path, runtime_header, devices): first_device_type = _DEVICE_TYPES[devices[0]] includes = ['#include "runtime.h"'] includes.extend(f'#include "{_RUNTIME_HEADERS[device]}"' for device in devices) functions = _parse_runtime_functions(runtime_header) + source_root = pathlib.Path(runtime_header).parent dispatch_functions = "\n".join( - _write_dispatch_function(function, devices) for function in functions + _write_dispatch_function( + function, _devices_for_function(function, devices, source_root) + ) + for function in functions ) source_path.parent.mkdir(parents=True, exist_ok=True) source_path.write_text( f"""#include -#include #include #include @@ -358,36 +390,39 @@ def _write_runtime_dispatch(source_path, runtime_header, devices): thread_local Device current_device{{{first_device_type}, 0}}; template -void CheckCall(Func&& func) {{ +infini::rt::Error CheckCall(Func&& func) {{ using ReturnType = decltype(std::forward(func)()); if constexpr (std::is_void_v) {{ std::forward(func)(); + return infini::rt::kSuccess; }} else {{ ReturnType status = std::forward(func)(); - if (status != ReturnType{{}}) {{ - assert(false && "runtime call failed"); - std::abort(); + if constexpr (std::is_same_v) {{ + return status == infini::rt::kSuccess ? infini::rt::kSuccess + : infini::rt::kUnSuccess; + }} else {{ + return status == ReturnType{{}} ? infini::rt::kSuccess + : infini::rt::kUnSuccess; }} }} }} template -auto RuntimeMemcpyKind(MemcpyKind kind) {{ +auto RuntimeMemcpyKind(infini::rt::MemcpyKind kind) {{ switch (kind) {{ - case MemcpyKind::kHostToHost: - return Runtime::MemcpyHostToHost; - case MemcpyKind::kHostToDevice: - return Runtime::MemcpyHostToDevice; - case MemcpyKind::kDeviceToHost: - return Runtime::MemcpyDeviceToHost; - case MemcpyKind::kDeviceToDevice: - return Runtime::MemcpyDeviceToDevice; + case infini::rt::MemcpyKind::kMemcpyHostToHost: + return Runtime::kMemcpyHostToHost; + case infini::rt::MemcpyKind::kMemcpyHostToDevice: + return Runtime::kMemcpyHostToDevice; + case infini::rt::MemcpyKind::kMemcpyDeviceToHost: + return Runtime::kMemcpyDeviceToHost; + case infini::rt::MemcpyKind::kMemcpyDeviceToDevice: + return Runtime::kMemcpyDeviceToDevice; }} assert(false && "unsupported memcpy kind"); - std::abort(); - return Runtime::MemcpyHostToHost; + return Runtime::kMemcpyHostToHost; }} }} // namespace diff --git a/src/native/cpu/runtime_.h b/src/native/cpu/runtime_.h index bf5a81c..d757672 100644 --- a/src/native/cpu/runtime_.h +++ b/src/native/cpu/runtime_.h @@ -1,9 +1,12 @@ #ifndef INFINI_RT_CPU_RUNTIME__H_ #define INFINI_RT_CPU_RUNTIME__H_ -#include +#include +#include +#include #include #include +#include #include "runtime.h" @@ -13,44 +16,199 @@ template <> struct Runtime : RuntimeBase> { static constexpr Device::Type kDeviceType = Device::Type::kCpu; - static void SetDevice(int index) { - if (index != 0) { - assert(false && "CPU device index must be 0"); - std::abort(); - } + static infini::rt::Error SetDevice(int index) { + return index == 0 ? infini::rt::kSuccess : infini::rt::kUnSuccess; } - static void GetDevice(int* index) { - assert(index != nullptr); + static infini::rt::Error GetDevice(int* index) { + if (index == nullptr) { + return infini::rt::kUnSuccess; + } *index = 0; + return infini::rt::kSuccess; } - static void GetDeviceCount(int* count) { - assert(count != nullptr); + static infini::rt::Error GetDeviceCount(int* count) { + if (count == nullptr) { + return infini::rt::kUnSuccess; + } *count = 1; + return infini::rt::kSuccess; } - static void DeviceSynchronize() {} + static infini::rt::Error DeviceSynchronize() { return infini::rt::kSuccess; } - static void Malloc(void** ptr, std::size_t size) { *ptr = std::malloc(size); } + static infini::rt::Error Malloc(void** ptr, std::size_t size) { + if (ptr == nullptr) { + return infini::rt::kUnSuccess; + } + *ptr = std::malloc(size); + return size != 0 && *ptr == nullptr ? infini::rt::kUnSuccess + : infini::rt::kSuccess; + } - static void Free(void* ptr) { std::free(ptr); } + static infini::rt::Error Free(void* ptr) { + std::free(ptr); + return infini::rt::kSuccess; + } - static void Memcpy(void* dst, const void* src, std::size_t size, int) { + static infini::rt::Error Memcpy(void* dst, const void* src, std::size_t size, + infini::rt::MemcpyKind) { + if ((dst == nullptr || src == nullptr) && size != 0) { + return infini::rt::kUnSuccess; + } std::memcpy(dst, src, size); + return infini::rt::kSuccess; } - static void Memset(void* ptr, int value, std::size_t count) { + static infini::rt::Error Memset(void* ptr, int value, std::size_t count) { + if (ptr == nullptr && count != 0) { + return infini::rt::kUnSuccess; + } std::memset(ptr, value, count); + return infini::rt::kSuccess; + } + + static infini::rt::Error MemGetInfo(std::size_t* free_bytes, + std::size_t* total_bytes) { + if (free_bytes == nullptr || total_bytes == nullptr) { + return infini::rt::kUnSuccess; + } + *free_bytes = 0; + *total_bytes = 0; + +#ifndef _WIN32 + FILE* fp = std::fopen("/proc/meminfo", "r"); + if (fp == nullptr) { + return infini::rt::kUnSuccess; + } + + char label[64]; + std::size_t value = 0; + while (std::fscanf(fp, "%63s %zu %*s", label, &value) == 2) { + if (std::strcmp(label, "MemTotal:") == 0) { + *total_bytes = value * 1024; + } else if (std::strcmp(label, "MemAvailable:") == 0) { + *free_bytes = value * 1024; + } + } + std::fclose(fp); +#endif + if (*total_bytes == 0) { + return infini::rt::kUnSuccess; + } + return infini::rt::kSuccess; + } + + static infini::rt::Error StreamCreate(infini::rt::Stream* stream) { + if (stream == nullptr) { + return infini::rt::kUnSuccess; + } + *stream = nullptr; + return infini::rt::kSuccess; + } + + static infini::rt::Error StreamDestroy(infini::rt::Stream) { + return infini::rt::kSuccess; + } + + static infini::rt::Error StreamSynchronize(infini::rt::Stream) { + return infini::rt::kSuccess; + } + + static infini::rt::Error StreamWaitEvent(infini::rt::Stream, + infini::rt::Event, std::uint32_t) { + return infini::rt::kSuccess; + } + + using CpuEvent = std::chrono::steady_clock::time_point; + + static infini::rt::Error EventCreate(infini::rt::Event* event) { + if (event == nullptr) { + return infini::rt::kUnSuccess; + } + *event = new (std::nothrow) CpuEvent(std::chrono::steady_clock::now()); + return *event == nullptr ? infini::rt::kUnSuccess : infini::rt::kSuccess; + } + + static infini::rt::Error EventCreateWithFlags(infini::rt::Event* event, + std::uint32_t) { + return EventCreate(event); + } + + static infini::rt::Error EventRecord(infini::rt::Event event, + infini::rt::Stream) { + if (event == nullptr) { + return infini::rt::kUnSuccess; + } + *static_cast(event) = std::chrono::steady_clock::now(); + return infini::rt::kSuccess; + } + + static infini::rt::Error EventQuery(infini::rt::Event event) { + return event == nullptr ? infini::rt::kUnSuccess : infini::rt::kSuccess; + } + + static infini::rt::Error EventSynchronize(infini::rt::Event event) { + return event == nullptr ? infini::rt::kUnSuccess : infini::rt::kSuccess; + } + + static infini::rt::Error EventDestroy(infini::rt::Event event) { + delete static_cast(event); + return infini::rt::kSuccess; + } + + static infini::rt::Error EventElapsedTime(float* ms, infini::rt::Event start, + infini::rt::Event end) { + if (ms == nullptr || start == nullptr || end == nullptr) { + return infini::rt::kUnSuccess; + } + const auto* start_time = static_cast(start); + const auto* end_time = static_cast(end); + const auto duration = std::chrono::duration_cast( + *end_time - *start_time); + *ms = static_cast(duration.count()) / 1000.0f; + return infini::rt::kSuccess; + } + + static infini::rt::Error MallocHost(void** ptr, std::size_t size) { + return Malloc(ptr, size); + } + + static infini::rt::Error FreeHost(void* ptr) { return Free(ptr); } + + static infini::rt::Error MemcpyAsync(void* dst, const void* src, + std::size_t size, + infini::rt::MemcpyKind kind, + infini::rt::Stream) { + return Memcpy(dst, src, size, kind); + } + + static infini::rt::Error MallocAsync(void** ptr, std::size_t size, + infini::rt::Stream) { + return Malloc(ptr, size); + } + + static infini::rt::Error FreeAsync(void* ptr, infini::rt::Stream) { + return Free(ptr); + } + + static infini::rt::Error MemsetAsync(void* ptr, int value, std::size_t count, + infini::rt::Stream) { + return Memset(ptr, value, count); } - static constexpr int MemcpyHostToHost = 0; + static constexpr auto kMemcpyHostToHost = + infini::rt::MemcpyKind::kMemcpyHostToHost; - static constexpr int MemcpyHostToDevice = 0; + static constexpr auto kMemcpyHostToDevice = + infini::rt::MemcpyKind::kMemcpyHostToDevice; - static constexpr int MemcpyDeviceToHost = 1; + static constexpr auto kMemcpyDeviceToHost = + infini::rt::MemcpyKind::kMemcpyDeviceToHost; - static constexpr int MemcpyDeviceToDevice = 0; + static constexpr auto kMemcpyDeviceToDevice = + infini::rt::MemcpyKind::kMemcpyDeviceToDevice; }; static_assert(Runtime::Validate()); diff --git a/src/runtime.h b/src/runtime.h index ebc2698..e983dca 100644 --- a/src/runtime.h +++ b/src/runtime.h @@ -2,6 +2,7 @@ #define INFINI_RT_RUNTIME_H_ #include +#include #include #include "device.h" @@ -51,28 +52,84 @@ struct DeviceRuntime : RuntimeBase { } }; +enum class Error { + kSuccess = 0, + kUnSuccess = -1, +}; + +inline constexpr infini::rt::Error kSuccess = infini::rt::Error::kSuccess; +inline constexpr infini::rt::Error kUnSuccess = infini::rt::Error::kUnSuccess; + enum class MemcpyKind { - kHostToHost, - kHostToDevice, - kDeviceToHost, - kDeviceToDevice, + kMemcpyHostToHost, + kMemcpyHostToDevice, + kMemcpyDeviceToHost, + kMemcpyDeviceToDevice, }; -void SetDevice(Device device); +using Stream = void*; +using Event = void*; + +infini::rt::Error SetDevice(int device); + +infini::rt::Error GetDevice(int* device); + +infini::rt::Error GetDeviceCount(int* count); + +infini::rt::Error DeviceSynchronize(); + +infini::rt::Error Malloc(void** ptr, std::size_t size); + +infini::rt::Error Free(void* ptr); + +infini::rt::Error Memset(void* ptr, int value, std::size_t count); + +infini::rt::Error Memcpy(void* dst, const void* src, std::size_t count, + infini::rt::MemcpyKind kind); + +infini::rt::Error MallocHost(void** ptr, std::size_t size); + +infini::rt::Error FreeHost(void* ptr); + +infini::rt::Error MemcpyAsync(void* dst, const void* src, std::size_t count, + infini::rt::MemcpyKind kind, + infini::rt::Stream stream); + +infini::rt::Error MallocAsync(void** ptr, std::size_t size, + infini::rt::Stream stream); + +infini::rt::Error FreeAsync(void* ptr, infini::rt::Stream stream); + +infini::rt::Error MemsetAsync(void* ptr, int value, std::size_t count, + infini::rt::Stream stream); + +infini::rt::Error MemGetInfo(std::size_t* free_bytes, std::size_t* total_bytes); + +infini::rt::Error StreamCreate(infini::rt::Stream* stream); + +infini::rt::Error StreamDestroy(infini::rt::Stream stream); + +infini::rt::Error StreamSynchronize(infini::rt::Stream stream); + +infini::rt::Error StreamWaitEvent(infini::rt::Stream stream, + infini::rt::Event event, std::uint32_t flags); -void GetDevice(Device* device); +infini::rt::Error EventCreate(infini::rt::Event* event); -void GetDeviceCount(int* count, Device::Type type); +infini::rt::Error EventCreateWithFlags(infini::rt::Event* event, + std::uint32_t flags); -void DeviceSynchronize(); +infini::rt::Error EventRecord(infini::rt::Event event, + infini::rt::Stream stream); -void Malloc(void** ptr, std::size_t size); +infini::rt::Error EventQuery(infini::rt::Event event); -void Free(void* ptr); +infini::rt::Error EventSynchronize(infini::rt::Event event); -void Memset(void* ptr, int value, std::size_t count); +infini::rt::Error EventDestroy(infini::rt::Event event); -void Memcpy(void* dst, const void* src, std::size_t count, MemcpyKind kind); +infini::rt::Error EventElapsedTime(float* ms, infini::rt::Event start, + infini::rt::Event end); } // namespace infini::rt