Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 75 additions & 40 deletions scripts/generate_public_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,24 +210,28 @@ def _parse_runtime_functions(runtime_header):
_Function(
return_type,
name,
tuple(_parse_param(param) for param in params.split(", ") if param),
tuple(
_parse_param(param)
for param in re.split(r",\s*", params.strip())
if param
),
)
for return_type, name, params in re.findall(
r"^(void) ([A-Z]\w*)\(([^()]*)\);$", text, re.MULTILINE
r"^(infini::rt::Error|Error|void) ([A-Z]\w*)\(([^()]*)\);$",
text,
re.MULTILINE,
)
)


def _abort_statement(message):
return f""" assert(false && "{message}");
std::abort();"""
def _unsupported_statement():
return " return infini::rt::kUnSuccess;"


def _dispatch_cases(devices, statements):
return "\n".join(
f""" case {_DEVICE_TYPES[device]}: {{
{statements.replace("__DEVICE_TYPE__", _DEVICE_TYPES[device])}
return;
}}"""
for device in devices
)
Expand All @@ -243,19 +247,21 @@ def _selector(function):
return "current_device.type()"


def _runtime_arg(param):
def _runtime_arg(function, param):
if param.type == "Device":
return f"{param.name}.index()"
if function.name in {"SetDevice", "GetDeviceResourceSnapshot"}:
return f"{param.name}.index()"
return None
if param.type == "Device::Type":
return None
if param.type == "MemcpyKind":
if param.type in {"MemcpyKind", "infini::rt::MemcpyKind"}:
return f"RuntimeMemcpyKind<__DEVICE_TYPE__>({param.name})"

return param.name


def _runtime_args(function):
args = (_runtime_arg(param) for param in function.params)
args = (_runtime_arg(function, param) for param in function.params)

return ", ".join(arg for arg in args if arg is not None)

