diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 52194ea82e3..fbef45d8641 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -485,24 +485,45 @@ utils::uvec3 ComputeGraph::create_local_wg_size( return config_.local_wg_size_override; } - utils::uvec3 local_group_size = {4, 4, 4}; + // array containing axis index and global workgroup size + std::pair global_wg_size_desc[] = { + {0u, global_wg_size[0]}, + {1u, global_wg_size[1]}, + {2u, global_wg_size[2]}}; + + // sort the global workgroup size in descending order + if (global_wg_size_desc[0].second < global_wg_size_desc[1].second) { + std::swap(global_wg_size_desc[0], global_wg_size_desc[1]); + } + if (global_wg_size_desc[1].second < global_wg_size_desc[2].second) { + std::swap(global_wg_size_desc[1], global_wg_size_desc[2]); + } + if (global_wg_size_desc[0].second < global_wg_size_desc[1].second) { + std::swap(global_wg_size_desc[0], global_wg_size_desc[1]); + } - if (global_wg_size[2u] == 1) { - if (global_wg_size[1u] == 1) { + utils::uvec3 local_group_size = { + 8, + std::max(1u, std::min(4u, global_wg_size_desc[1].second)), + std::max(1u, std::min(2u, global_wg_size_desc[2].second))}; + + if (global_wg_size_desc[2u].second == 1) { + if (global_wg_size_desc[1u].second == 1) { local_group_size[0u] = 64; local_group_size[1u] = 1; - local_group_size[2u] = 1; - } else if (global_wg_size[1u] < 8) { + } else if (global_wg_size_desc[1u].second % 4 == 0) { local_group_size[0u] = 16; local_group_size[1u] = 4; - local_group_size[2u] = 1; } else { - local_group_size[0u] = 8; - local_group_size[1u] = 8; - local_group_size[2u] = 1; + local_group_size[0u] = 32; + local_group_size[1u] = 2; } } - return local_group_size; + + return { + local_group_size[global_wg_size_desc[0].first], + local_group_size[global_wg_size_desc[1].first], + local_group_size[global_wg_size_desc[2].first]}; } utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {