Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/native/ascend/runtime_.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define INFINI_RT_ASCEND_RUNTIME__H_

#include <cassert>
#include <cstddef>
#include <cstdint>

// clang-format off
Expand All @@ -20,6 +21,8 @@ struct Runtime<Device::Type::kAscend>

using Stream = aclrtStream;

using Event = void*;

static constexpr Device::Type kDeviceType = Device::Type::kAscend;

static constexpr Error kSuccess = ACL_SUCCESS;
Expand All @@ -42,8 +45,20 @@ struct Runtime<Device::Type::kAscend>
return aclrtMalloc(ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
};

static Error MallocHost(void**, std::size_t) { return Unsupported(); }

static Error MallocAsync(void**, std::size_t, Stream) {
return Unsupported();
}

static constexpr auto Free = aclrtFree;

static Error FreeHost(void*) { return Unsupported(); }

static Error FreeAsync(void*, Stream) { return Unsupported(); }

static Error MemGetInfo(std::size_t*, std::size_t*) { return Unsupported(); }

static constexpr auto Memcpy = [](void* dst, const void* src, size_t count,
aclrtMemcpyKind kind) {
return aclrtMemcpy(dst, count, src, count, kind);
Expand All @@ -66,6 +81,39 @@ struct Runtime<Device::Type::kAscend>
static constexpr auto Memset = [](void* ptr, int value, size_t count) {
return aclrtMemset(ptr, count, value, count);
};

static Error MemsetAsync(void*, int, std::size_t, Stream) {
return Unsupported();
}

static constexpr auto StreamCreate = aclrtCreateStream;

static constexpr auto StreamDestroy = aclrtDestroyStream;

static constexpr auto StreamSynchronize = aclrtSynchronizeStream;

static Error StreamWaitEvent(Stream, Event, unsigned int) {
return Unsupported();
}

static Error EventCreate(Event*) { return Unsupported(); }

static Error EventCreateWithFlags(Event*, unsigned int) {
return Unsupported();
}

static Error EventRecord(Event, Stream) { return Unsupported(); }

static Error EventQuery(Event) { return Unsupported(); }

static Error EventSynchronize(Event) { return Unsupported(); }

static Error EventDestroy(Event) { return Unsupported(); }

static Error EventElapsedTime(float*, Event, Event) { return Unsupported(); }

private:
static Error Unsupported() { return static_cast<Error>(1); }
};

static_assert(Runtime<Device::Type::kAscend>::Validate());
Expand Down
47 changes: 47 additions & 0 deletions src/native/cambricon/runtime_.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ struct Runtime<Device::Type::kCambricon>

using Stream = cnrtQueue_t;

using Event = void*;

static constexpr Device::Type kDeviceType = Device::Type::kCambricon;

#ifdef CNRT_RET_SUCCESS
Expand All @@ -43,8 +45,20 @@ struct Runtime<Device::Type::kCambricon>

static constexpr auto Malloc = cnrtMalloc;

static Error MallocHost(void**, std::size_t) { return Unsupported(); }

static Error MallocAsync(void**, std::size_t, Stream) {
return Unsupported();
}

static constexpr auto Free = cnrtFree;

static Error FreeHost(void*) { return Unsupported(); }

static Error FreeAsync(void*, Stream) { return Unsupported(); }

static Error MemGetInfo(std::size_t*, std::size_t*) { return Unsupported(); }

static constexpr auto Memcpy = [](void* dst, const void* src,
std::size_t size, auto kind) {
return cnrtMemcpy(dst, const_cast<void*>(src), size, kind);
Expand All @@ -70,6 +84,39 @@ struct Runtime<Device::Type::kCambricon>
static constexpr auto kMemcpyDeviceToDevice = cnrtMemcpyDevToDev;

static constexpr auto Memset = cnrtMemset;

static Error MemsetAsync(void*, int, std::size_t, Stream) {
return Unsupported();
}

static constexpr auto StreamCreate = cnrtQueueCreate;

static constexpr auto StreamDestroy = cnrtQueueDestroy;

static constexpr auto StreamSynchronize = cnrtQueueSync;

static Error StreamWaitEvent(Stream, Event, unsigned int) {
return Unsupported();
}

static Error EventCreate(Event*) { return Unsupported(); }

static Error EventCreateWithFlags(Event*, unsigned int) {
return Unsupported();
}

static Error EventRecord(Event, Stream) { return Unsupported(); }

static Error EventQuery(Event) { return Unsupported(); }

static Error EventSynchronize(Event) { return Unsupported(); }

static Error EventDestroy(Event) { return Unsupported(); }

static Error EventElapsedTime(float*, Event, Event) { return Unsupported(); }

private:
static Error Unsupported() { return static_cast<Error>(1); }
};

