From 2f0c85aae98835946bb161a98539d1261be2243e Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 16 Apr 2024 10:42:32 -0400 Subject: [PATCH] [RUNTIME][VULKAN] Support total_global_memory This PR supports total_global_memory query for vulkan devices. --- src/runtime/vulkan/vulkan_device.cc | 7 +++++-- src/runtime/vulkan/vulkan_device.h | 2 ++ src/runtime/vulkan/vulkan_device_api.cc | 1 + 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc index 7c5ac55f0b4b..cc39972432a3 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -293,7 +293,7 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_ for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { VkMemoryType ty = prop.memoryTypes[k]; - size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; + int64_t heap_size = static_cast(prop.memoryHeaps[ty.heapIndex].size); // host visible if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue; // match copy requirment @@ -312,7 +312,7 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_ win_rank = -1; for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { VkMemoryType ty = prop.memoryTypes[k]; - size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; + int64_t heap_size = static_cast(prop.memoryHeaps[ty.heapIndex].size); // host visible if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue; // match copy requirment @@ -324,8 +324,10 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_ if (rank > win_rank) { win_rank = rank; compute_mtype_index = k; + compute_memory_size = heap_size; } } + ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device."; if (device_properties.supports_push_descriptor) { @@ -383,6 +385,7 @@ void VulkanDevice::do_swap(VulkanDevice&& other) { std::swap(queue_insert_debug_utils_label_functions, other.queue_insert_debug_utils_label_functions); std::swap(compute_mtype_index, other.compute_mtype_index); + std::swap(compute_memory_size, other.compute_memory_size); std::swap(queue, other.queue); std::swap(queue_family_index, other.queue_family_index); std::swap(physical_device_, other.physical_device_); diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h index 296483a6b104..0573a00e5c9e 100644 --- a/src/runtime/vulkan/vulkan_device.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -223,6 +223,8 @@ class VulkanDevice { queue_insert_debug_utils_label_functions{nullptr}; // Memory type index for compute uint32_t compute_mtype_index{0}; + // maximum memory size for compute + int64_t compute_memory_size{0}; // queue family_index; uint32_t queue_family_index{uint32_t(-1)}; diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 18a40bf54ffd..4b337dd52455 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -165,6 +165,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) break; case kTotalGlobalMemory: { + *rv = device(index).compute_memory_size; return; } }