Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 59 additions & 21 deletions src/infinicore/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "../utils.hpp"
#include "infinicore/context/context.hpp"
#include "standalone_infinirt_graph_bridge.hpp"
#include <infinirt.h>

namespace infinicore::graph {
Expand Down Expand Up @@ -32,9 +33,11 @@ DispatchableGraphOperator::~DispatchableGraphOperator() {
* ========================= */

struct Graph::DeviceGraph {
infinirtGraph_t graph;
infinirtGraphExec_t exec;
infinirtGraphNode_t node;
infinirtGraph_t graph = nullptr;
infinirtGraphExec_t exec = nullptr;
infinirtGraphNode_t node = nullptr;
infinirtStream_t stream = nullptr;
bool standalone = false;
std::vector<char> log_buffer;

DeviceGraph() {
Expand All @@ -43,15 +46,27 @@ struct Graph::DeviceGraph {

~DeviceGraph() {
if (exec) {
infinirtGraphExecDestroy(exec);
if (standalone) {
standalone_infinirt::graph_exec_destroy(exec);
} else {
infinirtGraphExecDestroy(exec);
}
}
if (graph) {
infinirtGraphDestroy(graph);
if (standalone) {
standalone_infinirt::graph_destroy(graph);
} else {
infinirtGraphDestroy(graph);
}
}
}

void launch() {
INFINICORE_CHECK_ERROR(infinirtGraphLuanch(exec, context::getStream()));
if (standalone) {
INFINICORE_CHECK_ERROR(standalone_infinirt::graph_launch(exec, stream));
} else {
INFINICORE_CHECK_ERROR(infinirtGraphLuanch(exec, context::getStream()));
}
}
};

Expand All @@ -75,42 +90,65 @@ void Graph::add_operator(std::shared_ptr<GraphOperator> op) {
void Graph::instantiate() {
// Reset device graph
device_graph_ = std::make_unique<DeviceGraph>();
device_graph_->standalone = standalone_infinirt::available(context::getDevice());
device_graph_->stream = context::getStream();
if (device_graph_->standalone) {
auto set_device_status = standalone_infinirt::set_device(context::getDevice());
if (set_device_status != INFINI_STATUS_SUCCESS) {
spdlog::warn("Standalone InfiniRT graph bridge failed to select the current device. Falling back to eager execution.");
device_graph_.reset();
return;
}

static bool logged_once = false;
if (!logged_once) {
logged_once = true;
spdlog::info("Using standalone InfiniRT graph bridge for graph capture and replay.");
}
}

// warmup
for (size_t iter = 0; iter < 5; ++iter) {
this->run();
}
infinicore::context::syncStream();

if (infinirtStreamBeginCapture(
context::getStream(),
INFINIRT_STREAM_CAPTURE_MODE_RELAXED)
!= INFINI_STATUS_SUCCESS) {
auto begin_status = device_graph_->standalone
? standalone_infinirt::stream_begin_capture(device_graph_->stream, INFINIRT_STREAM_CAPTURE_MODE_RELAXED)
: infinirtStreamBeginCapture(context::getStream(), INFINIRT_STREAM_CAPTURE_MODE_RELAXED);
if (begin_status != INFINI_STATUS_SUCCESS) {
spdlog::warn("Fail to begin device graph capture.");
device_graph_.reset();
return;
}

// Run and record
this->run();

if (infinirtStreamEndCapture(
context::getStream(),
&device_graph_.get()->graph)
!= INFINI_STATUS_SUCCESS) {
auto end_status = device_graph_->standalone
? standalone_infinirt::stream_end_capture(device_graph_->stream, &device_graph_.get()->graph)
: infinirtStreamEndCapture(context::getStream(), &device_graph_.get()->graph);
if (end_status != INFINI_STATUS_SUCCESS) {
spdlog::warn("Fail to end device graph capture.");
device_graph_.reset();
return;
}

if (infinirtGraphInstantiate(
&device_graph_.get()->exec,
device_graph_.get()->graph,
&device_graph_.get()->node,
device_graph_.get()->log_buffer.data(),
device_graph_.get()->log_buffer.size())
!= INFINI_STATUS_SUCCESS) {
auto instantiate_status = device_graph_->standalone
? standalone_infinirt::graph_instantiate(&device_graph_.get()->exec, device_graph_.get()->graph)
: infinirtGraphInstantiate(
&device_graph_.get()->exec,
device_graph_.get()->graph,
&device_graph_.get()->node,
device_graph_.get()->log_buffer.data(),
device_graph_.get()->log_buffer.size());
if (instantiate_status != INFINI_STATUS_SUCCESS) {
static bool warned_once = false;
if (!warned_once) {
warned_once = true;
spdlog::warn("Fail to instantiate device graph: {}", std::string(device_graph_.get()->log_buffer.data()));
}
device_graph_.reset();
}
}

Expand Down
173 changes: 173 additions & 0 deletions src/infinicore/graph/standalone_infinirt_graph_bridge.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
#include "standalone_infinirt_graph_bridge.hpp"

#ifdef USE_STANDALONE_INFINIRT_GRAPH

#include <cstdlib>
#include <string>

#include <infini/rt.h>

namespace infinicore::graph::standalone_infinirt {
namespace {

using StandaloneDevice = infini::rt::Device;
using StandaloneRuntime = infini::rt::runtime::Runtime<StandaloneDevice::Type::kNvidia>;

bool truthy_env(const char *name) {
auto value = std::getenv(name);
if (value == nullptr) {
return false;
}
std::string text{value};
return text == "1" || text == "ON" || text == "on" || text == "true" || text == "TRUE";
}

bool supports_device(Device::Type type) {
return type == Device::Type::NVIDIA;
}

template <typename Status>
infiniStatus_t to_core_status(Status status) {
return status == StandaloneRuntime::kSuccess
? INFINI_STATUS_SUCCESS
: INFINI_STATUS_INTERNAL_ERROR;
}

StandaloneRuntime::Stream to_standalone_stream(infinirtStream_t stream) {
return reinterpret_cast<StandaloneRuntime::Stream>(stream);
}

StandaloneRuntime::Graph to_standalone_graph(infinirtGraph_t graph) {
return reinterpret_cast<StandaloneRuntime::Graph>(graph);
}

StandaloneRuntime::GraphExec to_standalone_graph_exec(infinirtGraphExec_t graph_exec) {
return reinterpret_cast<StandaloneRuntime::GraphExec>(graph_exec);
}

decltype(StandaloneRuntime::kStreamCaptureModeRelaxed)
to_standalone_capture_mode(infinirtStreamCaptureMode_t mode) {
switch (mode) {
case INFINIRT_STREAM_CAPTURE_MODE_GLOBAL:
return StandaloneRuntime::kStreamCaptureModeGlobal;
case INFINIRT_STREAM_CAPTURE_MODE_THREAD_LOCAL:
return StandaloneRuntime::kStreamCaptureModeThreadLocal;
case INFINIRT_STREAM_CAPTURE_MODE_RELAXED:
return StandaloneRuntime::kStreamCaptureModeRelaxed;
}
return StandaloneRuntime::kStreamCaptureModeRelaxed;
}

} // namespace

bool enabled() {
return truthy_env("INFINICORE_USE_STANDALONE_INFINIRT_GRAPH");
}

bool available(const Device &device) {
return enabled() && supports_device(device.getType());
}

infiniStatus_t set_device(const Device &device) {
if (!supports_device(device.getType())) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return to_core_status(StandaloneRuntime::SetDevice(static_cast<int>(device.getIndex())));
}

infiniStatus_t stream_begin_capture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
if (stream == nullptr) {
return INFINI_STATUS_NULL_POINTER;
}
return to_core_status(StandaloneRuntime::StreamBeginCapture(
to_standalone_stream(stream),
to_standalone_capture_mode(mode)));
}

infiniStatus_t stream_end_capture(infinirtStream_t stream, infinirtGraph_t *graph) {
if (stream == nullptr || graph == nullptr) {
return INFINI_STATUS_NULL_POINTER;
}
return to_core_status(StandaloneRuntime::StreamEndCapture(
to_standalone_stream(stream),
reinterpret_cast<StandaloneRuntime::Graph *>(graph)));
}

infiniStatus_t graph_destroy(infinirtGraph_t graph) {
if (graph == nullptr) {
return INFINI_STATUS_NULL_POINTER;
}
return to_core_status(StandaloneRuntime::GraphDestroy(to_standalone_graph(graph)));
}

infiniStatus_t graph_instantiate(infinirtGraphExec_t *graph_exec, infinirtGraph_t graph) {
if (graph_exec == nullptr || graph == nullptr) {
return INFINI_STATUS_NULL_POINTER;
}
return to_core_status(StandaloneRuntime::GraphInstantiate(
reinterpret_cast<StandaloneRuntime::GraphExec *>(graph_exec),
to_standalone_graph(graph)));
}

infiniStatus_t graph_exec_destroy(infinirtGraphExec_t graph_exec) {
if (graph_exec == nullptr) {
return INFINI_STATUS_NULL_POINTER;
}
return to_core_status(StandaloneRuntime::GraphExecDestroy(
to_standalone_graph_exec(graph_exec)));
}

infiniStatus_t graph_launch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
if (graph_exec == nullptr || stream == nullptr) {
return INFINI_STATUS_NULL_POINTER;
}
return to_core_status(StandaloneRuntime::GraphLaunch(
to_standalone_graph_exec(graph_exec),
to_standalone_stream(stream)));
}

} // namespace infinicore::graph::standalone_infinirt

#else

namespace infinicore::graph::standalone_infinirt {

bool enabled() {
return false;
}

bool available(const Device &) {
return false;
}

infiniStatus_t set_device(const Device &) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}

infiniStatus_t stream_begin_capture(infinirtStream_t, infinirtStreamCaptureMode_t) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}

infiniStatus_t stream_end_capture(infinirtStream_t, infinirtGraph_t *) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}

infiniStatus_t graph_destroy(infinirtGraph_t) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}

infiniStatus_t graph_instantiate(infinirtGraphExec_t *, infinirtGraph_t) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}

infiniStatus_t graph_exec_destroy(infinirtGraphExec_t) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}

infiniStatus_t graph_launch(infinirtGraphExec_t, infinirtStream_t) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}

} // namespace infinicore::graph::standalone_infinirt

#endif
26 changes: 26 additions & 0 deletions src/infinicore/graph/standalone_infinirt_graph_bridge.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include "infinicore/device.hpp"
#include <infinirt.h>

namespace infinicore::graph::standalone_infinirt {

bool enabled();

bool available(const Device &device);

infiniStatus_t set_device(const Device &device);

infiniStatus_t stream_begin_capture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode);

infiniStatus_t stream_end_capture(infinirtStream_t stream, infinirtGraph_t *graph);

infiniStatus_t graph_destroy(infinirtGraph_t graph);

infiniStatus_t graph_instantiate(infinirtGraphExec_t *graph_exec, infinirtGraph_t graph);

infiniStatus_t graph_exec_destroy(infinirtGraphExec_t graph_exec);

infiniStatus_t graph_launch(infinirtGraphExec_t graph_exec, infinirtStream_t stream);

} // namespace infinicore::graph::standalone_infinirt
Loading
Loading