static_assert(Runtime<Device::Type::kCambricon>::Validate());
Expand Down
74 changes: 72 additions & 2 deletions src/native/cuda/hygon/runtime_.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ struct Runtime<Device::Type::kHygon>

using Stream = cudaStream_t;

using Event = cudaEvent_t;

static constexpr Device::Type kDeviceType = Device::Type::kHygon;

static constexpr Error kSuccess = cudaSuccess;
Expand All @@ -35,14 +37,34 @@ struct Runtime<Device::Type::kHygon>
return cudaMalloc(std::forward<decltype(args)>(args)...);
};

static constexpr auto Memcpy = cudaMemcpy;
static constexpr auto MallocHost = [](auto&&... args) {
return cudaMallocHost(std::forward<decltype(args)>(args)...);
};

static constexpr auto MemcpyAsync = cudaMemcpyAsync;
static constexpr auto MallocAsync = [](auto&&... args) {
return cudaMallocAsync(std::forward<decltype(args)>(args)...);
};

static constexpr auto Free = [](auto&&... args) {
return cudaFree(std::forward<decltype(args)>(args)...);
};

static constexpr auto FreeHost = [](auto&&... args) {
return cudaFreeHost(std::forward<decltype(args)>(args)...);
};

static constexpr auto FreeAsync = [](auto&&... args) {
return cudaFreeAsync(std::forward<decltype(args)>(args)...);
};

static constexpr auto MemGetInfo = [](auto&&... args) {
return cudaMemGetInfo(std::forward<decltype(args)>(args)...);
};

static constexpr auto Memcpy = cudaMemcpy;

static constexpr auto MemcpyAsync = cudaMemcpyAsync;

static constexpr auto kMemcpyHostToHost = cudaMemcpyHostToHost;

static constexpr auto kMemcpyHostToDevice = cudaMemcpyHostToDevice;
Expand All @@ -52,6 +74,54 @@ struct Runtime<Device::Type::kHygon>
static constexpr auto kMemcpyDeviceToDevice = cudaMemcpyDeviceToDevice;

static constexpr auto Memset = cudaMemset;

static constexpr auto MemsetAsync = [](auto&&... args) {
return cudaMemsetAsync(std::forward<decltype(args)>(args)...);
};

static constexpr auto StreamCreate = [](auto&&... args) {
return cudaStreamCreate(std::forward<decltype(args)>(args)...);
};

static constexpr auto StreamDestroy = [](auto&&... args) {
return cudaStreamDestroy(std::forward<decltype(args)>(args)...);
};

static constexpr auto StreamSynchronize = [](auto&&... args) {
return cudaStreamSynchronize(std::forward<decltype(args)>(args)...);
};

static constexpr auto StreamWaitEvent = [](auto&&... args) {
return cudaStreamWaitEvent(std::forward<decltype(args)>(args)...);
};

static constexpr auto EventCreate = [](auto&&... args) {
return cudaEventCreate(std::forward<decltype(args)>(args)...);
};

static constexpr auto EventCreateWithFlags = [](auto&&... args) {
return cudaEventCreateWithFlags(std::forward<decltype(args)>(args)...);
};

static constexpr auto EventRecord = [](auto&&... args) {
return cudaEventRecord(std::forward<decltype(args)>(args)...);
};

static constexpr auto EventQuery = [](auto&&... args) {
return cudaEventQuery(std::forward<decltype(args)>(args)...);
};

static constexpr auto EventSynchronize = [](auto&&... args) {
return cudaEventSynchronize(std::forward<decltype(args)>(args)...);
};

