From eaf3bf95a764d88fa06724903c449b18d19b164f Mon Sep 17 00:00:00 2001 From: Vivek Trivedi Date: Tue, 22 Oct 2024 10:26:53 -0700 Subject: [PATCH] Calculating axis's local wg size based on global workload and making it as close as possible to warp size of 32. (#6409) Summary: This diff changes the local workgroup size calculation logic in the Vulkan backend of Executorch. The workgroup size of the largest axis is kept largest so workgroups are better occupied. The workgroup size is calculated based on the warp size of 32. When kernel is 2 dimensional largest axis is kept close to warp size it, so threads in the same warp Read / Write to consecutive memory locations, thus improving performance. Reviewed By: SS-JIA Differential Revision: D64418632 --- .../vulkan/runtime/graph/ComputeGraph.cpp | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 52194ea82e3..fbef45d8641 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -485,24 +485,45 @@ utils::uvec3 ComputeGraph::create_local_wg_size( return config_.local_wg_size_override; } - utils::uvec3 local_group_size = {4, 4, 4}; + // array containing axis index and global workgroup size + std::pair global_wg_size_desc[] = { + {0u, global_wg_size[0]}, + {1u, global_wg_size[1]}, + {2u, global_wg_size[2]}}; + + // sort the global workgroup size in descending order + if (global_wg_size_desc[0].second < global_wg_size_desc[1].second) { + std::swap(global_wg_size_desc[0], global_wg_size_desc[1]); + } + if (global_wg_size_desc[1].second < global_wg_size_desc[2].second) { + std::swap(global_wg_size_desc[1], global_wg_size_desc[2]); + } + if (global_wg_size_desc[0].second < global_wg_size_desc[1].second) { + std::swap(global_wg_size_desc[0], global_wg_size_desc[1]); + } - if (global_wg_size[2u] == 1) { - if (global_wg_size[1u] == 1) { + utils::uvec3 local_group_size = { + 8, + std::max(1u, std::min(4u, global_wg_size_desc[1].second)), + std::max(1u, std::min(2u, global_wg_size_desc[2].second))}; + + if (global_wg_size_desc[2u].second == 1) { + if (global_wg_size_desc[1u].second == 1) { local_group_size[0u] = 64; local_group_size[1u] = 1; - local_group_size[2u] = 1; - } else if (global_wg_size[1u] < 8) { + } else if (global_wg_size_desc[1u].second % 4 == 0) { local_group_size[0u] = 16; local_group_size[1u] = 4; - local_group_size[2u] = 1; } else { - local_group_size[0u] = 8; - local_group_size[1u] = 8; - local_group_size[2u] = 1; + local_group_size[0u] = 32; + local_group_size[1u] = 2; } } - return local_group_size; + + return { + local_group_size[global_wg_size_desc[0].first], + local_group_size[global_wg_size_desc[1].first], + local_group_size[global_wg_size_desc[2].first]}; } utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {