Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 106 additions & 4 deletions scripts/generate_public_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,31 @@ def _write_detail_headers(include_root, source_root, devices):
_write_detail_header(include_root, source_root, relative_path)


def _write_generated_header(include_root, devices):
def _write_generated_header(include_root, source_root, devices):
default_device = _default_device(devices)
default_device_type = _DEVICE_TYPES[default_device]
public_runtime_functions = _public_runtime_functions_for_devices(
devices, source_root
)
has_graph_api = any(
function.name in {"StreamBeginCapture", "GraphLaunch"}
for function in public_runtime_functions
)
graph_declarations = (
"""
using Graph = void*;

using GraphExec = void*;

enum class StreamCaptureMode {
kStreamCaptureModeGlobal = 0,
kStreamCaptureModeThreadLocal = 1,
kStreamCaptureModeRelaxed = 2,
};
"""
if has_graph_api
else ""
)
includes = [
"#include <cstddef>",
"#include <cstdint>",
Expand All @@ -177,7 +199,7 @@ def _write_generated_header(include_root, devices):
includes.append(f"#include <infini/rt/{device}/runtime_.h>")

runtime_declarations = "\n\n".join(
f"{function.signature()};" for function in _PUBLIC_RUNTIME_FUNCTIONS
f"{function.signature()};" for function in public_runtime_functions
)

path = include_root / "infini" / "rt" / "generated.h"
Expand Down Expand Up @@ -209,6 +231,7 @@ def _write_generated_header(include_root, devices):

using Event = void*;

{graph_declarations}
using MemcpyKind = std::remove_cv_t<
decltype(generated_detail::DefaultErrorRuntime::kMemcpyHostToHost)>;

Expand Down Expand Up @@ -366,6 +389,32 @@ def params_decl(self):
_Param("Event", "end"),
),
),
_Function(
"Error",
"StreamBeginCapture",
(_Param("Stream", "stream"), _Param("StreamCaptureMode", "mode")),
),
_Function(
"Error",
"StreamEndCapture",
(_Param("Stream", "stream"), _Param("Graph*", "graph")),
),
_Function("Error", "GraphDestroy", (_Param("Graph", "graph"),)),
_Function(
"Error",
"GraphInstantiate",
(_Param("GraphExec*", "graph_exec"), _Param("Graph", "graph")),
),
_Function(
"Error",
"GraphExecDestroy",
(_Param("GraphExec", "graph_exec"),),
),
_Function(
"Error",
"GraphLaunch",
(_Param("GraphExec", "graph_exec"), _Param("Stream", "stream")),
),
)


Expand Down Expand Up @@ -395,6 +444,24 @@ def _runtime_arg(param, device):
return (
f"reinterpret_cast<typename Runtime<{device_type}>::Event*>({param.name})"
)
if param.type == "Graph":
return f"reinterpret_cast<typename Runtime<{device_type}>::Graph>({param.name})"
if param.type == "Graph*":
return (
f"reinterpret_cast<typename Runtime<{device_type}>::Graph*>({param.name})"
)
if param.type == "GraphExec":
return (
f"reinterpret_cast<typename Runtime<{device_type}>::GraphExec>"
f"({param.name})"
)
if param.type == "GraphExec*":
return (
f"reinterpret_cast<typename Runtime<{device_type}>::GraphExec*>"
f"({param.name})"
)
if param.type == "StreamCaptureMode":
return f"RuntimeStreamCaptureMode<{device_type}>({param.name})"

return param.name

Expand Down Expand Up @@ -452,8 +519,42 @@ def _devices_for_function(function, devices, source_root):
)


def _public_runtime_functions_for_devices(devices, source_root):
return tuple(
function
for function in _PUBLIC_RUNTIME_FUNCTIONS
if _devices_for_function(function, devices, source_root)
)


def _write_runtime_dispatch(source_path, source_root, devices):
functions = _PUBLIC_RUNTIME_FUNCTIONS
functions = _public_runtime_functions_for_devices(devices, source_root)
stream_capture_mode_helper = (
"""
template <Device::Type device_type>
auto RuntimeStreamCaptureMode(StreamCaptureMode mode) {
using DeviceRuntime = Runtime<device_type>;

switch (mode) {
case StreamCaptureMode::kStreamCaptureModeGlobal:
return DeviceRuntime::kStreamCaptureModeGlobal;
case StreamCaptureMode::kStreamCaptureModeThreadLocal:
return DeviceRuntime::kStreamCaptureModeThreadLocal;
case StreamCaptureMode::kStreamCaptureModeRelaxed:
return DeviceRuntime::kStreamCaptureModeRelaxed;
}

assert(false && "unsupported stream capture mode");
return DeviceRuntime::kStreamCaptureModeRelaxed;
}
"""
if any(
param.type == "StreamCaptureMode"
for function in functions
for param in function.params
)
else ""
)
dispatch_functions = "\n".join(
_write_runtime_dispatch_function(
function,
Expand Down Expand Up @@ -535,6 +636,7 @@ def _write_runtime_dispatch(source_path, source_root, devices):
return DeviceRuntime::kMemcpyHostToHost;
}}

{stream_capture_mode_helper}
}} // namespace

{dispatch_functions}
Expand Down Expand Up @@ -566,7 +668,7 @@ def main():
for wrapper_device, header_name, target in _DEVICE_HEADERS[device]:
_write_wrapper(include_root, wrapper_device, header_name, target)

_write_generated_header(include_root, devices)
_write_generated_header(include_root, source_root, devices)
_write_runtime_dispatch(pathlib.Path(args.source_output), source_root, devices)


