From e0db27e6f5a869dcb503cf2dacb445a9216c182a Mon Sep 17 00:00:00 2001 From: Justin Yip Date: Mon, 15 Apr 2024 20:01:17 -0700 Subject: [PATCH] aten.select.int (#3033) Summary: Port over the `select.int` shaders to ET. 1. Since in ET, tensor-shape reasoning happens in AOT, therefore we can simplify the c++ caller code by a lot. 2. In this diff, we also try to use the same buffer object for passing arguments to all shaders. Not worry about perf cost, since cost difference between passing int and ivec4 is very minor. Reviewed By: SS-JIA Differential Revision: D56082483 --- backends/vulkan/runtime/api/Tensor.h | 8 ++ backends/vulkan/runtime/graph/Logging.h | 6 + .../graph/ops/glsl/select_batch_4d.glsl | 53 +++++++ .../graph/ops/glsl/select_batch_4d.yaml | 10 ++ .../graph/ops/glsl/select_channel_3d.glsl | 50 +++++++ .../graph/ops/glsl/select_channel_3d.yaml | 10 ++ .../graph/ops/glsl/select_channel_4d.glsl | 64 +++++++++ .../graph/ops/glsl/select_channel_4d.yaml | 10 ++ .../graph/ops/glsl/select_height_3d.glsl | 62 ++++++++ .../graph/ops/glsl/select_height_3d.yaml | 10 ++ .../graph/ops/glsl/select_height_4d.glsl | 62 ++++++++ .../graph/ops/glsl/select_height_4d.yaml | 10 ++ .../graph/ops/glsl/select_width_3d.glsl | 61 ++++++++ .../graph/ops/glsl/select_width_3d.yaml | 10 ++ .../graph/ops/glsl/select_width_4d.glsl | 65 +++++++++ .../graph/ops/glsl/select_width_4d.yaml | 10 ++ .../vulkan/runtime/graph/ops/impl/Select.cpp | 132 ++++++++++++++++++ backends/vulkan/test/op_tests/cases.py | 24 ++++ 18 files changed, 657 insertions(+) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/select_height_3d.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/select_height_4d.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/select_height_4d.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/select_width_3d.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/select_width_3d.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/select_width_4d.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/select_width_4d.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Select.cpp diff --git a/backends/vulkan/runtime/api/Tensor.h b/backends/vulkan/runtime/api/Tensor.h index a6bd8b0531c..e1e28bf582c 100644 --- a/backends/vulkan/runtime/api/Tensor.h +++ b/backends/vulkan/runtime/api/Tensor.h @@ -255,6 +255,14 @@ class vTensor final { return sizes_; } + inline const int64_t size(size_t dim) const { + return sizes().at(dim); + } + + inline const int64_t dim() const { + return sizes_.size(); + } + inline const std::vector& strides() const { return strides_; } diff --git a/backends/vulkan/runtime/graph/Logging.h b/backends/vulkan/runtime/graph/Logging.h index fd455321952..f2684081332 100644 --- a/backends/vulkan/runtime/graph/Logging.h +++ b/backends/vulkan/runtime/graph/Logging.h @@ -8,6 +8,8 @@ #pragma once +#include + #include #include @@ -23,4 +25,8 @@ inline std::ostream& operator<<(std::ostream& os, const std::vector& vec) { return os; // Return the ostream to allow chaining } +inline std::ostream& operator<<(std::ostream& os, const api::utils::uvec3& v) { + return api::utils::operator<<(os, v); +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl new file mode 100644 index 00000000000..39f3681ceec --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl @@ -0,0 +1,53 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal { + // data.x: index along batch dim to select + // data.y: number of batches + // data.z: number of texels per batch + // data.w: unused + ivec4 data; +} +select_info; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const int num_batches = select_info.data.y; + const int num_texel_per_batch = select_info.data.z; + const int index = select_info.data.x; + + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + const uint src_pos_z = (num_texel_per_batch * index) + pos.z; + imageStore( + image_out, pos, texelFetch(image_in, ivec3(pos.x, pos.y, src_pos_z), 0)); +} + diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.yaml b/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.yaml new file mode 100644 index 00000000000..9c7d54c8f69 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.yaml @@ -0,0 +1,10 @@ +select_batch_4d: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: select_batch_4d diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl new file mode 100644 index 00000000000..dab728ef346 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl @@ -0,0 +1,50 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} +#define T ${texel_component_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +// index to select +layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { + int data; +} +index; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + const int tex = index.data / 4; + const int ind = index.data % 4; + const T v = VEC4_T(texelFetch(image_in, ivec3(pos.x, pos.y, tex), 0))[ind]; + + imageStore(image_out, ivec3(pos.x, pos.y, 0), VEC4_T(v, 0, 0, 0)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.yaml b/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.yaml new file mode 100644 index 00000000000..1c5c4e34b06 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.yaml @@ -0,0 +1,10 @@ +select_channel_3d: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: select_channel_3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl new file mode 100644 index 00000000000..6979e7fed21 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl @@ -0,0 +1,64 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal { + // data.x: index along channel dim to select + // data.y: number of batches + // data.z: number of texels per batch + // data.w: unused + ivec4 data; +} +select_info; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + const int num_batches = select_info.data.y; + const int num_texel_per_batch = select_info.data.z; + const int index = select_info.data.x; + + // read in the same channel from 4 separate batches + VEC4_T out_texel = VEC4_T(0, 0, 0, 0); + for (int k = 0; k < 4; k++) { + if ((k + pos.z * 4) >= + num_batches) { + break; + } + const uint src_pos_z = (4 * num_texel_per_batch * pos.z) + + (k * num_texel_per_batch) + (index / 4); + const uint src_pos_t = index % 4; + out_texel[k] = + VEC4_T(texelFetch(image_in, ivec3(pos.x, pos.y, src_pos_z), 0))[src_pos_t]; + } + + imageStore(image_out, pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.yaml b/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.yaml new file mode 100644 index 00000000000..6236555f5dd --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.yaml @@ -0,0 +1,10 @@ +select_channel_4d: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: select_channel_4d diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl new file mode 100644 index 00000000000..3ca92d3dcd4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +// index to select +layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { + int data; +} +index; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + // w + const int src_x = pos.x; + // h + const int src_y = index.data; + // c + const int src_z = pos.y; + + const VEC4_T v = VEC4_T(texelFetch(image_in, ivec3(src_x, src_y, src_z), 0)); + + for (int i = 0; i < 4; i++) { + ivec3 new_pos = ivec3(pos.x, pos.y * 4 + i, 0); + + // When the C-channel exceeds original block size, exit early + if (new_pos.y >= out_sizes.data.y) { + return; + } + + imageStore(image_out, new_pos, VEC4_T(v[i], 0, 0, 0)); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.yaml b/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.yaml new file mode 100644 index 00000000000..a373f1decd9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.yaml @@ -0,0 +1,10 @@ +select_height_3d: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: select_height_3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.glsl new file mode 100644 index 00000000000..1381c3c5fc4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.glsl @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +// index to select +layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { + // data.x: index along height dim to select + // data.y: number of batches + // data.z: number of texels per batch + // data.w: unused + ivec4 data; +} +select_info; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + const int num_batches = select_info.data.y; + const int num_texel_per_batch = select_info.data.z; + const int index = select_info.data.x; + + VEC4_T out_texel = VEC4_T(0, 0, 0, 0); + // read in the same channel from 4 separate batches + for (int k = 0; k < 4; k++) { + if ((k + pos.z * 4) >= num_batches + ) { // < 4 batches for this texel, exit early + break; + } + const uint src_pos_z = (pos.z * num_texel_per_batch * 4) + + k * num_texel_per_batch + (pos.y / 4); + out_texel[k] = VEC4_T(texelFetch( + image_in, ivec3(pos.x, index, src_pos_z), 0))[pos.y % 4]; + } + imageStore(image_out, pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.yaml b/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.yaml new file mode 100644 index 00000000000..c3724f1157a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.yaml @@ -0,0 +1,10 @@ +select_height_4d: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: select_height_4d diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.glsl new file mode 100644 index 00000000000..6f1ffcfe826 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.glsl @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +// index to select +layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { + int data; +} +index; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + // w + const int src_x = index.data; + // h + const int src_y = pos.x; + // c + const int src_z = pos.y; + + const VEC4_T v = VEC4_T(texelFetch(image_in, ivec3(src_x, src_y, src_z), 0)); + + for (int i = 0; i < 4; i++) { + ivec3 new_pos = ivec3(pos.x, pos.y * 4 + i, 0); + + // When the C-channel exceeds original block size, exit early + if (new_pos.y >= out_sizes.data.y) { + return; + } + + imageStore(image_out, new_pos, VEC4_T(v[i], 0, 0, 0)); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.yaml b/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.yaml new file mode 100644 index 00000000000..a3070bf6ca3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.yaml @@ -0,0 +1,10 @@ +select_width_3d: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: select_width_3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.glsl new file mode 100644 index 00000000000..6f9b3771823 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.glsl @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +// index to select +layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal { + // data.x: index along width dim to select + // data.y: number of batches + // data.z: number of texels per batch + // data.w: unused + ivec4 data; +} +select_info; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + const int num_batches = select_info.data.y; + const int num_texel_per_batch = select_info.data.z; + const int index = select_info.data.x; + + //vec4 out_texel = vec4(0, 0, 0, 0); + VEC4_T out_texel = VEC4_T(0, 0, 0, 0); + // read in the same channel from 4 separate batches + for (int k = 0; k < 4; k++) { + if ((k + pos.z * 4) >= + num_batches) { // < 4 batches for this texel, exit early + break; + } + const uint src_pos_z = (pos.z * num_texel_per_batch * 4) + + k * num_texel_per_batch + (pos.y / 4); + + out_texel[k] = VEC4_T(texelFetch( + image_in, ivec3(index, pos.x, src_pos_z), 0))[pos.y % 4]; + } + imageStore(image_out, pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.yaml b/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.yaml new file mode 100644 index 00000000000..f1131d77395 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.yaml @@ -0,0 +1,10 @@ +select_width_4d: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: select_width_4d diff --git a/backends/vulkan/runtime/graph/ops/impl/Select.cpp b/backends/vulkan/runtime/graph/ops/impl/Select.cpp new file mode 100644 index 00000000000..c66d98af4a2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Select.cpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include +#include + +namespace vkcompute { + +void check_args( + const vTensor& t_in, + int64_t dim, + int64_t index, + const vTensor& t_out) { + VK_CHECK_COND(check_memory_layout_is(t_in, api::kChannelsPacked)); + VK_CHECK_COND(check_memory_layout_is(t_out, api::kChannelsPacked)); + + const int64_t in_dim = t_in.dim(); + VK_CHECK_COND( + in_dim == 3 || in_dim == 4, + "Vulkan select only support 3d or 4d tensors!"); + + const int64_t in_size = t_in.size(dim); + + if (index < -in_size || index >= in_size) { + VK_CHECK_COND( + false, + "select(): index ", + index, + " t_outof range for tensor of size ", + in_size, + " at dimension ", + dim); + } +} + +void add_select_int_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef dim_ref, + const ValueRef index_ref, + const ValueRef out) { + vTensorPtr t_in = graph.get_tensor(in); + vTensorPtr t_out = graph.get_tensor(out); + int64_t dim = graph.extract_scalar(dim_ref); + int64_t index = graph.extract_scalar(index_ref); + + check_args(*t_in, dim, index, *t_out); + + const int64_t in_size = t_in->size(dim); + + if (index < 0) { + index += in_size; + } + + std::string kernel_name; + + // for 3d tensors, these values are not used by the shader. + int32_t num_texel_per_batch = 1; + int32_t num_batches = 1; + + int64_t in_dim = t_in->dim(); + if (in_dim == 3) { + if (dim == 0) { + kernel_name = "select_channel_3d"; + } else if (dim == 1) { + kernel_name = "select_height_3d"; + } else if (dim == 2) { + kernel_name = "select_width_3d"; + } else { + VK_CHECK_COND( + false, "Unexpected dim value=", dim, "for the input 3d tensor"); + } + } else { // self.dim() == 4 + num_texel_per_batch = + static_cast(std::ceil(static_cast(t_in->size(1)) / 4)); + num_batches = t_in->size(0); + if (dim == 0) { + kernel_name = "select_batch_4d"; + } else if (dim == 1) { + kernel_name = "select_channel_4d"; + } else if (dim == 2) { + kernel_name = "select_height_4d"; + } else if (dim == 3) { + kernel_name = "select_width_4d"; + } else { + VK_CHECK_COND( + false, "Unexpected dim value=", dim, "for the input 4d tensor"); + } + } + + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, *t_out); + + api::utils::uvec3 global_size = t_out->virtual_extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + // TODO: add resizing to support dynamic shapes. + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, + {t_out->gpu_sizes_ubo(), + // TODO: num_batches and num_texel_per_batch are provided by + // t_out->gpu_sizes. Can change the following to reduce params + // created. + + graph.create_params_buffer(api::utils::make_ivec4( + {index, num_batches, num_texel_per_batch, 0}))})); +} + +void select_int(ComputeGraph& graph, const std::vector& args) { + return add_select_int_node(graph, args[0], args[1], args[2], args[3]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.select.int, select_int); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 56baa60a9f6..31961f3e449 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -142,6 +142,29 @@ def get_full_inputs(): return test_suite +def get_select_int_inputs(): + test_suite = VkTestSuite( + [ + ((6, 2, 7), 0, 3), + ((6, 2, 7), 1, 0), + ((6, 2, 7), 2, 3), + ((6, 10, 7), 0, 3), + ((6, 10, 7), 1, 0), + ((6, 10, 7), 1, 9), + ((6, 10, 7), 2, 6), + ((9, 2, 9, 4), 0, 8), + ((9, 2, 9, 4), 1, 1), + ((9, 2, 9, 4), 2, 0), + ((9, 2, 9, 4), 2, 8), + ((9, 2, 9, 4), 3, 3), + ((8, 6, 1, 1), 0, 4), + ((8, 6, 1, 1), 1, 4), + ] + ) + test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"] + return test_suite + + test_suites = { "aten.add.Tensor": get_binary_elementwise_inputs(), "aten.sub.Tensor": get_binary_elementwise_inputs(), @@ -152,6 +175,7 @@ def get_full_inputs(): "aten.convolution.default": get_conv2d_inputs(), "aten.native_layer_norm.default": get_native_layer_norm_inputs(), "aten.full.default": get_full_inputs(), + "aten.select.int": get_select_int_inputs(), } prepacked_args = {"aten.mm.default": {"mat2"}}