diff --git a/src/native/ascend/runtime_.h b/src/native/ascend/runtime_.h index 065a9d3..0642027 100644 --- a/src/native/ascend/runtime_.h +++ b/src/native/ascend/runtime_.h @@ -2,6 +2,7 @@ #define INFINI_RT_ASCEND_RUNTIME__H_ #include +#include #include // clang-format off @@ -20,6 +21,8 @@ struct Runtime using Stream = aclrtStream; + using Event = void*; + static constexpr Device::Type kDeviceType = Device::Type::kAscend; static constexpr Error kSuccess = ACL_SUCCESS; @@ -42,8 +45,20 @@ struct Runtime return aclrtMalloc(ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); }; + static Error MallocHost(void**, std::size_t) { return Unsupported(); } + + static Error MallocAsync(void**, std::size_t, Stream) { + return Unsupported(); + } + static constexpr auto Free = aclrtFree; + static Error FreeHost(void*) { return Unsupported(); } + + static Error FreeAsync(void*, Stream) { return Unsupported(); } + + static Error MemGetInfo(std::size_t*, std::size_t*) { return Unsupported(); } + static constexpr auto Memcpy = [](void* dst, const void* src, size_t count, aclrtMemcpyKind kind) { return aclrtMemcpy(dst, count, src, count, kind); @@ -66,6 +81,39 @@ struct Runtime static constexpr auto Memset = [](void* ptr, int value, size_t count) { return aclrtMemset(ptr, count, value, count); }; + + static Error MemsetAsync(void*, int, std::size_t, Stream) { + return Unsupported(); + } + + static constexpr auto StreamCreate = aclrtCreateStream; + + static constexpr auto StreamDestroy = aclrtDestroyStream; + + static constexpr auto StreamSynchronize = aclrtSynchronizeStream; + + static Error StreamWaitEvent(Stream, Event, unsigned int) { + return Unsupported(); + } + + static Error EventCreate(Event*) { return Unsupported(); } + + static Error EventCreateWithFlags(Event*, unsigned int) { + return Unsupported(); + } + + static Error EventRecord(Event, Stream) { return Unsupported(); } + + static Error EventQuery(Event) { return Unsupported(); } + + static Error EventSynchronize(Event) { return Unsupported(); } + + static Error EventDestroy(Event) { return Unsupported(); } + + static Error EventElapsedTime(float*, Event, Event) { return Unsupported(); } + + private: + static Error Unsupported() { return static_cast(1); } }; static_assert(Runtime::Validate()); diff --git a/src/native/cambricon/runtime_.h b/src/native/cambricon/runtime_.h index d655cd9..c4f4112 100644 --- a/src/native/cambricon/runtime_.h +++ b/src/native/cambricon/runtime_.h @@ -19,6 +19,8 @@ struct Runtime using Stream = cnrtQueue_t; + using Event = void*; + static constexpr Device::Type kDeviceType = Device::Type::kCambricon; #ifdef CNRT_RET_SUCCESS @@ -43,8 +45,20 @@ struct Runtime static constexpr auto Malloc = cnrtMalloc; + static Error MallocHost(void**, std::size_t) { return Unsupported(); } + + static Error MallocAsync(void**, std::size_t, Stream) { + return Unsupported(); + } + static constexpr auto Free = cnrtFree; + static Error FreeHost(void*) { return Unsupported(); } + + static Error FreeAsync(void*, Stream) { return Unsupported(); } + + static Error MemGetInfo(std::size_t*, std::size_t*) { return Unsupported(); } + static constexpr auto Memcpy = [](void* dst, const void* src, std::size_t size, auto kind) { return cnrtMemcpy(dst, const_cast(src), size, kind); @@ -70,6 +84,39 @@ struct Runtime static constexpr auto kMemcpyDeviceToDevice = cnrtMemcpyDevToDev; static constexpr auto Memset = cnrtMemset; + + static Error MemsetAsync(void*, int, std::size_t, Stream) { + return Unsupported(); + } + + static constexpr auto StreamCreate = cnrtQueueCreate; + + static constexpr auto StreamDestroy = cnrtQueueDestroy; + + static constexpr auto StreamSynchronize = cnrtQueueSync; + + static Error StreamWaitEvent(Stream, Event, unsigned int) { + return Unsupported(); + } + + static Error EventCreate(Event*) { return Unsupported(); } + + static Error EventCreateWithFlags(Event*, unsigned int) { + return Unsupported(); + } + + static Error EventRecord(Event, Stream) { return Unsupported(); } + + static Error EventQuery(Event) { return Unsupported(); } + + static Error EventSynchronize(Event) { return Unsupported(); } + + static Error EventDestroy(Event) { return Unsupported(); } + + static Error EventElapsedTime(float*, Event, Event) { return Unsupported(); } + + private: + static Error Unsupported() { return static_cast(1); } }; static_assert(Runtime::Validate()); diff --git a/src/native/cuda/hygon/runtime_.h b/src/native/cuda/hygon/runtime_.h index 52c47eb..5051a87 100644 --- a/src/native/cuda/hygon/runtime_.h +++ b/src/native/cuda/hygon/runtime_.h @@ -19,6 +19,8 @@ struct Runtime using Stream = cudaStream_t; + using Event = cudaEvent_t; + static constexpr Device::Type kDeviceType = Device::Type::kHygon; static constexpr Error kSuccess = cudaSuccess; @@ -35,14 +37,34 @@ struct Runtime return cudaMalloc(std::forward(args)...); }; - static constexpr auto Memcpy = cudaMemcpy; + static constexpr auto MallocHost = [](auto&&... args) { + return cudaMallocHost(std::forward(args)...); + }; - static constexpr auto MemcpyAsync = cudaMemcpyAsync; + static constexpr auto MallocAsync = [](auto&&... args) { + return cudaMallocAsync(std::forward(args)...); + }; static constexpr auto Free = [](auto&&... args) { return cudaFree(std::forward(args)...); }; + static constexpr auto FreeHost = [](auto&&... args) { + return cudaFreeHost(std::forward(args)...); + }; + + static constexpr auto FreeAsync = [](auto&&... args) { + return cudaFreeAsync(std::forward(args)...); + }; + + static constexpr auto MemGetInfo = [](auto&&... args) { + return cudaMemGetInfo(std::forward(args)...); + }; + + static constexpr auto Memcpy = cudaMemcpy; + + static constexpr auto MemcpyAsync = cudaMemcpyAsync; + static constexpr auto kMemcpyHostToHost = cudaMemcpyHostToHost; static constexpr auto kMemcpyHostToDevice = cudaMemcpyHostToDevice; @@ -52,6 +74,54 @@ struct Runtime static constexpr auto kMemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; static constexpr auto Memset = cudaMemset; + + static constexpr auto MemsetAsync = [](auto&&... args) { + return cudaMemsetAsync(std::forward(args)...); + }; + + static constexpr auto StreamCreate = [](auto&&... args) { + return cudaStreamCreate(std::forward(args)...); + }; + + static constexpr auto StreamDestroy = [](auto&&... args) { + return cudaStreamDestroy(std::forward(args)...); + }; + + static constexpr auto StreamSynchronize = [](auto&&... args) { + return cudaStreamSynchronize(std::forward(args)...); + }; + + static constexpr auto StreamWaitEvent = [](auto&&... args) { + return cudaStreamWaitEvent(std::forward(args)...); + }; + + static constexpr auto EventCreate = [](auto&&... args) { + return cudaEventCreate(std::forward(args)...); + }; + + static constexpr auto EventCreateWithFlags = [](auto&&... args) { + return cudaEventCreateWithFlags(std::forward(args)...); + }; + + static constexpr auto EventRecord = [](auto&&... args) { + return cudaEventRecord(std::forward(args)...); + }; + + static constexpr auto EventQuery = [](auto&&... args) { + return cudaEventQuery(std::forward(args)...); + }; + + static constexpr auto EventSynchronize = [](auto&&... args) { + return cudaEventSynchronize(std::forward(args)...); + }; + + static constexpr auto EventDestroy = [](auto&&... args) { + return cudaEventDestroy(std::forward(args)...); + }; + + static constexpr auto EventElapsedTime = [](auto&&... args) { + return cudaEventElapsedTime(std::forward(args)...); + }; }; static_assert(Runtime::Validate()); diff --git a/src/native/cuda/iluvatar/runtime_.h b/src/native/cuda/iluvatar/runtime_.h index f49db23..9425559 100644 --- a/src/native/cuda/iluvatar/runtime_.h +++ b/src/native/cuda/iluvatar/runtime_.h @@ -19,6 +19,8 @@ struct Runtime using Stream = cudaStream_t; + using Event = cudaEvent_t; + static constexpr Device::Type kDeviceType = Device::Type::kIluvatar; static constexpr Error kSuccess = cudaSuccess; @@ -35,12 +37,32 @@ struct Runtime return cudaMalloc(std::forward(args)...); }; - static constexpr auto Memcpy = cudaMemcpy; + static constexpr auto MallocHost = [](auto&&... args) { + return cudaMallocHost(std::forward(args)...); + }; - static constexpr auto MemcpyAsync = cudaMemcpyAsync; + static constexpr auto MallocAsync = [](auto&&... args) { + return cudaMallocAsync(std::forward(args)...); + }; static constexpr auto Free = cudaFree; + static constexpr auto FreeHost = [](auto&&... args) { + return cudaFreeHost(std::forward(args)...); + }; + + static constexpr auto FreeAsync = [](auto&&... args) { + return cudaFreeAsync(std::forward(args)...); + }; + + static constexpr auto MemGetInfo = [](auto&&... args) { + return cudaMemGetInfo(std::forward(args)...); + }; + + static constexpr auto Memcpy = cudaMemcpy; + + static constexpr auto MemcpyAsync = cudaMemcpyAsync; + static constexpr auto kMemcpyHostToHost = cudaMemcpyHostToHost; static constexpr auto kMemcpyHostToDevice = cudaMemcpyHostToDevice; @@ -50,6 +72,54 @@ struct Runtime static constexpr auto kMemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; static constexpr auto Memset = cudaMemset; + + static constexpr auto MemsetAsync = [](auto&&... args) { + return cudaMemsetAsync(std::forward(args)...); + }; + + static constexpr auto StreamCreate = [](auto&&... args) { + return cudaStreamCreate(std::forward(args)...); + }; + + static constexpr auto StreamDestroy = [](auto&&... args) { + return cudaStreamDestroy(std::forward(args)...); + }; + + static constexpr auto StreamSynchronize = [](auto&&... args) { + return cudaStreamSynchronize(std::forward(args)...); + }; + + static constexpr auto StreamWaitEvent = [](auto&&... args) { + return cudaStreamWaitEvent(std::forward(args)...); + }; + + static constexpr auto EventCreate = [](auto&&... args) { + return cudaEventCreate(std::forward(args)...); + }; + + static constexpr auto EventCreateWithFlags = [](auto&&... args) { + return cudaEventCreateWithFlags(std::forward(args)...); + }; + + static constexpr auto EventRecord = [](auto&&... args) { + return cudaEventRecord(std::forward(args)...); + }; + + static constexpr auto EventQuery = [](auto&&... args) { + return cudaEventQuery(std::forward(args)...); + }; + + static constexpr auto EventSynchronize = [](auto&&... args) { + return cudaEventSynchronize(std::forward(args)...); + }; + + static constexpr auto EventDestroy = [](auto&&... args) { + return cudaEventDestroy(std::forward(args)...); + }; + + static constexpr auto EventElapsedTime = [](auto&&... args) { + return cudaEventElapsedTime(std::forward(args)...); + }; }; static_assert(Runtime::Validate()); diff --git a/src/native/cuda/metax/runtime_.h b/src/native/cuda/metax/runtime_.h index c1f19c0..462f5f1 100644 --- a/src/native/cuda/metax/runtime_.h +++ b/src/native/cuda/metax/runtime_.h @@ -17,6 +17,8 @@ struct Runtime using Stream = mcStream_t; + using Event = mcEvent_t; + static constexpr Device::Type kDeviceType = Device::Type::kMetax; static constexpr Error kSuccess = mcSuccess; @@ -33,18 +35,38 @@ struct Runtime return mcMalloc(std::forward(args)...); }; - static constexpr auto Memcpy = [](auto&&... args) { - return mcMemcpy(std::forward(args)...); + static constexpr auto MallocHost = [](auto&&... args) { + return mcMallocHost(std::forward(args)...); }; - static constexpr auto MemcpyAsync = [](auto&&... args) { - return mcMemcpyAsync(std::forward(args)...); + static constexpr auto MallocAsync = [](auto&&... args) { + return mcMallocAsync(std::forward(args)...); }; static constexpr auto Free = [](auto&&... args) { return mcFree(std::forward(args)...); }; + static constexpr auto FreeHost = [](auto&&... args) { + return mcFreeHost(std::forward(args)...); + }; + + static constexpr auto FreeAsync = [](auto&&... args) { + return mcFreeAsync(std::forward(args)...); + }; + + static constexpr auto MemGetInfo = [](auto&&... args) { + return mcMemGetInfo(std::forward(args)...); + }; + + static constexpr auto Memcpy = [](auto&&... args) { + return mcMemcpy(std::forward(args)...); + }; + + static constexpr auto MemcpyAsync = [](auto&&... args) { + return mcMemcpyAsync(std::forward(args)...); + }; + static constexpr auto kMemcpyHostToHost = mcMemcpyHostToHost; static constexpr auto kMemcpyHostToDevice = mcMemcpyHostToDevice; @@ -54,6 +76,54 @@ struct Runtime static constexpr auto kMemcpyDeviceToDevice = mcMemcpyDeviceToDevice; static constexpr auto Memset = mcMemset; + + static constexpr auto MemsetAsync = [](auto&&... args) { + return mcMemsetAsync(std::forward(args)...); + }; + + static constexpr auto StreamCreate = [](auto&&... args) { + return mcStreamCreate(std::forward(args)...); + }; + + static constexpr auto StreamDestroy = [](auto&&... args) { + return mcStreamDestroy(std::forward(args)...); + }; + + static constexpr auto StreamSynchronize = [](auto&&... args) { + return mcStreamSynchronize(std::forward(args)...); + }; + + static constexpr auto StreamWaitEvent = [](auto&&... args) { + return mcStreamWaitEvent(std::forward(args)...); + }; + + static constexpr auto EventCreate = [](auto&&... args) { + return mcEventCreate(std::forward(args)...); + }; + + static constexpr auto EventCreateWithFlags = [](auto&&... args) { + return mcEventCreateWithFlags(std::forward(args)...); + }; + + static constexpr auto EventRecord = [](auto&&... args) { + return mcEventRecord(std::forward(args)...); + }; + + static constexpr auto EventQuery = [](auto&&... args) { + return mcEventQuery(std::forward(args)...); + }; + + static constexpr auto EventSynchronize = [](auto&&... args) { + return mcEventSynchronize(std::forward(args)...); + }; + + static constexpr auto EventDestroy = [](auto&&... args) { + return mcEventDestroy(std::forward(args)...); + }; + + static constexpr auto EventElapsedTime = [](auto&&... args) { + return mcEventElapsedTime(std::forward(args)...); + }; }; static_assert(Runtime::Validate()); diff --git a/src/native/cuda/moore/runtime_.h b/src/native/cuda/moore/runtime_.h index 5beffcf..81ccd15 100644 --- a/src/native/cuda/moore/runtime_.h +++ b/src/native/cuda/moore/runtime_.h @@ -3,6 +3,7 @@ #include +#include #include #include "native/cuda/moore/device_.h" @@ -17,6 +18,8 @@ struct Runtime using Stream = musaStream_t; + using Event = musaEvent_t; + static constexpr Device::Type kDeviceType = Device::Type::kMoore; static constexpr Error kSuccess = musaSuccess; @@ -39,18 +42,38 @@ struct Runtime return musaMalloc(std::forward(args)...); }; - static constexpr auto Memcpy = [](auto&&... args) { - return musaMemcpy(std::forward(args)...); + static constexpr auto MallocHost = [](auto&&... args) { + return musaMallocHost(std::forward(args)...); }; - static constexpr auto MemcpyAsync = [](auto&&... args) { - return musaMemcpyAsync(std::forward(args)...); + static constexpr auto MallocAsync = [](void**, std::size_t, Stream) { + return static_cast(1); }; static constexpr auto Free = [](auto&&... args) { return musaFree(std::forward(args)...); }; + static constexpr auto FreeHost = [](auto&&... args) { + return musaFreeHost(std::forward(args)...); + }; + + static constexpr auto FreeAsync = [](void*, Stream) { + return static_cast(1); + }; + + static constexpr auto MemGetInfo = [](auto&&... args) { + return musaMemGetInfo(std::forward(args)...); + }; + + static constexpr auto Memcpy = [](auto&&... args) { + return musaMemcpy(std::forward(args)...); + }; + + static constexpr auto MemcpyAsync = [](auto&&... args) { + return musaMemcpyAsync(std::forward(args)...); + }; + static constexpr auto kMemcpyHostToHost = musaMemcpyHostToHost; static constexpr auto kMemcpyHostToDevice = musaMemcpyHostToDevice; @@ -60,6 +83,54 @@ struct Runtime static constexpr auto kMemcpyDeviceToDevice = musaMemcpyDeviceToDevice; static constexpr auto Memset = musaMemset; + + static constexpr auto MemsetAsync = [](auto&&... args) { + return musaMemsetAsync(std::forward(args)...); + }; + + static constexpr auto StreamCreate = [](auto&&... args) { + return musaStreamCreate(std::forward(args)...); + }; + + static constexpr auto StreamDestroy = [](auto&&... args) { + return musaStreamDestroy(std::forward(args)...); + }; + + static constexpr auto StreamSynchronize = [](auto&&... args) { + return musaStreamSynchronize(std::forward(args)...); + }; + + static constexpr auto StreamWaitEvent = [](auto&&... args) { + return musaStreamWaitEvent(std::forward(args)...); + }; + + static constexpr auto EventCreate = [](auto&&... args) { + return musaEventCreate(std::forward(args)...); + }; + + static constexpr auto EventCreateWithFlags = [](auto&&... args) { + return musaEventCreateWithFlags(std::forward(args)...); + }; + + static constexpr auto EventRecord = [](auto&&... args) { + return musaEventRecord(std::forward(args)...); + }; + + static constexpr auto EventQuery = [](auto&&... args) { + return musaEventQuery(std::forward(args)...); + }; + + static constexpr auto EventSynchronize = [](auto&&... args) { + return musaEventSynchronize(std::forward(args)...); + }; + + static constexpr auto EventDestroy = [](auto&&... args) { + return musaEventDestroy(std::forward(args)...); + }; + + static constexpr auto EventElapsedTime = [](auto&&... args) { + return musaEventElapsedTime(std::forward(args)...); + }; }; static_assert(Runtime::Validate()); diff --git a/src/native/cuda/nvidia/runtime_.h b/src/native/cuda/nvidia/runtime_.h index 1786e08..840f0bd 100644 --- a/src/native/cuda/nvidia/runtime_.h +++ b/src/native/cuda/nvidia/runtime_.h @@ -19,6 +19,8 @@ struct Runtime using Stream = cudaStream_t; + using Event = cudaEvent_t; + static constexpr Device::Type kDeviceType = Device::Type::kNvidia; static constexpr Error kSuccess = cudaSuccess; @@ -35,12 +37,32 @@ struct Runtime return cudaMalloc(std::forward(args)...); }; - static constexpr auto Memcpy = cudaMemcpy; + static constexpr auto MallocHost = [](auto&&... args) { + return cudaMallocHost(std::forward(args)...); + }; - static constexpr auto MemcpyAsync = cudaMemcpyAsync; + static constexpr auto MallocAsync = [](auto&&... args) { + return cudaMallocAsync(std::forward(args)...); + }; static constexpr auto Free = cudaFree; + static constexpr auto FreeHost = [](auto&&... args) { + return cudaFreeHost(std::forward(args)...); + }; + + static constexpr auto FreeAsync = [](auto&&... args) { + return cudaFreeAsync(std::forward(args)...); + }; + + static constexpr auto MemGetInfo = [](auto&&... args) { + return cudaMemGetInfo(std::forward(args)...); + }; + + static constexpr auto Memcpy = cudaMemcpy; + + static constexpr auto MemcpyAsync = cudaMemcpyAsync; + static constexpr auto kMemcpyHostToHost = cudaMemcpyHostToHost; static constexpr auto kMemcpyHostToDevice = cudaMemcpyHostToDevice; @@ -50,6 +72,54 @@ struct Runtime static constexpr auto kMemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; static constexpr auto Memset = cudaMemset; + + static constexpr auto MemsetAsync = [](auto&&... args) { + return cudaMemsetAsync(std::forward(args)...); + }; + + static constexpr auto StreamCreate = [](auto&&... args) { + return cudaStreamCreate(std::forward(args)...); + }; + + static constexpr auto StreamDestroy = [](auto&&... args) { + return cudaStreamDestroy(std::forward(args)...); + }; + + static constexpr auto StreamSynchronize = [](auto&&... args) { + return cudaStreamSynchronize(std::forward(args)...); + }; + + static constexpr auto StreamWaitEvent = [](auto&&... args) { + return cudaStreamWaitEvent(std::forward(args)...); + }; + + static constexpr auto EventCreate = [](auto&&... args) { + return cudaEventCreate(std::forward(args)...); + }; + + static constexpr auto EventCreateWithFlags = [](auto&&... args) { + return cudaEventCreateWithFlags(std::forward(args)...); + }; + + static constexpr auto EventRecord = [](auto&&... args) { + return cudaEventRecord(std::forward(args)...); + }; + + static constexpr auto EventQuery = [](auto&&... args) { + return cudaEventQuery(std::forward(args)...); + }; + + static constexpr auto EventSynchronize = [](auto&&... args) { + return cudaEventSynchronize(std::forward(args)...); + }; + + static constexpr auto EventDestroy = [](auto&&... args) { + return cudaEventDestroy(std::forward(args)...); + }; + + static constexpr auto EventElapsedTime = [](auto&&... args) { + return cudaEventElapsedTime(std::forward(args)...); + }; }; static_assert(Runtime::Validate()); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 88b3fc0..2fcf5b4 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -5,7 +5,9 @@ function(add_infini_rt_test target) endfunction() function(add_infini_rt_backend_runtime_test backend device_type runtime_header - expect_async_memcpy_success) + expect_async_memcpy_success supports_host_memory supports_async_memory + supports_mem_get_info supports_memset_async supports_stream_wait_event + supports_event supports_event_elapsed_time) string(TOLOWER "${backend}" backend_lower) set(target "test_${backend_lower}_runtime") add_infini_rt_test(${target} test_native_runtime.cc) @@ -14,7 +16,14 @@ function(add_infini_rt_backend_runtime_test backend device_type runtime_header "INFINI_RT_TEST_BACKEND_NAME=\"${backend}\"" "INFINI_RT_TEST_DEVICE_TYPE=${device_type}" "INFINI_RT_TEST_RUNTIME_HEADER=\"${runtime_header}\"" - "INFINI_RT_TEST_EXPECT_ASYNC_MEMCPY_SUCCESS=${expect_async_memcpy_success}") + "INFINI_RT_TEST_EXPECT_ASYNC_MEMCPY_SUCCESS=${expect_async_memcpy_success}" + "INFINI_RT_TEST_SUPPORTS_HOST_MEMORY=${supports_host_memory}" + "INFINI_RT_TEST_SUPPORTS_ASYNC_MEMORY=${supports_async_memory}" + "INFINI_RT_TEST_SUPPORTS_MEM_GET_INFO=${supports_mem_get_info}" + "INFINI_RT_TEST_SUPPORTS_MEMSET_ASYNC=${supports_memset_async}" + "INFINI_RT_TEST_SUPPORTS_STREAM_WAIT_EVENT=${supports_stream_wait_event}" + "INFINI_RT_TEST_SUPPORTS_EVENT=${supports_event}" + "INFINI_RT_TEST_SUPPORTS_EVENT_ELAPSED_TIME=${supports_event_elapsed_time}") endfunction() add_infini_rt_test(test_smoke test_smoke.cc) @@ -25,56 +34,64 @@ set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND OFF) if(WITH_CPU) set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) add_infini_rt_backend_runtime_test( - CPU infini::rt::Device::Type::kCpu infini/rt/cpu/runtime_.h 0) + CPU infini::rt::Device::Type::kCpu infini/rt/cpu/runtime_.h + 0 1 0 1 0 1 1 1) endif() if(WITH_NVIDIA) set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) add_infini_rt_backend_runtime_test( NVIDIA infini::rt::Device::Type::kNvidia - infini/rt/nvidia/runtime_.h 1) + infini/rt/nvidia/runtime_.h + 1 1 1 1 1 1 1 1) endif() if(WITH_ILUVATAR) set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) add_infini_rt_backend_runtime_test( ILUVATAR infini::rt::Device::Type::kIluvatar - infini/rt/iluvatar/runtime_.h 1) + infini/rt/iluvatar/runtime_.h + 1 1 1 1 1 1 1 1) endif() if(WITH_HYGON) set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) add_infini_rt_backend_runtime_test( HYGON infini::rt::Device::Type::kHygon - infini/rt/hygon/runtime_.h 1) + infini/rt/hygon/runtime_.h + 1 1 1 1 1 1 1 1) endif() if(WITH_METAX) set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) add_infini_rt_backend_runtime_test( METAX infini::rt::Device::Type::kMetax - infini/rt/metax/runtime_.h 1) + infini/rt/metax/runtime_.h + 1 1 1 1 1 1 1 1) endif() if(WITH_MOORE) set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) add_infini_rt_backend_runtime_test( MOORE infini::rt::Device::Type::kMoore - infini/rt/moore/runtime_.h 1) + infini/rt/moore/runtime_.h + 1 1 0 1 1 1 1 1) endif() if(WITH_CAMBRICON) set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) add_infini_rt_backend_runtime_test( CAMBRICON infini::rt::Device::Type::kCambricon - infini/rt/cambricon/runtime_.h 1) + infini/rt/cambricon/runtime_.h + 1 0 0 0 0 0 0 0) endif() if(WITH_ASCEND) set(INFINI_RT_TEST_HAS_RUNTIME_BACKEND ON) add_infini_rt_backend_runtime_test( ASCEND infini::rt::Device::Type::kAscend - infini/rt/ascend/runtime_.h 1) + infini/rt/ascend/runtime_.h + 1 0 0 0 0 0 0 0) endif() if(INFINI_RT_TEST_HAS_RUNTIME_BACKEND) diff --git a/tests/test_native_runtime.cc b/tests/test_native_runtime.cc index cdd39cb..3ae7ea9 100644 --- a/tests/test_native_runtime.cc +++ b/tests/test_native_runtime.cc @@ -14,6 +14,15 @@ using Runtime = infini::rt::runtime::Runtime; constexpr bool kExpectAsyncMemcpySuccess = INFINI_RT_TEST_EXPECT_ASYNC_MEMCPY_SUCCESS != 0; +constexpr bool kSupportsHostMemory = INFINI_RT_TEST_SUPPORTS_HOST_MEMORY != 0; +constexpr bool kSupportsAsyncMemory = INFINI_RT_TEST_SUPPORTS_ASYNC_MEMORY != 0; +constexpr bool kSupportsMemGetInfo = INFINI_RT_TEST_SUPPORTS_MEM_GET_INFO != 0; +constexpr bool kSupportsMemsetAsync = INFINI_RT_TEST_SUPPORTS_MEMSET_ASYNC != 0; +constexpr bool kSupportsStreamWaitEvent = + INFINI_RT_TEST_SUPPORTS_STREAM_WAIT_EVENT != 0; +constexpr bool kSupportsEvent = INFINI_RT_TEST_SUPPORTS_EVENT != 0; +constexpr bool kSupportsEventElapsedTime = + INFINI_RT_TEST_SUPPORTS_EVENT_ELAPSED_TIME != 0; void ExpectSuccess(infini::rt::test::TestContext* context, typename Runtime::Error status, const char* message) { @@ -25,6 +34,16 @@ void ExpectFailure(infini::rt::test::TestContext* context, context->Expect(status != Runtime::kSuccess, message); } +void ExpectStatus(infini::rt::test::TestContext* context, + typename Runtime::Error status, bool expect_success, + const char* message) { + if (expect_success) { + ExpectSuccess(context, status, message); + } else { + ExpectFailure(context, status, message); + } +} + bool SelectDevice() { int device_count = 0; if (Runtime::GetDeviceCount(&device_count) != Runtime::kSuccess || @@ -43,6 +62,21 @@ bool SelectDevice() { return true; } +bool CreateStream(infini::rt::test::TestContext* context, + typename Runtime::Stream* stream) { + const auto status = Runtime::StreamCreate(stream); + ExpectSuccess(context, status, + INFINI_RT_TEST_BACKEND_NAME " runtime should create a stream."); + return status == Runtime::kSuccess; +} + +void DestroyStream(infini::rt::test::TestContext* context, + typename Runtime::Stream stream) { + ExpectSuccess(context, Runtime::StreamDestroy(stream), + INFINI_RT_TEST_BACKEND_NAME + " runtime should destroy a stream."); +} + void TestDevice(infini::rt::test::TestContext* context) { int current_device = -1; ExpectSuccess(context, Runtime::GetDevice(¤t_device), @@ -72,6 +106,72 @@ void TestMallocAndFree(infini::rt::test::TestContext* context) { } } +void TestHostMemory(infini::rt::test::TestContext* context) { + void* ptr = nullptr; + const auto malloc_status = Runtime::MallocHost(&ptr, 16); + ExpectStatus(context, malloc_status, kSupportsHostMemory, + INFINI_RT_TEST_BACKEND_NAME + " runtime should report expected host memory support."); + + if constexpr (kSupportsHostMemory) { + context->Expect(ptr != nullptr, INFINI_RT_TEST_BACKEND_NAME + " runtime host allocation should produce a pointer."); + if (ptr != nullptr) { + ExpectSuccess(context, Runtime::FreeHost(ptr), + INFINI_RT_TEST_BACKEND_NAME + " runtime should free host memory."); + } + } +} + +void TestAsyncMemory(infini::rt::test::TestContext* context) { + typename Runtime::Stream stream{}; + if (!CreateStream(context, &stream)) { + return; + } + + void* ptr = nullptr; + const auto malloc_status = Runtime::MallocAsync(&ptr, 16, stream); + ExpectStatus(context, malloc_status, kSupportsAsyncMemory, + INFINI_RT_TEST_BACKEND_NAME + " runtime should report expected async allocation support."); + + if constexpr (kSupportsAsyncMemory) { + context->Expect(ptr != nullptr, INFINI_RT_TEST_BACKEND_NAME + " runtime async allocation should produce a pointer."); + if (ptr != nullptr) { + ExpectSuccess(context, Runtime::FreeAsync(ptr, stream), + INFINI_RT_TEST_BACKEND_NAME + " runtime should free async memory."); + ExpectSuccess(context, Runtime::StreamSynchronize(stream), + INFINI_RT_TEST_BACKEND_NAME + " runtime should synchronize async memory operations."); + } + } else if (ptr != nullptr) { + ExpectSuccess(context, Runtime::Free(ptr), + INFINI_RT_TEST_BACKEND_NAME + " runtime should free unexpected async allocation."); + } + + DestroyStream(context, stream); +} + +void TestMemGetInfo(infini::rt::test::TestContext* context) { + std::size_t free = 0; + std::size_t total = 0; + const auto status = Runtime::MemGetInfo(&free, &total); + ExpectStatus(context, status, kSupportsMemGetInfo, + INFINI_RT_TEST_BACKEND_NAME + " runtime should report expected memory info support."); + + if constexpr (kSupportsMemGetInfo) { + context->Expect(total > 0, INFINI_RT_TEST_BACKEND_NAME + " runtime should report total memory."); + context->Expect(free <= total, INFINI_RT_TEST_BACKEND_NAME + " runtime free memory should not exceed total memory."); + } +} + void TestMemcpyRoundTrip(infini::rt::test::TestContext* context) { std::array input{0, 1, 2, 3, 4, 5, 6, 7}; std::array output{}; @@ -174,6 +274,190 @@ void TestMemset(infini::rt::test::TestContext* context) { } } +void TestMemsetAsync(infini::rt::test::TestContext* context) { + std::array output{}; + void* ptr = nullptr; + + ExpectSuccess(context, Runtime::Malloc(&ptr, output.size()), + INFINI_RT_TEST_BACKEND_NAME + " runtime should allocate async memset memory."); + if (ptr == nullptr) { + return; + } + + typename Runtime::Stream stream{}; + if (!CreateStream(context, &stream)) { + ExpectSuccess(context, Runtime::Free(ptr), + INFINI_RT_TEST_BACKEND_NAME + " runtime should free async memset memory."); + return; + } + + const auto memset_status = + Runtime::MemsetAsync(ptr, 0xA5, output.size(), stream); + ExpectStatus(context, memset_status, kSupportsMemsetAsync, + INFINI_RT_TEST_BACKEND_NAME + " runtime should report expected async memset support."); + + if constexpr (kSupportsMemsetAsync) { + ExpectSuccess(context, Runtime::StreamSynchronize(stream), + INFINI_RT_TEST_BACKEND_NAME + " runtime should synchronize async memset."); + ExpectSuccess(context, + Runtime::Memcpy(output.data(), ptr, output.size(), + Runtime::kMemcpyDeviceToHost), + INFINI_RT_TEST_BACKEND_NAME + " runtime should copy async filled memory to host."); + for (const auto value : output) { + context->ExpectEqual(value, static_cast(0xA5), + INFINI_RT_TEST_BACKEND_NAME + " runtime should preserve async filled bytes."); + } + } + + DestroyStream(context, stream); + ExpectSuccess(context, Runtime::Free(ptr), + INFINI_RT_TEST_BACKEND_NAME + " runtime should free async memset memory."); +} + +void TestStream(infini::rt::test::TestContext* context) { + typename Runtime::Stream stream{}; + if (!CreateStream(context, &stream)) { + return; + } + + ExpectSuccess(context, Runtime::StreamSynchronize(stream), + INFINI_RT_TEST_BACKEND_NAME + " runtime should synchronize a stream."); + DestroyStream(context, stream); +} + +void TestEvent(infini::rt::test::TestContext* context) { + typename Runtime::Event event{}; + const auto create_status = Runtime::EventCreate(&event); + ExpectStatus(context, create_status, kSupportsEvent, + INFINI_RT_TEST_BACKEND_NAME + " runtime should report expected event support."); + + if constexpr (!kSupportsEvent) { + return; + } + + context->Expect(event != typename Runtime::Event{}, + INFINI_RT_TEST_BACKEND_NAME + " runtime event creation should produce an event."); + if (event == typename Runtime::Event{}) { + return; + } + + ExpectSuccess(context, + Runtime::EventRecord(event, typename Runtime::Stream{}), + INFINI_RT_TEST_BACKEND_NAME " runtime should record an event."); + ExpectSuccess(context, Runtime::EventSynchronize(event), + INFINI_RT_TEST_BACKEND_NAME + " runtime should synchronize an event."); + ExpectSuccess(context, Runtime::EventQuery(event), + INFINI_RT_TEST_BACKEND_NAME + " runtime should query a completed event."); + ExpectSuccess(context, Runtime::EventDestroy(event), + INFINI_RT_TEST_BACKEND_NAME + " runtime should destroy an event."); + + typename Runtime::Event flagged_event{}; + ExpectSuccess(context, Runtime::EventCreateWithFlags(&flagged_event, 0), + INFINI_RT_TEST_BACKEND_NAME + " runtime should create an event with flags."); + if (flagged_event != typename Runtime::Event{}) { + ExpectSuccess(context, Runtime::EventDestroy(flagged_event), + INFINI_RT_TEST_BACKEND_NAME + " runtime should destroy a flagged event."); + } +} + +void TestStreamWaitEvent(infini::rt::test::TestContext* context) { + typename Runtime::Stream stream{}; + if (!CreateStream(context, &stream)) { + return; + } + + typename Runtime::Event event{}; + if constexpr (kSupportsEvent) { + ExpectSuccess(context, Runtime::EventCreate(&event), + INFINI_RT_TEST_BACKEND_NAME + " runtime should create a stream wait event."); + if (event != typename Runtime::Event{}) { + ExpectSuccess(context, Runtime::EventRecord(event, stream), + INFINI_RT_TEST_BACKEND_NAME + " runtime should record a stream wait event."); + } + } + + const auto wait_status = Runtime::StreamWaitEvent(stream, event, 0); + ExpectStatus(context, wait_status, kSupportsStreamWaitEvent, + INFINI_RT_TEST_BACKEND_NAME + " runtime should report expected stream wait event support."); + ExpectSuccess(context, Runtime::StreamSynchronize(stream), + INFINI_RT_TEST_BACKEND_NAME + " runtime should synchronize after stream wait event."); + + if constexpr (kSupportsEvent) { + if (event != typename Runtime::Event{}) { + ExpectSuccess(context, Runtime::EventDestroy(event), + INFINI_RT_TEST_BACKEND_NAME + " runtime should destroy a stream wait event."); + } + } + DestroyStream(context, stream); +} + +void TestEventElapsedTime(infini::rt::test::TestContext* context) { + typename Runtime::Event start{}; + typename Runtime::Event end{}; + if constexpr (kSupportsEvent) { + ExpectSuccess(context, Runtime::EventCreate(&start), + INFINI_RT_TEST_BACKEND_NAME + " runtime should create an elapsed-time start event."); + ExpectSuccess(context, Runtime::EventCreate(&end), + INFINI_RT_TEST_BACKEND_NAME + " runtime should create an elapsed-time end event."); + } + + if constexpr (kSupportsEventElapsedTime) { + ExpectSuccess(context, + Runtime::EventRecord(start, typename Runtime::Stream{}), + INFINI_RT_TEST_BACKEND_NAME + " runtime should record an elapsed-time start event."); + ExpectSuccess(context, + Runtime::EventRecord(end, typename Runtime::Stream{}), + INFINI_RT_TEST_BACKEND_NAME + " runtime should record an elapsed-time end event."); + ExpectSuccess(context, Runtime::EventSynchronize(end), + INFINI_RT_TEST_BACKEND_NAME + " runtime should synchronize an elapsed-time event."); + } + + float elapsed_ms = 0.0f; + const auto elapsed_status = + Runtime::EventElapsedTime(&elapsed_ms, start, end); + ExpectStatus(context, elapsed_status, kSupportsEventElapsedTime, + INFINI_RT_TEST_BACKEND_NAME + " runtime should report expected event elapsed-time support."); + + if constexpr (kSupportsEvent) { + if (start != typename Runtime::Event{}) { + ExpectSuccess(context, Runtime::EventDestroy(start), + INFINI_RT_TEST_BACKEND_NAME + " runtime should destroy an elapsed-time start event."); + } + if (end != typename Runtime::Event{}) { + ExpectSuccess(context, Runtime::EventDestroy(end), + INFINI_RT_TEST_BACKEND_NAME + " runtime should destroy an elapsed-time end event."); + } + } +} + } // namespace int main() { @@ -185,9 +469,17 @@ int main() { TestDevice(&context); TestMallocAndFree(&context); + TestHostMemory(&context); + TestAsyncMemory(&context); + TestMemGetInfo(&context); TestMemcpyRoundTrip(&context); TestMemcpyAsync(&context); TestMemset(&context); + TestMemsetAsync(&context); + TestStream(&context); + TestEvent(&context); + TestStreamWaitEvent(&context); + TestEventElapsedTime(&context); return context.ExitCode(); } diff --git a/tests/test_runtime_dispatch.cc b/tests/test_runtime_dispatch.cc index 3264b98..be49b3b 100644 --- a/tests/test_runtime_dispatch.cc +++ b/tests/test_runtime_dispatch.cc @@ -13,6 +13,17 @@ namespace { namespace runtime = infini::rt::runtime; +struct RuntimeApiExpectations { + bool async_memcpy; + bool host_memory; + bool async_memory; + bool mem_get_info; + bool memset_async; + bool stream_wait_event; + bool event; + bool event_elapsed_time; +}; + void ExpectSuccess(infini::rt::test::TestContext* context, runtime::Error status, std::string_view message) { context->Expect(status == runtime::kSuccess, message); @@ -23,6 +34,15 @@ void ExpectFailure(infini::rt::test::TestContext* context, context->Expect(status != runtime::kSuccess, message); } +void ExpectStatus(infini::rt::test::TestContext* context, runtime::Error status, + bool expect_success, std::string_view message) { + if (expect_success) { + ExpectSuccess(context, status, message); + } else { + ExpectFailure(context, status, message); + } +} + std::string Message(std::string_view backend_name, std::string_view message) { return std::string{backend_name} + " dispatch should " + std::string{message} + "."; @@ -54,13 +74,22 @@ bool SelectDevice(infini::rt::test::TestContext* context, return true; } -void TestDispatch(infini::rt::test::TestContext* context, - infini::rt::Device::Type device_type, - const char* backend_name, bool expect_async_memcpy_success) { - if (!SelectDevice(context, device_type, backend_name)) { - return; - } +bool CreateStream(infini::rt::test::TestContext* context, + const char* backend_name, runtime::Stream* stream) { + const auto status = runtime::StreamCreate(stream); + ExpectSuccess(context, status, Message(backend_name, "create a stream")); + return status == runtime::kSuccess; +} +void DestroyStream(infini::rt::test::TestContext* context, + const char* backend_name, runtime::Stream stream) { + ExpectSuccess(context, runtime::StreamDestroy(stream), + Message(backend_name, "destroy a stream")); +} + +void TestCoreDispatch(infini::rt::test::TestContext* context, + const char* backend_name, + const RuntimeApiExpectations& expectations) { std::array input{1, 2, 3, 4}; std::array output{}; void* ptr = nullptr; @@ -85,14 +114,11 @@ void TestDispatch(infini::rt::test::TestContext* context, runtime::Stream stream{}; const auto async_status = runtime::MemcpyAsync( ptr, input.data(), input.size(), runtime::kMemcpyHostToDevice, stream); - if (expect_async_memcpy_success) { - ExpectSuccess(context, async_status, - Message(backend_name, "support async host-to-device copy")); + ExpectStatus(context, async_status, expectations.async_memcpy, + Message(backend_name, "report expected async memcpy support")); + if (expectations.async_memcpy) { ExpectSuccess(context, runtime::DeviceSynchronize(), Message(backend_name, "synchronize async copy")); - } else { - ExpectFailure(context, async_status, - Message(backend_name, "not report async memcpy success")); } ExpectSuccess(context, @@ -120,42 +146,298 @@ void TestDispatch(infini::rt::test::TestContext* context, Message(backend_name, "free memory")); } +void TestHostMemory(infini::rt::test::TestContext* context, + const char* backend_name, + const RuntimeApiExpectations& expectations) { + void* ptr = nullptr; + const auto malloc_status = runtime::MallocHost(&ptr, 16); + ExpectStatus(context, malloc_status, expectations.host_memory, + Message(backend_name, "report expected host memory support")); + + if (expectations.host_memory) { + context->Expect( + ptr != nullptr, + Message(backend_name, "produce a pointer from host allocation")); + if (ptr != nullptr) { + ExpectSuccess(context, runtime::FreeHost(ptr), + Message(backend_name, "free host memory")); + } + } +} + +void TestAsyncMemory(infini::rt::test::TestContext* context, + const char* backend_name, + const RuntimeApiExpectations& expectations) { + runtime::Stream stream{}; + if (!CreateStream(context, backend_name, &stream)) { + return; + } + + void* ptr = nullptr; + const auto malloc_status = runtime::MallocAsync(&ptr, 16, stream); + ExpectStatus( + context, malloc_status, expectations.async_memory, + Message(backend_name, "report expected async allocation support")); + + if (expectations.async_memory) { + context->Expect( + ptr != nullptr, + Message(backend_name, "produce a pointer from async allocation")); + if (ptr != nullptr) { + ExpectSuccess(context, runtime::FreeAsync(ptr, stream), + Message(backend_name, "free async memory")); + ExpectSuccess( + context, runtime::StreamSynchronize(stream), + Message(backend_name, "synchronize async memory operations")); + } + } else if (ptr != nullptr) { + ExpectSuccess(context, runtime::Free(ptr), + Message(backend_name, "free unexpected async allocation")); + } + + DestroyStream(context, backend_name, stream); +} + +void TestMemGetInfo(infini::rt::test::TestContext* context, + const char* backend_name, + const RuntimeApiExpectations& expectations) { + std::size_t free = 0; + std::size_t total = 0; + const auto status = runtime::MemGetInfo(&free, &total); + ExpectStatus(context, status, expectations.mem_get_info, + Message(backend_name, "report expected memory info support")); + + if (expectations.mem_get_info) { + context->Expect(total > 0, Message(backend_name, "report total memory")); + context->Expect( + free <= total, + Message(backend_name, "report free memory not exceeding total memory")); + } +} + +void TestMemsetAsync(infini::rt::test::TestContext* context, + const char* backend_name, + const RuntimeApiExpectations& expectations) { + std::array output{}; + void* ptr = nullptr; + + ExpectSuccess(context, runtime::Malloc(&ptr, output.size()), + Message(backend_name, "allocate async memset memory")); + if (ptr == nullptr) { + return; + } + + runtime::Stream stream{}; + if (!CreateStream(context, backend_name, &stream)) { + ExpectSuccess(context, runtime::Free(ptr), + Message(backend_name, "free async memset memory")); + return; + } + + const auto memset_status = + runtime::MemsetAsync(ptr, 0xA5, output.size(), stream); + ExpectStatus(context, memset_status, expectations.memset_async, + Message(backend_name, "report expected async memset support")); + + if (expectations.memset_async) { + ExpectSuccess(context, runtime::StreamSynchronize(stream), + Message(backend_name, "synchronize async memset")); + ExpectSuccess(context, + runtime::Memcpy(output.data(), ptr, output.size(), + runtime::kMemcpyDeviceToHost), + Message(backend_name, "copy async filled memory to host")); + for (const auto value : output) { + context->ExpectEqual( + value, static_cast(0xA5), + Message(backend_name, "preserve async filled bytes")); + } + } + + DestroyStream(context, backend_name, stream); + ExpectSuccess(context, runtime::Free(ptr), + Message(backend_name, "free async memset memory")); +} + +void TestStream(infini::rt::test::TestContext* context, + const char* backend_name) { + runtime::Stream stream{}; + if (!CreateStream(context, backend_name, &stream)) { + return; + } + + ExpectSuccess(context, runtime::StreamSynchronize(stream), + Message(backend_name, "synchronize a stream")); + DestroyStream(context, backend_name, stream); +} + +void TestEvent(infini::rt::test::TestContext* context, const char* backend_name, + const RuntimeApiExpectations& expectations) { + runtime::Event event{}; + const auto create_status = runtime::EventCreate(&event); + ExpectStatus(context, create_status, expectations.event, + Message(backend_name, "report expected event support")); + + if (!expectations.event) { + return; + } + + context->Expect( + event != nullptr, + Message(backend_name, "produce an event from event creation")); + if (event == nullptr) { + return; + } + + ExpectSuccess(context, runtime::EventRecord(event, runtime::Stream{}), + Message(backend_name, "record an event")); + ExpectSuccess(context, runtime::EventSynchronize(event), + Message(backend_name, "synchronize an event")); + ExpectSuccess(context, runtime::EventQuery(event), + Message(backend_name, "query a completed event")); + ExpectSuccess(context, runtime::EventDestroy(event), + Message(backend_name, "destroy an event")); + + runtime::Event flagged_event{}; + ExpectSuccess(context, runtime::EventCreateWithFlags(&flagged_event, 0), + Message(backend_name, "create an event with flags")); + if (flagged_event != nullptr) { + ExpectSuccess(context, runtime::EventDestroy(flagged_event), + Message(backend_name, "destroy a flagged event")); + } +} + +void TestStreamWaitEvent(infini::rt::test::TestContext* context, + const char* backend_name, + const RuntimeApiExpectations& expectations) { + runtime::Stream stream{}; + if (!CreateStream(context, backend_name, &stream)) { + return; + } + + runtime::Event event{}; + if (expectations.event) { + ExpectSuccess(context, runtime::EventCreate(&event), + Message(backend_name, "create a stream wait event")); + if (event != nullptr) { + ExpectSuccess(context, runtime::EventRecord(event, stream), + Message(backend_name, "record a stream wait event")); + } + } + + const auto wait_status = runtime::StreamWaitEvent(stream, event, 0); + ExpectStatus( + context, wait_status, expectations.stream_wait_event, + Message(backend_name, "report expected stream wait event support")); + ExpectSuccess(context, runtime::StreamSynchronize(stream), + Message(backend_name, "synchronize after stream wait event")); + + if (event != nullptr) { + ExpectSuccess(context, runtime::EventDestroy(event), + Message(backend_name, "destroy a stream wait event")); + } + DestroyStream(context, backend_name, stream); +} + +void TestEventElapsedTime(infini::rt::test::TestContext* context, + const char* backend_name, + const RuntimeApiExpectations& expectations) { + runtime::Event start{}; + runtime::Event end{}; + if (expectations.event) { + ExpectSuccess(context, runtime::EventCreate(&start), + Message(backend_name, "create an elapsed-time start event")); + ExpectSuccess(context, runtime::EventCreate(&end), + Message(backend_name, "create an elapsed-time end event")); + } + + if (expectations.event_elapsed_time) { + ExpectSuccess(context, runtime::EventRecord(start, runtime::Stream{}), + Message(backend_name, "record an elapsed-time start event")); + ExpectSuccess(context, runtime::EventRecord(end, runtime::Stream{}), + Message(backend_name, "record an elapsed-time end event")); + ExpectSuccess(context, runtime::EventSynchronize(end), + Message(backend_name, "synchronize an elapsed-time event")); + } + + float elapsed_ms = 0.0f; + const auto elapsed_status = + runtime::EventElapsedTime(&elapsed_ms, start, end); + ExpectStatus( + context, elapsed_status, expectations.event_elapsed_time, + Message(backend_name, "report expected event elapsed-time support")); + + if (start != nullptr) { + ExpectSuccess(context, runtime::EventDestroy(start), + Message(backend_name, "destroy an elapsed-time start event")); + } + if (end != nullptr) { + ExpectSuccess(context, runtime::EventDestroy(end), + Message(backend_name, "destroy an elapsed-time end event")); + } +} + +void TestDispatch(infini::rt::test::TestContext* context, + infini::rt::Device::Type device_type, + const char* backend_name, + const RuntimeApiExpectations& expectations) { + if (!SelectDevice(context, device_type, backend_name)) { + return; + } + + TestCoreDispatch(context, backend_name, expectations); + TestHostMemory(context, backend_name, expectations); + TestAsyncMemory(context, backend_name, expectations); + TestMemGetInfo(context, backend_name, expectations); + TestMemsetAsync(context, backend_name, expectations); + TestStream(context, backend_name); + TestEvent(context, backend_name, expectations); + TestStreamWaitEvent(context, backend_name, expectations); + TestEventElapsedTime(context, backend_name, expectations); +} + } // namespace int main() { infini::rt::test::TestContext context; #if defined(INFINI_RT_TEST_WITH_CPU) - TestDispatch(&context, infini::rt::Device::Type::kCpu, "CPU", false); + TestDispatch(&context, infini::rt::Device::Type::kCpu, "CPU", + {false, true, false, true, false, true, true, true}); #endif #if defined(INFINI_RT_TEST_WITH_NVIDIA) - TestDispatch(&context, infini::rt::Device::Type::kNvidia, "NVIDIA", true); + TestDispatch(&context, infini::rt::Device::Type::kNvidia, "NVIDIA", + {true, true, true, true, true, true, true, true}); #endif #if defined(INFINI_RT_TEST_WITH_ILUVATAR) - TestDispatch(&context, infini::rt::Device::Type::kIluvatar, "ILUVATAR", true); + TestDispatch(&context, infini::rt::Device::Type::kIluvatar, "ILUVATAR", + {true, true, true, true, true, true, true, true}); #endif #if defined(INFINI_RT_TEST_WITH_HYGON) - TestDispatch(&context, infini::rt::Device::Type::kHygon, "HYGON", true); + TestDispatch(&context, infini::rt::Device::Type::kHygon, "HYGON", + {true, true, true, true, true, true, true, true}); #endif #if defined(INFINI_RT_TEST_WITH_METAX) - TestDispatch(&context, infini::rt::Device::Type::kMetax, "METAX", true); + TestDispatch(&context, infini::rt::Device::Type::kMetax, "METAX", + {true, true, true, true, true, true, true, true}); #endif #if defined(INFINI_RT_TEST_WITH_MOORE) - TestDispatch(&context, infini::rt::Device::Type::kMoore, "MOORE", true); + TestDispatch(&context, infini::rt::Device::Type::kMoore, "MOORE", + {true, true, false, true, true, true, true, true}); #endif #if defined(INFINI_RT_TEST_WITH_CAMBRICON) TestDispatch(&context, infini::rt::Device::Type::kCambricon, "CAMBRICON", - true); + {true, false, false, false, false, false, false, false}); #endif #if defined(INFINI_RT_TEST_WITH_ASCEND) - TestDispatch(&context, infini::rt::Device::Type::kAscend, "ASCEND", true); + TestDispatch(&context, infini::rt::Device::Type::kAscend, "ASCEND", + {true, false, false, false, false, false, false, false}); #endif return context.ExitCode();