static constexpr auto EventDestroy = [](auto&&... args) {
return cudaEventDestroy(std::forward<decltype(args)>(args)...);
};

static constexpr auto EventElapsedTime = [](auto&&... args) {
return cudaEventElapsedTime(std::forward<decltype(args)>(args)...);
};
};

static_assert(Runtime<Device::Type::kHygon>::Validate());
Expand Down
74 changes: 72 additions & 2 deletions src/native/cuda/iluvatar/runtime_.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ struct Runtime<Device::Type::kIluvatar>

using Stream = cudaStream_t;

using Event = cudaEvent_t;

static constexpr Device::Type kDeviceType = Device::Type::kIluvatar;

static constexpr Error kSuccess = cudaSuccess;
Expand All @@ -35,12 +37,32 @@ struct Runtime<Device::Type::kIluvatar>
return cudaMalloc(std::forward<decltype(args)>(args)...);
};

static constexpr auto Memcpy = cudaMemcpy;
static constexpr auto MallocHost = [](auto&&... args) {
return cudaMallocHost(std::forward<decltype(args)>(args)...);
};

static constexpr auto MemcpyAsync = cudaMemcpyAsync;
static constexpr auto MallocAsync = [](auto&&... args) {
return cudaMallocAsync(std::forward<decltype(args)>(args)...);
};

static constexpr auto Free = cudaFree;

static constexpr auto FreeHost = [](auto&&... args) {
return cudaFreeHost(std::forward<decltype(args)>(args)...);
};

static constexpr auto FreeAsync = [](auto&&... args) {
return cudaFreeAsync(std::forward<decltype(args)>(args)...);
};

static constexpr auto MemGetInfo = [](auto&&... args) {
return cudaMemGetInfo(std::forward<decltype(args)>(args)...);
};

static constexpr auto Memcpy = cudaMemcpy;

static constexpr auto MemcpyAsync = cudaMemcpyAsync;

static constexpr auto kMemcpyHostToHost = cudaMemcpyHostToHost;

static constexpr auto kMemcpyHostToDevice = cudaMemcpyHostToDevice;
Expand All @@ -50,6 +72,54 @@ struct Runtime<Device::Type::kIluvatar>
static constexpr auto kMemcpyDeviceToDevice = cudaMemcpyDeviceToDevice;

static constexpr auto Memset = cudaMemset;

static constexpr auto MemsetAsync = [](auto&&... args) {
return cudaMemsetAsync(std::forward<decltype(args)>(args)...);
};

static constexpr auto StreamCreate = [](auto&&... args) {
return cudaStreamCreate(std::forward<decltype(args)>(args)...);
};

static constexpr auto StreamDestroy = [](auto&&... args) {
return cudaStreamDestroy(std::forward<decltype(args)>(args)...);
};

static constexpr auto StreamSynchronize = [](auto&&... args) {
return cudaStreamSynchronize(std::forward<decltype(args)>(args)...);
};

static constexpr auto StreamWaitEvent = [](auto&&... args) {
return cudaStreamWaitEvent(std::forward<decltype(args)>(args)...);
};

static constexpr auto EventCreate = [](auto&&... args) {
return cudaEventCreate(std::forward<decltype(args)>(args)...);
};

static constexpr auto EventCreateWithFlags = [](auto&&... args) {
return cudaEventCreateWithFlags(std::forward<decltype(args)>(args)...);
};

static constexpr auto EventRecord = [](auto&&... args) {
return cudaEventRecord(std::forward<decltype(args)>(args)...);
};

static constexpr auto EventQuery = [](auto&&... args) {
return cudaEventQuery(std::forward<decltype(args)>(args)...);
};

static constexpr auto EventSynchronize = [](auto&&... args) {
return cudaEventSynchronize(std::forward<decltype(args)>(args)...);
};

static constexpr auto EventDestroy = [](auto&&... args) {
return cudaEventDestroy(std::forward<decltype(args)>(args)...);
};

static constexpr auto EventElapsedTime = [](auto&&... args) {
return cudaEventElapsedTime(std::forward<decltype(args)>(args)...);
};
};

static_assert(Runtime<Device::Type::kIluvatar>::Validate());
Expand Down
Loading
Loading