Expand Down
39 changes: 38 additions & 1 deletion src/native/cuda/iluvatar/runtime_.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ struct Runtime<Device::Type::kIluvatar>

using Stream = cudaStream_t;

using Graph = cudaGraph_t;

using GraphExec = cudaGraphExec_t;

using Event = cudaEvent_t;

static constexpr Device::Type kDeviceType = Device::Type::kIluvatar;
Expand Down Expand Up @@ -78,7 +82,8 @@ struct Runtime<Device::Type::kIluvatar>
};

static constexpr auto StreamCreate = [](auto&&... args) {
return cudaStreamCreate(std::forward<decltype(args)>(args)...);
return cudaStreamCreateWithFlags(std::forward<decltype(args)>(args)...,
cudaStreamNonBlocking);
};

static constexpr auto StreamDestroy = [](auto&&... args) {
Expand Down Expand Up @@ -120,6 +125,38 @@ struct Runtime<Device::Type::kIluvatar>
static constexpr auto EventElapsedTime = [](auto&&... args) {
return cudaEventElapsedTime(std::forward<decltype(args)>(args)...);
};

static constexpr auto kStreamCaptureModeGlobal = cudaStreamCaptureModeGlobal;

static constexpr auto kStreamCaptureModeThreadLocal =
cudaStreamCaptureModeThreadLocal;

static constexpr auto kStreamCaptureModeRelaxed =
cudaStreamCaptureModeRelaxed;

static constexpr auto StreamBeginCapture = [](auto&&... args) {
return cudaStreamBeginCapture(std::forward<decltype(args)>(args)...);
};

static constexpr auto StreamEndCapture = [](auto&&... args) {
return cudaStreamEndCapture(std::forward<decltype(args)>(args)...);
};

static constexpr auto GraphDestroy = [](auto&&... args) {
return cudaGraphDestroy(std::forward<decltype(args)>(args)...);
};

static constexpr auto GraphInstantiate = [](auto&&... args) {
return cudaGraphInstantiate(std::forward<decltype(args)>(args)...);
};

static constexpr auto GraphExecDestroy = [](auto&&... args) {
return cudaGraphExecDestroy(std::forward<decltype(args)>(args)...);
};

static constexpr auto GraphLaunch = [](auto&&... args) {
return cudaGraphLaunch(std::forward<decltype(args)>(args)...);
};
};

static_assert(Runtime<Device::Type::kIluvatar>::Validate());
Expand Down
27 changes: 27 additions & 0 deletions src/native/cuda/metax/runtime_.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <mcr/mc_runtime.h>

#include <cstddef>
#include <utility>

#include "native/cuda/metax/device_.h"
Expand All @@ -17,6 +18,10 @@ struct Runtime<Device::Type::kMetax>

using Stream = mcStream_t;

using Graph = void*;

using GraphExec = void*;

using Event = mcEvent_t;

static constexpr Device::Type kDeviceType = Device::Type::kMetax;
Expand Down Expand Up @@ -124,6 +129,28 @@ struct Runtime<Device::Type::kMetax>
static constexpr auto EventElapsedTime = [](auto&&... args) {
return mcEventElapsedTime(std::forward<decltype(args)>(args)...);
};

static constexpr int kStreamCaptureModeGlobal = 0;

static constexpr int kStreamCaptureModeThreadLocal = 1;

static constexpr int kStreamCaptureModeRelaxed = 2;

static Error StreamBeginCapture(Stream, int) { return static_cast<Error>(1); }

static Error StreamEndCapture(Stream, Graph*) {
return static_cast<Error>(1);
}

static Error GraphDestroy(Graph) { return static_cast<Error>(1); }

static Error GraphInstantiate(GraphExec*, Graph) {
return static_cast<Error>(1);
}

static Error GraphExecDestroy(GraphExec) { return static_cast<Error>(1); }

static Error GraphLaunch(GraphExec, Stream) { return static_cast<Error>(1); }
};

static_assert(Runtime<Device::Type::kMetax>::Validate());
Expand Down
26 changes: 26 additions & 0 deletions src/native/cuda/moore/runtime_.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ struct Runtime<Device::Type::kMoore>

using Stream = musaStream_t;

using Graph = void*;

using GraphExec = void*;

using Event = musaEvent_t;

static constexpr Device::Type kDeviceType = Device::Type::kMoore;
Expand Down Expand Up @@ -131,6 +135,28 @@ struct Runtime<Device::Type::kMoore>
static constexpr auto EventElapsedTime = [](auto&&... args) {
return musaEventElapsedTime(std::forward<decltype(args)>(args)...);
};

static constexpr int kStreamCaptureModeGlobal = 0;

static constexpr int kStreamCaptureModeThreadLocal = 1;

static constexpr int kStreamCaptureModeRelaxed = 2;

static Error StreamBeginCapture(Stream, int) { return static_cast<Error>(1); }

static Error StreamEndCapture(Stream, Graph*) {
return static_cast<Error>(1);
}

static Error GraphDestroy(Graph) { return static_cast<Error>(1); }

static Error GraphInstantiate(GraphExec*, Graph) {
return static_cast<Error>(1);
}

static Error GraphExecDestroy(GraphExec) { return static_cast<Error>(1); }

static Error GraphLaunch(GraphExec, Stream) { return static_cast<Error>(1); }
};

static_assert(Runtime<Device::Type::kMoore>::Validate());
Expand Down
Loading
Loading