Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/runtime/vulkan/vulkan_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_

for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
VkMemoryType ty = prop.memoryTypes[k];
size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
int64_t heap_size = static_cast<int64_t>(prop.memoryHeaps[ty.heapIndex].size);
// host visible
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
// match copy requirment
Expand All @@ -312,7 +312,7 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_
win_rank = -1;
for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
VkMemoryType ty = prop.memoryTypes[k];
size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
int64_t heap_size = static_cast<int64_t>(prop.memoryHeaps[ty.heapIndex].size);
// host visible
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
// match copy requirment
Expand All @@ -324,8 +324,10 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_
if (rank > win_rank) {
win_rank = rank;
compute_mtype_index = k;
compute_memory_size = heap_size;
}
}

ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device.";

if (device_properties.supports_push_descriptor) {
Expand Down Expand Up @@ -383,6 +385,7 @@ void VulkanDevice::do_swap(VulkanDevice&& other) {
std::swap(queue_insert_debug_utils_label_functions,
other.queue_insert_debug_utils_label_functions);
std::swap(compute_mtype_index, other.compute_mtype_index);
std::swap(compute_memory_size, other.compute_memory_size);
std::swap(queue, other.queue);
std::swap(queue_family_index, other.queue_family_index);
std::swap(physical_device_, other.physical_device_);
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/vulkan/vulkan_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ class VulkanDevice {
queue_insert_debug_utils_label_functions{nullptr};
// Memory type index for compute
uint32_t compute_mtype_index{0};
// maximum memory size for compute
int64_t compute_memory_size{0};

// queue family_index;
uint32_t queue_family_index{uint32_t(-1)};
Expand Down
1 change: 1 addition & 0 deletions src/runtime/vulkan/vulkan_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
break;

case kTotalGlobalMemory: {
*rv = device(index).compute_memory_size;
return;
}
}
Expand Down