From 27c90defe9bff047b41bf32651cc090650643b09 Mon Sep 17 00:00:00 2001 From: Justin Yip Date: Thu, 18 Apr 2024 09:59:19 -0700 Subject: [PATCH 1/2] [ET-VK][6/n] aten.view_copy aten.view_copy, supporting all packing. Using @ssjia's idea to do a direct lookup. Differential Revision: [D56281400](https://our.internmc.facebook.com/intern/diff/D56281400/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/indexing_utils.h | 34 +++++++- .../vulkan/runtime/graph/ops/glsl/view.glsl | 80 +++++++++++++++++++ .../vulkan/runtime/graph/ops/glsl/view.yaml | 15 ++++ .../vulkan/runtime/graph/ops/impl/View.cpp | 61 ++++++++++++++ backends/vulkan/test/op_tests/cases.py | 28 +++++++ .../test/op_tests/utils/codegen_base.py | 6 +- 6 files changed, 221 insertions(+), 3 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/view.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/view.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/View.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index b3195ee7511..24cd9bef9fe 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -8,9 +8,21 @@ #define divup4(x) ((x + 3) / 4) +// Input: idx is a ivec4 user-level coordinate, sizes is the tensor shape +// Output: buffer_idx in the continous nchw-buffer. #define to_buffer_i(idx, sizes) \ - idx.x + idx.y* sizes.x + idx.z* sizes.y* sizes.x + \ - idx.w* sizes.z* sizes.y* sizes.x; + (idx.x + idx.y* sizes.x + idx.z* sizes.y* sizes.x + \ + idx.w* sizes.z* sizes.y* sizes.x) + +// Inverse of to_buffer_i +// Input: buffer_idx in the continous nchw-buffer, sizes is the tensor shape +// Output: ivec4 user-level coorindate +#define from_buffer_i(buf_i, sizes) \ + ivec4( \ + buf_i % sizes.x, \ + (buf_i / (sizes.x)) % sizes.y, \ + (buf_i / (sizes.x * sizes.y)) % sizes.z, \ + (buf_i / (sizes.x * sizes.y * sizes.z))) #define get_packed_dim_C_packed(vec) vec.z #define get_packed_dim_W_packed(vec) vec.x @@ -20,6 +32,8 @@ #define get_packed_stride_W_packed(vec) (1) #define get_packed_stride_H_packed(vec) (vec.x) +// Input: pos is a texture position, sizes is a pack-aligned size. +// Output: a user-level (w, h, c, n) coordinate #define to_tensor_idx_C_packed(pos, sizes) \ ivec4(pos.x, pos.y, (pos.z * 4) % sizes.z, (pos.z * 4) / sizes.z) @@ -29,6 +43,9 @@ #define to_tensor_idx_H_packed(pos, sizes) \ ivec4(pos.x, (pos.y * 4), pos.z % sizes.z, pos.z / sizes.z) +// Input: idx is a user-level (w, h, c, n) coordinate. size is a pack-aligned +// size. +// Output: texture location #define to_texture_pos_C_packed(idx, sizes) \ ivec3(idx.x, idx.y, (idx.z + idx.w * sizes.z) / 4) @@ -38,6 +55,19 @@ #define to_texture_pos_H_packed(idx, sizes) \ ivec3(idx.x, idx.y / 4, (idx.z + idx.w * sizes.z)) +// Input: idx is a user-level (w, h, c, n) coordinate. size is a pack-aligned +// size with the index in the texel. +// Output: ivec4, xyz is the texture position, w is the element index in the +// texel. +#define to_texture_pos_elem_C_packed(idx, sizes) \ + ivec4(idx.x, idx.y, (idx.z + idx.w * sizes.z) / 4, idx.z % 4) + +#define to_texture_pos_elem_W_packed(idx, sizes) \ + ivec4(idx.x / 4, idx.y, (idx.z + idx.w * sizes.z), idx.x % 4) + +#define to_texture_pos_elem_H_packed(idx, sizes) \ + ivec4(idx.x, idx.y / 4, (idx.z + idx.w * sizes.z), idx.y % 4) + // Given a buffer(1-D) index cur, compute a new index where the corresponding // tensor(N-D)'s adjacent dimensions are swapped. The parameters x,y and plane // describe sizes. As an example, let's say we want to swap dimensions 0,1 for a diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.glsl b/backends/vulkan/runtime/graph/ops/glsl/view.glsl new file mode 100644 index 00000000000..cc1aa46ea10 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view.glsl @@ -0,0 +1,80 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +#define VEC4_T ${texel_type(DTYPE)} + +#define to_tensor_idx to_tensor_idx_${PACKING} +#define to_texture_pos_elem to_texture_pos_elem_${PACKING} +#define get_packed_stride get_packed_stride_${PACKING} + +layout(set = 0, binding = 2) uniform PRECISION restrict OutGpuSizes { + uvec4 data; +} +out_gpu_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict OutCpuSizes { + uvec4 data; +} +out_cpu_sizes; + +layout(set = 0, binding = 4) uniform PRECISION restrict InGpuSizes { + uvec4 data; +} +in_gpu_sizes; + +layout(set = 0, binding = 5) uniform PRECISION restrict InCpuSizes { + uvec4 data; +} +in_cpu_sizes; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + + +void main() { + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + const ivec4 out_tensor_idx = to_tensor_idx(out_pos, out_gpu_sizes.data); + + if (all(greaterThanEqual(out_tensor_idx, out_gpu_sizes.data))) { + return; + } + + // Assume there is a virtual continous buffer in nchw format. From the output + // pos, we first calculate the index in the virual buffer, and then calculate + // the input position from the indx. + + const uint base_index = to_buffer_i(out_tensor_idx, out_cpu_sizes.data); + const uvec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * get_packed_stride(out_cpu_sizes.data); + + VEC4_T value; + // Need to look up the 4 values in the output texel separately. + for (int i=0; i<4; i++) { + ivec4 user_coor = from_buffer_i(buf_indices[i], in_cpu_sizes.data); + + ivec4 in_pos_elem = to_texture_pos_elem(user_coor, in_gpu_sizes.data); + + VEC4_T intex = VEC4_T(texelFetch(image_in, in_pos_elem.xyz, 0)); + + value[i] = intex[in_pos_elem.w]; + } + + imageStore(image_out, out_pos, value); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.yaml b/backends/vulkan/runtime/graph/ops/glsl/view.yaml new file mode 100644 index 00000000000..8c91f7d76f0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view.yaml @@ -0,0 +1,15 @@ +view: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + PACKING: + - VALUE: C_packed + - VALUE: W_packed + - VALUE: H_packed + shader_variants: + - NAME: view + diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp new file mode 100644 index 00000000000..2b8d5984878 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +namespace vkcompute { + +void add_view_node( + ComputeGraph& graph, + ValueRef in, + ValueRef size_ref, + ValueRef out) { + // Note: size_ref is not used here. Since the output tensor's size have been + // determined during compilation. + vTensorPtr t_in = graph.get_tensor(in); + vTensorPtr t_out = graph.get_tensor(out); + + std::string kernel_name = "view"; + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, *t_out); + add_memory_layout_suffix(kernel_name, *t_out); + + api::utils::uvec3 global_size = t_out->virtual_extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, + { + t_out->gpu_sizes_ubo(), + t_out->cpu_sizes_ubo(), + t_in->gpu_sizes_ubo(), + t_in->cpu_sizes_ubo()})); +} + + +void view(ComputeGraph& graph, const std::vector& args) { + return add_view_node(graph, args[0], args[1], args[2]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.view_copy.default, view); +} + + +} // namespace vkcompute + diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 4e556ce6fd5..31cbed07989 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -194,6 +194,33 @@ def get_permute_inputs(): return test_suite +def get_view_inputs(): + test_suite = VkTestSuite( + [ + ((3,4,5), [1, 1, -1]), + ((3,4,5), [1, -1, 1]), + ((3,4,5), [-1, 1, 1]), + ((8, 7, 2, 3), [4, 3, 7, 4]), + ((8, 7, 2, 3), [7, -1, 2, 1]), + ((8, 7, 2, 3), [1, 1, 1, -1]), + ((8, 7, 2, 3), [-1]), + ((2, 3, 3, 7), [2, -1, 1, 1]), + ((3, 5, 2, 7), [7, -1, 2, 1]), + ((2, 2, 8, 6), [2, 6, -1, 1]), + ((2, 2, 8, 6), [6, -1, 1]), + ((S1, S2, S1, S2), [S2, -1, 1, S1]), + ((S1, S2, S1, S2), [S1, 1, -1, S2]), + ((S1, S2, S1, S2), [-1, 1, S1, S2]), + ] + ) + test_suite.layouts = [ + "api::kWidthPacked", + "api::kHeightPacked", + "api::kChannelsPacked", + ] + return test_suite + + test_suites = { "aten.add.Tensor": get_binary_elementwise_inputs(), "aten.sub.Tensor": get_binary_elementwise_inputs(), @@ -208,4 +235,5 @@ def get_permute_inputs(): "aten.select_copy.int": get_select_int_inputs(), "aten.permute.default": get_permute_inputs(), "aten.permute_copy.default": get_permute_inputs(), + "aten.view_copy.default": get_view_inputs(), } diff --git a/backends/vulkan/test/op_tests/utils/codegen_base.py b/backends/vulkan/test/op_tests/utils/codegen_base.py index f3b13335932..a0f29ce53cd 100644 --- a/backends/vulkan/test/op_tests/utils/codegen_base.py +++ b/backends/vulkan/test/op_tests/utils/codegen_base.py @@ -105,10 +105,14 @@ def gen_case_name(self, inputs: List[Any], prepack: bool = False) -> str: for size in arg_sizes_or_val: name_str += str(size) + "x" name_str = name_str[:-1] + # minus sign is a invalid char for test case. change to "n". + name_str = name_str.replace('-', 'n') elif isinstance(arg_sizes_or_val, list): for size in arg_sizes_or_val: name_str += str(size) + "c" name_str = name_str[:-1] + # minus sign is a invalid char for test case. change to "n". + name_str = name_str.replace('-', 'n') else: name_str += str(arg_sizes_or_val).replace(".", "p") return name_str @@ -234,7 +238,7 @@ def generate_suite_cpp(self) -> str: // from_blob doesn't take ownership of data. Hence must create a copy as // "values" will go out of scope. - return at::from_blob(values.data(), sizes, dtype).detach().clone(); + return at::from_blob(values.data(), sizes, at::kFloat).toType(dtype).detach().clone(); }} {test_suites_cpp} From b4ed969dff04718310f9f415ac0aea5960b657a5 Mon Sep 17 00:00:00 2001 From: Justin Yip Date: Thu, 18 Apr 2024 15:07:18 -0700 Subject: [PATCH 2/2] Update on "[ET-VK][6/n] aten.view_copy" aten.view_copy, supporting all packing. Using ssjia's idea to do a direct lookup. Differential Revision: [D56281400](https://our.internmc.facebook.com/intern/diff/D56281400/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/indexing_utils.h | 18 +++++----- .../vulkan/runtime/graph/ops/glsl/view.glsl | 32 ++++++++--------- .../vulkan/runtime/graph/ops/glsl/view.yaml | 1 - .../vulkan/runtime/graph/ops/impl/View.cpp | 34 +++++++------------ backends/vulkan/test/op_tests/cases.py | 6 ++-- .../test/op_tests/utils/codegen_base.py | 6 ++-- 6 files changed, 42 insertions(+), 55 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index 24cd9bef9fe..7231003c51b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -9,18 +9,18 @@ #define divup4(x) ((x + 3) / 4) // Input: idx is a ivec4 user-level coordinate, sizes is the tensor shape -// Output: buffer_idx in the continous nchw-buffer. -#define to_buffer_i(idx, sizes) \ - (idx.x + idx.y* sizes.x + idx.z* sizes.y* sizes.x + \ - idx.w* sizes.z* sizes.y* sizes.x) +// Output: buffer_idx in the continuous nchw-buffer. +#define to_buffer_i(idx, sizes) \ + (idx.x + idx.y * sizes.x + idx.z * sizes.y * sizes.x + \ + idx.w * sizes.z * sizes.y * sizes.x) // Inverse of to_buffer_i -// Input: buffer_idx in the continous nchw-buffer, sizes is the tensor shape +// Input: buffer_idx in the continuous nchw-buffer, sizes is the tensor shape // Output: ivec4 user-level coorindate -#define from_buffer_i(buf_i, sizes) \ - ivec4( \ - buf_i % sizes.x, \ - (buf_i / (sizes.x)) % sizes.y, \ +#define from_buffer_i(buf_i, sizes) \ + ivec4( \ + buf_i % sizes.x, \ + (buf_i / (sizes.x)) % sizes.y, \ (buf_i / (sizes.x * sizes.y)) % sizes.z, \ (buf_i / (sizes.x * sizes.y * sizes.z))) diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.glsl b/backends/vulkan/runtime/graph/ops/glsl/view.glsl index cc1aa46ea10..f7664ce5127 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/view.glsl @@ -26,33 +26,29 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; #define get_packed_stride get_packed_stride_${PACKING} layout(set = 0, binding = 2) uniform PRECISION restrict OutGpuSizes { - uvec4 data; -} -out_gpu_sizes; + uvec4 out_gpu_sizes; +}; layout(set = 0, binding = 3) uniform PRECISION restrict OutCpuSizes { - uvec4 data; -} -out_cpu_sizes; + uvec4 out_cpu_sizes; +}; layout(set = 0, binding = 4) uniform PRECISION restrict InGpuSizes { - uvec4 data; -} -in_gpu_sizes; + uvec4 in_gpu_sizes; +}; layout(set = 0, binding = 5) uniform PRECISION restrict InCpuSizes { - uvec4 data; -} -in_cpu_sizes; + uvec4 in_cpu_sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { const ivec3 out_pos = ivec3(gl_GlobalInvocationID); - const ivec4 out_tensor_idx = to_tensor_idx(out_pos, out_gpu_sizes.data); + const ivec4 out_tensor_idx = to_tensor_idx(out_pos, out_gpu_sizes); - if (all(greaterThanEqual(out_tensor_idx, out_gpu_sizes.data))) { + if (all(greaterThanEqual(out_tensor_idx, out_gpu_sizes))) { return; } @@ -60,16 +56,16 @@ void main() { // pos, we first calculate the index in the virual buffer, and then calculate // the input position from the indx. - const uint base_index = to_buffer_i(out_tensor_idx, out_cpu_sizes.data); + const uint base_index = to_buffer_i(out_tensor_idx, out_cpu_sizes); const uvec4 buf_indices = - base_index + ivec4(0, 1, 2, 3) * get_packed_stride(out_cpu_sizes.data); + base_index + ivec4(0, 1, 2, 3) * get_packed_stride(out_cpu_sizes); VEC4_T value; // Need to look up the 4 values in the output texel separately. for (int i=0; i<4; i++) { - ivec4 user_coor = from_buffer_i(buf_indices[i], in_cpu_sizes.data); + ivec4 user_coor = from_buffer_i(buf_indices[i], in_cpu_sizes); - ivec4 in_pos_elem = to_texture_pos_elem(user_coor, in_gpu_sizes.data); + ivec4 in_pos_elem = to_texture_pos_elem(user_coor, in_gpu_sizes); VEC4_T intex = VEC4_T(texelFetch(image_in, in_pos_elem.xyz, 0)); diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.yaml b/backends/vulkan/runtime/graph/ops/glsl/view.yaml index 8c91f7d76f0..7d337028c9e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/view.yaml @@ -12,4 +12,3 @@ view: - VALUE: H_packed shader_variants: - NAME: view - diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 2b8d5984878..8b5175038b0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -8,30 +8,22 @@ #include -#include - #include #include #include namespace vkcompute { -void add_view_node( - ComputeGraph& graph, - ValueRef in, - ValueRef size_ref, - ValueRef out) { - // Note: size_ref is not used here. Since the output tensor's size have been - // determined during compilation. +void add_view_node(ComputeGraph& graph, ValueRef in, ValueRef out) { vTensorPtr t_in = graph.get_tensor(in); vTensorPtr t_out = graph.get_tensor(out); - + std::string kernel_name = "view"; kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, *t_out); add_memory_layout_suffix(kernel_name, *t_out); - - api::utils::uvec3 global_size = t_out->virtual_extents(); + + api::utils::uvec3 global_size = t_out->extents(); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); graph.execute_nodes().emplace_back(new ExecuteNode( @@ -40,22 +32,20 @@ void add_view_node( global_size, local_size, {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, - { - t_out->gpu_sizes_ubo(), - t_out->cpu_sizes_ubo(), - t_in->gpu_sizes_ubo(), - t_in->cpu_sizes_ubo()})); + {t_out->gpu_sizes_ubo(), + t_out->cpu_sizes_ubo(), + t_in->gpu_sizes_ubo(), + t_in->cpu_sizes_ubo()})); } - void view(ComputeGraph& graph, const std::vector& args) { - return add_view_node(graph, args[0], args[1], args[2]); + // Note: The second argument size_ref is not used here. Since the output + // tensor's size have been determined during compilation. + return add_view_node(graph, args[0], args[2]); } REGISTER_OPERATORS { VK_REGISTER_OP(aten.view_copy.default, view); } - -} // namespace vkcompute - +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 31cbed07989..d5d0c5a6e56 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -197,9 +197,9 @@ def get_permute_inputs(): def get_view_inputs(): test_suite = VkTestSuite( [ - ((3,4,5), [1, 1, -1]), - ((3,4,5), [1, -1, 1]), - ((3,4,5), [-1, 1, 1]), + ((3, 4, 5), [1, 1, -1]), + ((3, 4, 5), [1, -1, 1]), + ((3, 4, 5), [-1, 1, 1]), ((8, 7, 2, 3), [4, 3, 7, 4]), ((8, 7, 2, 3), [7, -1, 2, 1]), ((8, 7, 2, 3), [1, 1, 1, -1]), diff --git a/backends/vulkan/test/op_tests/utils/codegen_base.py b/backends/vulkan/test/op_tests/utils/codegen_base.py index a0f29ce53cd..ff3509db5ca 100644 --- a/backends/vulkan/test/op_tests/utils/codegen_base.py +++ b/backends/vulkan/test/op_tests/utils/codegen_base.py @@ -106,13 +106,15 @@ def gen_case_name(self, inputs: List[Any], prepack: bool = False) -> str: name_str += str(size) + "x" name_str = name_str[:-1] # minus sign is a invalid char for test case. change to "n". - name_str = name_str.replace('-', 'n') + name_str = name_str.replace("-", "n") + elif isinstance(arg_sizes_or_val, list): for size in arg_sizes_or_val: name_str += str(size) + "c" name_str = name_str[:-1] # minus sign is a invalid char for test case. change to "n". - name_str = name_str.replace('-', 'n') + name_str = name_str.replace("-", "n") + else: name_str += str(arg_sizes_or_val).replace(".", "p") return name_str