diff --git a/scripts/generate_public_headers.py b/scripts/generate_public_headers.py index e929682..6620b42 100644 --- a/scripts/generate_public_headers.py +++ b/scripts/generate_public_headers.py @@ -161,6 +161,7 @@ def _write_generated_header(include_root, devices): default_device_type = _DEVICE_TYPES[default_device] includes = [ "#include ", + "#include ", "#include ", f"#include {_detail_include('data_type.h')}", f"#include {_detail_include('device.h')}", @@ -206,6 +207,8 @@ def _write_generated_header(include_root, devices): using Stream = typename generated_detail::DefaultErrorRuntime::Stream; +using Event = void*; + using MemcpyKind = std::remove_cv_t< decltype(generated_detail::DefaultErrorRuntime::kMemcpyHostToHost)>; @@ -262,16 +265,32 @@ def params_decl(self): "Malloc", (_Param("void**", "ptr"), _Param("std::size_t", "size")), ), - _Function("Error", "Free", (_Param("void*", "ptr"),)), _Function( "Error", - "Memset", + "MallocHost", + (_Param("void**", "ptr"), _Param("std::size_t", "size")), + ), + _Function( + "Error", + "MallocAsync", ( - _Param("void*", "ptr"), - _Param("int", "value"), - _Param("std::size_t", "count"), + _Param("void**", "ptr"), + _Param("std::size_t", "size"), + _Param("Stream", "stream"), ), ), + _Function("Error", "Free", (_Param("void*", "ptr"),)), + _Function("Error", "FreeHost", (_Param("void*", "ptr"),)), + _Function( + "Error", + "FreeAsync", + (_Param("void*", "ptr"), _Param("Stream", "stream")), + ), + _Function( + "Error", + "MemGetInfo", + (_Param("std::size_t*", "free"), _Param("std::size_t*", "total")), + ), _Function( "Error", "Memcpy", @@ -293,6 +312,60 @@ def params_decl(self): _Param("Stream", "stream"), ), ), + _Function( + "Error", + "Memset", + ( + _Param("void*", "ptr"), + _Param("int", "value"), + _Param("std::size_t", "count"), + ), + ), + _Function( + "Error", + "MemsetAsync", + ( + _Param("void*", "ptr"), + _Param("int", "value"), + _Param("std::size_t", "count"), + _Param("Stream", "stream"), + ), + ), + _Function("Error", "StreamCreate", (_Param("Stream*", "stream"),)), + _Function("Error", "StreamDestroy", (_Param("Stream", "stream"),)), + _Function("Error", "StreamSynchronize", (_Param("Stream", "stream"),)), + _Function( + "Error", + "StreamWaitEvent", + ( + _Param("Stream", "stream"), + _Param("Event", "event"), + _Param("unsigned int", "flags"), + ), + ), + _Function("Error", "EventCreate", (_Param("Event*", "event"),)), + _Function( + "Error", + "EventCreateWithFlags", + (_Param("Event*", "event"), _Param("unsigned int", "flags")), + ), + _Function( + "Error", + "EventRecord", + (_Param("Event", "event"), _Param("Stream", "stream")), + ), + _Function("Error", "EventQuery", (_Param("Event", "event"),)), + _Function("Error", "EventSynchronize", (_Param("Event", "event"),)), + _Function("Error", "EventDestroy", (_Param("Event", "event"),)), + _Function( + "Error", + "EventElapsedTime", + ( + _Param("float*", "ms"), + _Param("Event", "start"), + _Param("Event", "end"), + ), + ), ) @@ -312,6 +385,16 @@ def _runtime_arg(param, device): return ( f"reinterpret_cast::Stream>({param.name})" ) + if param.type == "Stream*": + return ( + f"reinterpret_cast::Stream*>({param.name})" + ) + if param.type == "Event": + return f"reinterpret_cast::Event>({param.name})" + if param.type == "Event*": + return ( + f"reinterpret_cast::Event*>({param.name})" + ) return param.name @@ -351,10 +434,31 @@ def _write_runtime_dispatch_function(function, devices): """ -def _write_runtime_dispatch(source_path, devices): +def _runtime_header_for_device(source_root, device): + for _, header_name, target in _DEVICE_HEADERS[device]: + if header_name == "runtime_.h": + return source_root / target + + raise ValueError(f"device {device!r} does not have a runtime header") + + +def _devices_for_function(function, devices, source_root): + pattern = re.compile(r"\b" + re.escape(function.name) + r"\b") + + return tuple( + device + for device in devices + if pattern.search(_runtime_header_for_device(source_root, device).read_text()) + ) + + +def _write_runtime_dispatch(source_path, source_root, devices): functions = _PUBLIC_RUNTIME_FUNCTIONS dispatch_functions = "\n".join( - _write_runtime_dispatch_function(function, devices=devices) + _write_runtime_dispatch_function( + function, + devices=_devices_for_function(function, devices, source_root), + ) for function in functions ) set_device_type_cases = "\n".join( @@ -463,7 +567,7 @@ def main(): _write_wrapper(include_root, wrapper_device, header_name, target) _write_generated_header(include_root, devices) - _write_runtime_dispatch(pathlib.Path(args.source_output), devices) + _write_runtime_dispatch(pathlib.Path(args.source_output), source_root, devices) if __name__ == "__main__": diff --git a/src/native/cpu/runtime_.h b/src/native/cpu/runtime_.h index f4b18bf..c928e1f 100644 --- a/src/native/cpu/runtime_.h +++ b/src/native/cpu/runtime_.h @@ -1,8 +1,11 @@ #ifndef INFINI_RT_CPU_RUNTIME__H_ #define INFINI_RT_CPU_RUNTIME__H_ +#include +#include #include #include +#include #include "runtime.h" @@ -16,6 +19,8 @@ struct Runtime : RuntimeBase> { using Stream = void*; + using Event = void*; + static constexpr Error kSuccess = 0; static constexpr Error kErrorInvalidValue = 1; @@ -66,12 +71,53 @@ struct Runtime : RuntimeBase> { return kSuccess; } + static Error MallocHost(void** ptr, std::size_t size) { + return Malloc(ptr, size); + } + + static Error MallocAsync(void** ptr, std::size_t size, Stream) { + return kErrorInvalidValue; + } + static Error Free(void* ptr) { std::free(ptr); return kSuccess; } + static Error FreeHost(void* ptr) { return Free(ptr); } + + static Error FreeAsync(void* ptr, Stream) { return kErrorInvalidValue; } + + static Error MemGetInfo(std::size_t* free, std::size_t* total) { + if (free == nullptr || total == nullptr) { + return kErrorInvalidValue; + } + + *free = 0; + *total = 0; + +#ifndef _WIN32 + FILE* fp = std::fopen("/proc/meminfo", "r"); + if (fp == nullptr) { + return kErrorInvalidValue; + } + + char label[64]; + std::size_t value = 0; + while (std::fscanf(fp, "%63s %zu %*s", label, &value) == 2) { + if (std::strcmp(label, "MemTotal:") == 0) { + *total = value * 1024; + } else if (std::strcmp(label, "MemAvailable:") == 0) { + *free = value * 1024; + } + } + std::fclose(fp); +#endif + + return *total == 0 ? kErrorInvalidValue : kSuccess; + } + static Error Memcpy(void* dst, const void* src, std::size_t size, int) { if ((dst == nullptr || src == nullptr) && size != 0) { return kErrorInvalidValue; @@ -82,6 +128,11 @@ struct Runtime : RuntimeBase> { return kSuccess; } + static Error MemcpyAsync(void* dst, const void* src, std::size_t size, + int kind, Stream) { + return kErrorInvalidValue; + } + static Error Memset(void* ptr, int value, std::size_t count) { if (ptr == nullptr && count != 0) { return kErrorInvalidValue; @@ -92,11 +143,80 @@ struct Runtime : RuntimeBase> { return kSuccess; } - static Error MemcpyAsync(void* dst, const void* src, std::size_t size, - int kind, Stream) { + static Error MemsetAsync(void* ptr, int value, std::size_t count, Stream) { return kErrorInvalidValue; } + static Error StreamCreate(Stream* stream) { + if (stream == nullptr) { + return kErrorInvalidValue; + } + + *stream = nullptr; + + return kSuccess; + } + + static Error StreamDestroy(Stream) { return kSuccess; } + + static Error StreamSynchronize(Stream) { return kSuccess; } + + static Error StreamWaitEvent(Stream, Event, unsigned int) { return kSuccess; } + + using CpuEvent = std::chrono::steady_clock::time_point; + + static Error EventCreate(Event* event) { + if (event == nullptr) { + return kErrorInvalidValue; + } + + *event = new (std::nothrow) CpuEvent(std::chrono::steady_clock::now()); + + return *event == nullptr ? kErrorMemoryAllocation : kSuccess; + } + + static Error EventCreateWithFlags(Event* event, unsigned int) { + return EventCreate(event); + } + + static Error EventRecord(Event event, Stream) { + if (event == nullptr) { + return kErrorInvalidValue; + } + + *static_cast(event) = std::chrono::steady_clock::now(); + + return kSuccess; + } + + static Error EventQuery(Event event) { + return event == nullptr ? kErrorInvalidValue : kSuccess; + } + + static Error EventSynchronize(Event event) { + return event == nullptr ? kErrorInvalidValue : kSuccess; + } + + static Error EventDestroy(Event event) { + delete static_cast(event); + + return kSuccess; + } + + static Error EventElapsedTime(float* ms, Event start, Event end) { + if (ms == nullptr || start == nullptr || end == nullptr) { + return kErrorInvalidValue; + } + + const auto* start_time = static_cast(start); + const auto* end_time = static_cast(end); + const auto duration = std::chrono::duration_cast( + *end_time - *start_time); + *ms = static_cast(duration.count()) / 1000.0f; + + return kSuccess; + } + static constexpr int kMemcpyHostToHost = 0; static constexpr int kMemcpyHostToDevice = 1;