diff --git a/scripts/generate_public_headers.py b/scripts/generate_public_headers.py index 6620b42..3363741 100644 --- a/scripts/generate_public_headers.py +++ b/scripts/generate_public_headers.py @@ -156,9 +156,31 @@ def _write_detail_headers(include_root, source_root, devices): _write_detail_header(include_root, source_root, relative_path) -def _write_generated_header(include_root, devices): +def _write_generated_header(include_root, source_root, devices): default_device = _default_device(devices) default_device_type = _DEVICE_TYPES[default_device] + public_runtime_functions = _public_runtime_functions_for_devices( + devices, source_root + ) + has_graph_api = any( + function.name in {"StreamBeginCapture", "GraphLaunch"} + for function in public_runtime_functions + ) + graph_declarations = ( + """ +using Graph = void*; + +using GraphExec = void*; + +enum class StreamCaptureMode { + kStreamCaptureModeGlobal = 0, + kStreamCaptureModeThreadLocal = 1, + kStreamCaptureModeRelaxed = 2, +}; +""" + if has_graph_api + else "" + ) includes = [ "#include ", "#include ", @@ -177,7 +199,7 @@ def _write_generated_header(include_root, devices): includes.append(f"#include ") runtime_declarations = "\n\n".join( - f"{function.signature()};" for function in _PUBLIC_RUNTIME_FUNCTIONS + f"{function.signature()};" for function in public_runtime_functions ) path = include_root / "infini" / "rt" / "generated.h" @@ -209,6 +231,7 @@ def _write_generated_header(include_root, devices): using Event = void*; +{graph_declarations} using MemcpyKind = std::remove_cv_t< decltype(generated_detail::DefaultErrorRuntime::kMemcpyHostToHost)>; @@ -366,6 +389,32 @@ def params_decl(self): _Param("Event", "end"), ), ), + _Function( + "Error", + "StreamBeginCapture", + (_Param("Stream", "stream"), _Param("StreamCaptureMode", "mode")), + ), + _Function( + "Error", + "StreamEndCapture", + (_Param("Stream", "stream"), _Param("Graph*", "graph")), + ), + _Function("Error", "GraphDestroy", (_Param("Graph", "graph"),)), + _Function( + "Error", + "GraphInstantiate", + (_Param("GraphExec*", "graph_exec"), _Param("Graph", "graph")), + ), + _Function( + "Error", + "GraphExecDestroy", + (_Param("GraphExec", "graph_exec"),), + ), + _Function( + "Error", + "GraphLaunch", + (_Param("GraphExec", "graph_exec"), _Param("Stream", "stream")), + ), ) @@ -395,6 +444,24 @@ def _runtime_arg(param, device): return ( f"reinterpret_cast::Event*>({param.name})" ) + if param.type == "Graph": + return f"reinterpret_cast::Graph>({param.name})" + if param.type == "Graph*": + return ( + f"reinterpret_cast::Graph*>({param.name})" + ) + if param.type == "GraphExec": + return ( + f"reinterpret_cast::GraphExec>" + f"({param.name})" + ) + if param.type == "GraphExec*": + return ( + f"reinterpret_cast::GraphExec*>" + f"({param.name})" + ) + if param.type == "StreamCaptureMode": + return f"RuntimeStreamCaptureMode<{device_type}>({param.name})" return param.name @@ -452,8 +519,42 @@ def _devices_for_function(function, devices, source_root): ) +def _public_runtime_functions_for_devices(devices, source_root): + return tuple( + function + for function in _PUBLIC_RUNTIME_FUNCTIONS + if _devices_for_function(function, devices, source_root) + ) + + def _write_runtime_dispatch(source_path, source_root, devices): - functions = _PUBLIC_RUNTIME_FUNCTIONS + functions = _public_runtime_functions_for_devices(devices, source_root) + stream_capture_mode_helper = ( + """ +template +auto RuntimeStreamCaptureMode(StreamCaptureMode mode) { + using DeviceRuntime = Runtime; + + switch (mode) { + case StreamCaptureMode::kStreamCaptureModeGlobal: + return DeviceRuntime::kStreamCaptureModeGlobal; + case StreamCaptureMode::kStreamCaptureModeThreadLocal: + return DeviceRuntime::kStreamCaptureModeThreadLocal; + case StreamCaptureMode::kStreamCaptureModeRelaxed: + return DeviceRuntime::kStreamCaptureModeRelaxed; + } + + assert(false && "unsupported stream capture mode"); + return DeviceRuntime::kStreamCaptureModeRelaxed; +} +""" + if any( + param.type == "StreamCaptureMode" + for function in functions + for param in function.params + ) + else "" + ) dispatch_functions = "\n".join( _write_runtime_dispatch_function( function, @@ -535,6 +636,7 @@ def _write_runtime_dispatch(source_path, source_root, devices): return DeviceRuntime::kMemcpyHostToHost; }} +{stream_capture_mode_helper} }} // namespace {dispatch_functions} @@ -566,7 +668,7 @@ def main(): for wrapper_device, header_name, target in _DEVICE_HEADERS[device]: _write_wrapper(include_root, wrapper_device, header_name, target) - _write_generated_header(include_root, devices) + _write_generated_header(include_root, source_root, devices) _write_runtime_dispatch(pathlib.Path(args.source_output), source_root, devices) diff --git a/src/native/cuda/iluvatar/runtime_.h b/src/native/cuda/iluvatar/runtime_.h index 9425559..09feb65 100644 --- a/src/native/cuda/iluvatar/runtime_.h +++ b/src/native/cuda/iluvatar/runtime_.h @@ -19,6 +19,10 @@ struct Runtime using Stream = cudaStream_t; + using Graph = cudaGraph_t; + + using GraphExec = cudaGraphExec_t; + using Event = cudaEvent_t; static constexpr Device::Type kDeviceType = Device::Type::kIluvatar; @@ -78,7 +82,8 @@ struct Runtime }; static constexpr auto StreamCreate = [](auto&&... args) { - return cudaStreamCreate(std::forward(args)...); + return cudaStreamCreateWithFlags(std::forward(args)..., + cudaStreamNonBlocking); }; static constexpr auto StreamDestroy = [](auto&&... args) { @@ -120,6 +125,38 @@ struct Runtime static constexpr auto EventElapsedTime = [](auto&&... args) { return cudaEventElapsedTime(std::forward(args)...); }; + + static constexpr auto kStreamCaptureModeGlobal = cudaStreamCaptureModeGlobal; + + static constexpr auto kStreamCaptureModeThreadLocal = + cudaStreamCaptureModeThreadLocal; + + static constexpr auto kStreamCaptureModeRelaxed = + cudaStreamCaptureModeRelaxed; + + static constexpr auto StreamBeginCapture = [](auto&&... args) { + return cudaStreamBeginCapture(std::forward(args)...); + }; + + static constexpr auto StreamEndCapture = [](auto&&... args) { + return cudaStreamEndCapture(std::forward(args)...); + }; + + static constexpr auto GraphDestroy = [](auto&&... args) { + return cudaGraphDestroy(std::forward(args)...); + }; + + static constexpr auto GraphInstantiate = [](auto&&... args) { + return cudaGraphInstantiate(std::forward(args)...); + }; + + static constexpr auto GraphExecDestroy = [](auto&&... args) { + return cudaGraphExecDestroy(std::forward(args)...); + }; + + static constexpr auto GraphLaunch = [](auto&&... args) { + return cudaGraphLaunch(std::forward(args)...); + }; }; static_assert(Runtime::Validate()); diff --git a/src/native/cuda/metax/runtime_.h b/src/native/cuda/metax/runtime_.h index 462f5f1..201f9b0 100644 --- a/src/native/cuda/metax/runtime_.h +++ b/src/native/cuda/metax/runtime_.h @@ -3,6 +3,7 @@ #include +#include #include #include "native/cuda/metax/device_.h" @@ -17,6 +18,10 @@ struct Runtime using Stream = mcStream_t; + using Graph = void*; + + using GraphExec = void*; + using Event = mcEvent_t; static constexpr Device::Type kDeviceType = Device::Type::kMetax; @@ -124,6 +129,28 @@ struct Runtime static constexpr auto EventElapsedTime = [](auto&&... args) { return mcEventElapsedTime(std::forward(args)...); }; + + static constexpr int kStreamCaptureModeGlobal = 0; + + static constexpr int kStreamCaptureModeThreadLocal = 1; + + static constexpr int kStreamCaptureModeRelaxed = 2; + + static Error StreamBeginCapture(Stream, int) { return static_cast(1); } + + static Error StreamEndCapture(Stream, Graph*) { + return static_cast(1); + } + + static Error GraphDestroy(Graph) { return static_cast(1); } + + static Error GraphInstantiate(GraphExec*, Graph) { + return static_cast(1); + } + + static Error GraphExecDestroy(GraphExec) { return static_cast(1); } + + static Error GraphLaunch(GraphExec, Stream) { return static_cast(1); } }; static_assert(Runtime::Validate()); diff --git a/src/native/cuda/moore/runtime_.h b/src/native/cuda/moore/runtime_.h index 81ccd15..37eab28 100644 --- a/src/native/cuda/moore/runtime_.h +++ b/src/native/cuda/moore/runtime_.h @@ -18,6 +18,10 @@ struct Runtime using Stream = musaStream_t; + using Graph = void*; + + using GraphExec = void*; + using Event = musaEvent_t; static constexpr Device::Type kDeviceType = Device::Type::kMoore; @@ -131,6 +135,28 @@ struct Runtime static constexpr auto EventElapsedTime = [](auto&&... args) { return musaEventElapsedTime(std::forward(args)...); }; + + static constexpr int kStreamCaptureModeGlobal = 0; + + static constexpr int kStreamCaptureModeThreadLocal = 1; + + static constexpr int kStreamCaptureModeRelaxed = 2; + + static Error StreamBeginCapture(Stream, int) { return static_cast(1); } + + static Error StreamEndCapture(Stream, Graph*) { + return static_cast(1); + } + + static Error GraphDestroy(Graph) { return static_cast(1); } + + static Error GraphInstantiate(GraphExec*, Graph) { + return static_cast(1); + } + + static Error GraphExecDestroy(GraphExec) { return static_cast(1); } + + static Error GraphLaunch(GraphExec, Stream) { return static_cast(1); } }; static_assert(Runtime::Validate()); diff --git a/src/native/cuda/nvidia/runtime_.h b/src/native/cuda/nvidia/runtime_.h index 840f0bd..c9d2649 100644 --- a/src/native/cuda/nvidia/runtime_.h +++ b/src/native/cuda/nvidia/runtime_.h @@ -19,6 +19,10 @@ struct Runtime using Stream = cudaStream_t; + using Graph = cudaGraph_t; + + using GraphExec = cudaGraphExec_t; + using Event = cudaEvent_t; static constexpr Device::Type kDeviceType = Device::Type::kNvidia; @@ -78,7 +82,8 @@ struct Runtime }; static constexpr auto StreamCreate = [](auto&&... args) { - return cudaStreamCreate(std::forward(args)...); + return cudaStreamCreateWithFlags(std::forward(args)..., + cudaStreamNonBlocking); }; static constexpr auto StreamDestroy = [](auto&&... args) { @@ -120,6 +125,78 @@ struct Runtime static constexpr auto EventElapsedTime = [](auto&&... args) { return cudaEventElapsedTime(std::forward(args)...); }; + + static constexpr auto kStreamCaptureModeGlobal = cudaStreamCaptureModeGlobal; + + static constexpr auto kStreamCaptureModeThreadLocal = + cudaStreamCaptureModeThreadLocal; + + static constexpr auto kStreamCaptureModeRelaxed = + cudaStreamCaptureModeRelaxed; + + static constexpr auto StreamBeginCapture = [](auto&&... args) { + return cudaStreamBeginCapture(std::forward(args)...); + }; + + static constexpr auto StreamEndCapture = [](auto&&... args) { + return cudaStreamEndCapture(std::forward(args)...); + }; + + static constexpr auto GraphDestroy = [](auto&&... args) { + return cudaGraphDestroy(std::forward(args)...); + }; + + static constexpr auto GraphInstantiate = [](auto&&... args) { + return cudaGraphInstantiate(std::forward(args)...); + }; + + static constexpr auto GraphExecDestroy = [](auto&&... args) { + return cudaGraphExecDestroy(std::forward(args)...); + }; + + static constexpr auto GraphLaunch = [](auto&&... args) { + return cudaGraphLaunch(std::forward(args)...); + }; + + static constexpr bool Validate() { + CudaRuntime>::Validate(); + static_assert(sizeof(Graph) > 0, + "`Runtime` must define a `Graph` type alias."); + static_assert(sizeof(GraphExec) > 0, + "`Runtime` must define a `GraphExec` type alias."); + static_assert(std::is_invocable_v, + "`Runtime::StreamCreate` must be callable with `(Stream*)`."); + static_assert(std::is_invocable_v, + "`Runtime::StreamDestroy` must be callable with `(Stream)`."); + static_assert( + std::is_invocable_v, + "`Runtime::StreamSynchronize` must be callable with `(Stream)`."); + static_assert(std::is_invocable_v, + "`Runtime::MemcpyAsync` must be callable with " + "`(void*, const void*, size_t, cudaMemcpyKind, Stream)`."); + static_assert(std::is_invocable_v, + "`Runtime::StreamBeginCapture` must be callable with " + "`(Stream, cudaStreamCaptureMode)`."); + static_assert( + std::is_invocable_v, + "`Runtime::StreamEndCapture` must be callable with " + "`(Stream, Graph*)`."); + static_assert(std::is_invocable_v, + "`Runtime::GraphDestroy` must be callable with `(Graph)`."); + static_assert( + std::is_invocable_v, + "`Runtime::GraphInstantiate` must be callable with " + "`(GraphExec*, Graph)`."); + static_assert( + std::is_invocable_v, + "`Runtime::GraphExecDestroy` must be callable with `(GraphExec)`."); + static_assert( + std::is_invocable_v, + "`Runtime::GraphLaunch` must be callable with `(GraphExec, Stream)`."); + return true; + } }; static_assert(Runtime::Validate()); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 2fcf5b4..4da61b3 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -26,6 +26,17 @@ function(add_infini_rt_backend_runtime_test backend device_type runtime_header "INFINI_RT_TEST_SUPPORTS_EVENT_ELAPSED_TIME=${supports_event_elapsed_time}") endfunction() +function(add_infini_rt_backend_graph_test backend device_type supports_graph_capture) + string(TOLOWER "${backend}" backend_lower) + set(target "test_${backend_lower}_graph") + add_infini_rt_test(${target} test_native_graph.cc) + target_compile_definitions(${target} + PRIVATE + "INFINI_RT_TEST_BACKEND_NAME=\"${backend}\"" + "INFINI_RT_TEST_DEVICE_TYPE=${device_type}" + "INFINI_RT_TEST_SUPPORTS_GRAPH_CAPTURE=${supports_graph_capture}") +endfunction() + add_infini_rt_test(test_smoke test_smoke.cc) add_infini_rt_test(test_core test_core.cc) @@ -44,6 +55,8 @@ if(WITH_NVIDIA) NVIDIA infini::rt::Device::Type::kNvidia infini/rt/nvidia/runtime_.h 1 1 1 1 1 1 1 1) + add_infini_rt_backend_graph_test( + NVIDIA infini::rt::Device::Type::kNvidia 1) endif() if(WITH_ILUVATAR) @@ -52,6 +65,8 @@ if(WITH_ILUVATAR) ILUVATAR infini::rt::Device::Type::kIluvatar infini/rt/iluvatar/runtime_.h 1 1 1 1 1 1 1 1) + add_infini_rt_backend_graph_test( + ILUVATAR infini::rt::Device::Type::kIluvatar 1) endif() if(WITH_HYGON) @@ -68,6 +83,8 @@ if(WITH_METAX) METAX infini::rt::Device::Type::kMetax infini/rt/metax/runtime_.h 1 1 1 1 1 1 1 1) + add_infini_rt_backend_graph_test( + METAX infini::rt::Device::Type::kMetax 0) endif() if(WITH_MOORE) @@ -76,6 +93,8 @@ if(WITH_MOORE) MOORE infini::rt::Device::Type::kMoore infini/rt/moore/runtime_.h 1 1 0 1 1 1 1 1) + add_infini_rt_backend_graph_test( + MOORE infini::rt::Device::Type::kMoore 0) endif() if(WITH_CAMBRICON) diff --git a/tests/test_native_graph.cc b/tests/test_native_graph.cc new file mode 100644 index 0000000..17a51c8 --- /dev/null +++ b/tests/test_native_graph.cc @@ -0,0 +1,196 @@ +#include + +#include +#include +#include +#include +#include + +#include "test_helper.h" + +namespace { + +namespace runtime = infini::rt::runtime; + +void ExpectSuccess(infini::rt::test::TestContext* context, + runtime::Error status, std::string_view message) { + context->Expect(status == runtime::kSuccess, message); +} + +void FillPattern(std::array* input, std::uint8_t salt) { + for (std::size_t i = 0; i < input->size(); ++i) { + (*input)[i] = static_cast(i * 13 + salt); + } +} + +bool CopyDeviceToHostAndValidate(infini::rt::test::TestContext* context, + void* device_ptr, + const std::array& expected, + std::string_view message) { + std::array output{}; + ExpectSuccess(context, + runtime::Memcpy(output.data(), device_ptr, output.size(), + runtime::kMemcpyDeviceToHost), + "Failed to copy graph output to host."); + return context->ExpectEqual(output, expected, message); +} + +bool SelectTestDevice(infini::rt::test::TestContext* context) { + infini::rt::set_runtime_device_type(INFINI_RT_TEST_DEVICE_TYPE); + + int count = 0; + ExpectSuccess(context, runtime::GetDeviceCount(&count), + "Failed to query device count."); + if (context->ExitCode() != 0) { + return false; + } + + if (count == 0) { + std::cout << "Skipping " << INFINI_RT_TEST_BACKEND_NAME + << " graph test because no device is available." << std::endl; + return false; + } + + ExpectSuccess(context, runtime::SetDevice(0), + "Failed to set graph test device."); + return context->ExitCode() == 0; +} + +void RunUnsupportedGraphSmoke(infini::rt::test::TestContext* context) { + runtime::Stream stream = nullptr; + ExpectSuccess(context, runtime::StreamCreate(&stream), + "Failed to create stream for unsupported graph smoke."); + if (context->ExitCode() != 0) { + return; + } + + const auto status = runtime::StreamBeginCapture( + stream, runtime::StreamCaptureMode::kStreamCaptureModeRelaxed); + context->Expect(status != runtime::kSuccess, + "Unsupported graph capture should return an error."); + + ExpectSuccess(context, runtime::StreamDestroy(stream), + "Failed to destroy stream for unsupported graph smoke."); +} + +void RunGraphReplayTest(infini::rt::test::TestContext* context) { + void* src = nullptr; + void* dst = nullptr; + runtime::Stream stream = nullptr; + runtime::Graph graph = nullptr; + runtime::GraphExec graph_exec = nullptr; + + std::array capture_input{}; + FillPattern(&capture_input, 7); + + ExpectSuccess(context, runtime::Malloc(&src, capture_input.size()), + "Failed to allocate source buffer."); + ExpectSuccess(context, runtime::Malloc(&dst, capture_input.size()), + "Failed to allocate destination buffer."); + ExpectSuccess(context, runtime::StreamCreate(&stream), + "Failed to create stream."); + + if (context->ExitCode() == 0) { + ExpectSuccess( + context, + runtime::Memcpy(src, capture_input.data(), capture_input.size(), + runtime::kMemcpyHostToDevice), + "Failed to initialize source buffer."); + ExpectSuccess(context, runtime::Memset(dst, 0, capture_input.size()), + "Failed to initialize destination buffer."); + + ExpectSuccess( + context, + runtime::StreamBeginCapture( + stream, runtime::StreamCaptureMode::kStreamCaptureModeRelaxed), + "Failed to begin stream capture."); + ExpectSuccess(context, + runtime::MemcpyAsync(dst, src, capture_input.size(), + runtime::kMemcpyDeviceToDevice, stream), + "Failed to record device-to-device copy."); + ExpectSuccess(context, runtime::StreamEndCapture(stream, &graph), + "Failed to end stream capture."); + + ExpectSuccess(context, runtime::GraphInstantiate(&graph_exec, graph), + "Failed to instantiate graph."); + } + + std::array replay_input_1{}; + std::array replay_input_2{}; + FillPattern(&replay_input_1, 31); + FillPattern(&replay_input_2, 53); + + if (context->ExitCode() == 0) { + ExpectSuccess( + context, + runtime::Memcpy(src, replay_input_1.data(), replay_input_1.size(), + runtime::kMemcpyHostToDevice), + "Failed to refresh first source buffer."); + ExpectSuccess(context, runtime::Memset(dst, 0, replay_input_1.size()), + "Failed to clear first destination buffer."); + ExpectSuccess(context, runtime::DeviceSynchronize(), + "Failed to synchronize first replay inputs."); + ExpectSuccess(context, runtime::GraphLaunch(graph_exec, stream), + "Failed to launch first graph replay."); + ExpectSuccess(context, runtime::StreamSynchronize(stream), + "Failed to synchronize first graph replay."); + CopyDeviceToHostAndValidate(context, dst, replay_input_1, + "First graph replay should copy D2D data."); + } + + if (context->ExitCode() == 0) { + ExpectSuccess( + context, + runtime::Memcpy(src, replay_input_2.data(), replay_input_2.size(), + runtime::kMemcpyHostToDevice), + "Failed to refresh second source buffer."); + ExpectSuccess(context, runtime::Memset(dst, 0, replay_input_2.size()), + "Failed to clear second destination buffer."); + ExpectSuccess(context, runtime::DeviceSynchronize(), + "Failed to synchronize second replay inputs."); + ExpectSuccess(context, runtime::GraphLaunch(graph_exec, stream), + "Failed to launch second graph replay."); + ExpectSuccess(context, runtime::StreamSynchronize(stream), + "Failed to synchronize second graph replay."); + CopyDeviceToHostAndValidate(context, dst, replay_input_2, + "Second graph replay should copy D2D data."); + } + + if (graph_exec != nullptr) { + ExpectSuccess(context, runtime::GraphExecDestroy(graph_exec), + "Failed to destroy graph exec."); + } + if (graph != nullptr) { + ExpectSuccess(context, runtime::GraphDestroy(graph), + "Failed to destroy graph."); + } + if (stream != nullptr) { + ExpectSuccess(context, runtime::StreamDestroy(stream), + "Failed to destroy stream."); + } + if (dst != nullptr) { + ExpectSuccess(context, runtime::Free(dst), + "Failed to free destination buffer."); + } + if (src != nullptr) { + ExpectSuccess(context, runtime::Free(src), "Failed to free source buffer."); + } +} + +} // namespace + +int main() { + infini::rt::test::TestContext context; + + if (!SelectTestDevice(&context)) { + return context.ExitCode(); + } + + if constexpr (INFINI_RT_TEST_SUPPORTS_GRAPH_CAPTURE) { + RunGraphReplayTest(&context); + } else { + RunUnsupportedGraphSmoke(&context); + } + + return context.ExitCode(); +}