Expand All @@ -267,17 +273,21 @@ def _preconditions(function):
}
checks = []
for param in function.params:
if param.type.endswith("**") or param.name in required_pointer_names.get(
function.name, set()
if (
param.type.endswith("**")
or param.type.endswith("*")
or param.name in required_pointer_names.get(function.name, set())
):
checks.append(f" assert({param.name} != nullptr);")
checks.append(f" if ({param.name} == nullptr) {{")
checks.append(" return infini::rt::kUnSuccess;")
checks.append(" }")

return "\n".join(checks)


def _post_dispatch(function):
if function.name == "SetDevice":
return "\n current_device = device;"
return "\n if (rt_status == infini::rt::kSuccess) {\n current_device = Device{current_device.type(), device};\n }"

return ""

Expand All @@ -294,19 +304,23 @@ def _write_get_device(function, devices):
device_param = function.params[0].name
cases = _dispatch_cases(
devices,
f""" int index = current_device.index();
CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GetDevice(&index); }});
current_device = Device{{current_device.type(), index}};
*{device_param} = current_device;""",
f""" infini::rt::Error status = CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GetDevice({device_param}); }});
if (status != infini::rt::kSuccess) {{
return status;
}}
current_device = Device{{current_device.type(), *{device_param}}};
return infini::rt::kSuccess;""",
)

return f"""void GetDevice(Device* {device_param}) {{
assert({device_param} != nullptr);
return f"""{function.return_type} GetDevice(int* {device_param}) {{
if ({device_param} == nullptr) {{
return infini::rt::kUnSuccess;
}}

switch (current_device.type()) {{
{cases}
default:
{_abort_statement("runtime device is not enabled")}
{_unsupported_statement()}
}}
}}
"""
Expand All @@ -318,7 +332,8 @@ def _write_dispatch_function(function, devices):

cases = _dispatch_cases(
devices,
f""" CheckCall([&] {{ return {_runtime_call(function)}; }});{_post_dispatch(function)}""",
f""" infini::rt::Error rt_status = CheckCall([&] {{ return {_runtime_call(function)}; }});{_post_dispatch(function)}
return rt_status;""",
)
preconditions = _preconditions(function)
if preconditions:
Expand All @@ -328,25 +343,42 @@ def _write_dispatch_function(function, devices):
{preconditions} switch ({_selector(function)}) {{
{cases}
default:
{_abort_statement("runtime device is not enabled")}
{_unsupported_statement()}
}}
}}
"""


def _runtime_header_for_device(source_root, device):
return source_root / _RUNTIME_HEADERS[device]


def _devices_for_function(function, devices, source_root):
enabled = []
pattern = re.compile(r"\b" + re.escape(function.name) + r"\b")
for device in devices:
runtime_header = _runtime_header_for_device(source_root, device)
if runtime_header.exists() and pattern.search(runtime_header.read_text()):
enabled.append(device)
return enabled


def _write_runtime_dispatch(source_path, runtime_header, devices):
first_device_type = _DEVICE_TYPES[devices[0]]
includes = ['#include "runtime.h"']
includes.extend(f'#include "{_RUNTIME_HEADERS[device]}"' for device in devices)
functions = _parse_runtime_functions(runtime_header)
source_root = pathlib.Path(runtime_header).parent
dispatch_functions = "\n".join(
_write_dispatch_function(function, devices) for function in functions
_write_dispatch_function(
function, _devices_for_function(function, devices, source_root)
)
for function in functions
)

source_path.parent.mkdir(parents=True, exist_ok=True)
source_path.write_text(
f"""#include <cassert>
#include <cstdlib>
#include <type_traits>
#include <utility>

Expand All @@ -358,36 +390,39 @@ def _write_runtime_dispatch(source_path, runtime_header, devices):
thread_local Device current_device{{{first_device_type}, 0}};

template <typename Func>
void CheckCall(Func&& func) {{
infini::rt::Error CheckCall(Func&& func) {{
using ReturnType = decltype(std::forward<Func>(func)());

if constexpr (std::is_void_v<ReturnType>) {{
std::forward<Func>(func)();
return infini::rt::kSuccess;
}} else {{
ReturnType status = std::forward<Func>(func)();
if (status != ReturnType{{}}) {{
assert(false && "runtime call failed");
std::abort();
if constexpr (std::is_same_v<ReturnType, infini::rt::Error>) {{
return status == infini::rt::kSuccess ? infini::rt::kSuccess
: infini::rt::kUnSuccess;
}} else {{
return status == ReturnType{{}} ? infini::rt::kSuccess
: infini::rt::kUnSuccess;
}}
}}
}}

template <Device::Type kDev>
auto RuntimeMemcpyKind(MemcpyKind kind) {{
auto RuntimeMemcpyKind(infini::rt::MemcpyKind kind) {{
switch (kind) {{
case MemcpyKind::kHostToHost:
return Runtime<kDev>::MemcpyHostToHost;
case MemcpyKind::kHostToDevice:
return Runtime<kDev>::MemcpyHostToDevice;
case MemcpyKind::kDeviceToHost:
return Runtime<kDev>::MemcpyDeviceToHost;
case MemcpyKind::kDeviceToDevice:
return Runtime<kDev>::MemcpyDeviceToDevice;
case infini::rt::MemcpyKind::kMemcpyHostToHost:
return Runtime<kDev>::kMemcpyHostToHost;
case infini::rt::MemcpyKind::kMemcpyHostToDevice:
return Runtime<kDev>::kMemcpyHostToDevice;
case infini::rt::MemcpyKind::kMemcpyDeviceToHost:
return Runtime<kDev>::kMemcpyDeviceToHost;
case infini::rt::MemcpyKind::kMemcpyDeviceToDevice:
return Runtime<kDev>::kMemcpyDeviceToDevice;
}}

assert(false && "unsupported memcpy kind");
std::abort();
return Runtime<kDev>::MemcpyHostToHost;
return Runtime<kDev>::kMemcpyHostToHost;
}}

}} // namespace
Expand Down
Loading
Loading