diff --git a/backends/vulkan/runtime/api/Tensor.h b/backends/vulkan/runtime/api/Tensor.h index a6bd8b0531c..e1e28bf582c 100644 --- a/backends/vulkan/runtime/api/Tensor.h +++ b/backends/vulkan/runtime/api/Tensor.h @@ -255,6 +255,14 @@ class vTensor final { return sizes_; } + inline const int64_t size(size_t dim) const { + return sizes().at(dim); + } + + inline const int64_t dim() const { + return sizes_.size(); + } + inline const std::vector& strides() const { return strides_; } diff --git a/backends/vulkan/runtime/graph/Logging.h b/backends/vulkan/runtime/graph/Logging.h index fd455321952..f2684081332 100644 --- a/backends/vulkan/runtime/graph/Logging.h +++ b/backends/vulkan/runtime/graph/Logging.h @@ -8,6 +8,8 @@ #pragma once +#include + #include #include @@ -23,4 +25,8 @@ inline std::ostream& operator<<(std::ostream& os, const std::vector& vec) { return os; // Return the ostream to allow chaining } +inline std::ostream& operator<<(std::ostream& os, const api::utils::uvec3& v) { + return api::utils::operator<<(os, v); +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl new file mode 100644 index 00000000000..39f3681ceec --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl @@ -0,0 +1,53 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal { + // data.x: index along batch dim to select + // data.y: number of batches + // data.z: number of texels per batch + // data.w: unused + ivec4 data; +} +select_info; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const int num_batches = select_info.data.y; + const int num_texel_per_batch = select_info.data.z; + const int index = select_info.data.x; + + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + const uint src_pos_z = (num_texel_per_batch * index) + pos.z; + imageStore( + image_out, pos, texelFetch(image_in, ivec3(pos.x, pos.y, src_pos_z), 0)); +} + diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.yaml b/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.yaml new file mode 100644 index 00000000000..9c7d54c8f69 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.yaml @@ -0,0 +1,10 @@ +select_batch_4d: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: select_batch_4d diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl new file mode 100644 index 00000000000..dab728ef346 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl @@ -0,0 +1,50 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} +#define T ${texel_component_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +// index to select +layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { + int data; +} +index; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + const int tex = index.data / 4; + const int ind = index.data % 4; + const T v = VEC4_T(texelFetch(image_in, ivec3(pos.x, pos.y, tex), 0))[ind]; + + imageStore(image_out, ivec3(pos.x, pos.y, 0), VEC4_T(v, 0, 0, 0)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.yaml b/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.yaml new file mode 100644 index 00000000000..1c5c4e34b06 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.yaml @@ -0,0 +1,10 @@ +select_channel_3d: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: select_channel_3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl new file mode 100644 index 00000000000..6979e7fed21 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl @@ -0,0 +1,64 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal { + // data.x: index along channel dim to select + // data.y: number of batches + // data.z: number of texels per batch + // data.w: unused + ivec4 data; +} +select_info; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + const int num_batches = select_info.data.y; + const int num_texel_per_batch = select_info.data.z; + const int index = select_info.data.x; + + // read in the same channel from 4 separate batches + VEC4_T out_texel = VEC4_T(0, 0, 0, 0); + for (int k = 0; k < 4; k++) { + if ((k + pos.z * 4) >= + num_batches) { + break; + } + const uint src_pos_z = (4 * num_texel_per_batch * pos.z) + + (k * num_texel_per_batch) + (index / 4); + const uint src_pos_t = index % 4; + out_texel[k] = + VEC4_T(texelFetch(image_in, ivec3(pos.x, pos.y, src_pos_z), 0))[src_pos_t]; + } + + imageStore(image_out, pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.yaml b/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.yaml new file mode 100644 index 00000000000..6236555f5dd --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.yaml @@ -0,0 +1,10 @@ +select_channel_4d: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: select_channel_4d diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl new file mode 100644 index 00000000000..3ca92d3dcd4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +// index to select +layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { + int data; +} +index; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + // w + const int src_x = pos.x; + // h + const int src_y = index.data; + // c + const int src_z = pos.y; + + const VEC4_T v = VEC4_T(texelFetch(image_in, ivec3(src_x, src_y, src_z), 0)); + + for (int i = 0; i < 4; i++) { + ivec3 new_pos = ivec3(pos.x, pos.y * 4 + i, 0); + + // When the C-channel exceeds original block size, exit early + if (new_pos.y >= out_sizes.data.y) { + return; + } + + imageStore(image_out, new_pos, VEC4_T(v[i], 0, 0, 0)); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.yaml b/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.yaml new file mode 100644 index 00000000000..a373f1decd9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.yaml @@ -0,0 +1,10 @@ +select_height_3d: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: select_height_3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.glsl new file mode 100644 index 00000000000..1381c3c5fc4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.glsl @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +// index to select +layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { + // data.x: index along height dim to select + // data.y: number of batches + // data.z: number of texels per batch + // data.w: unused + ivec4 data; +} +select_info; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + const int num_batches = select_info.data.y; + const int num_texel_per_batch = select_info.data.z; + const int index = select_info.data.x; + + VEC4_T out_texel = VEC4_T(0, 0, 0, 0); + // read in the same channel from 4 separate batches + for (int k = 0; k < 4; k++) { + if ((k + pos.z * 4) >= num_batches + ) { // < 4 batches for this texel, exit early + break; + } + const uint src_pos_z = (pos.z * num_texel_per_batch * 4) + + k * num_texel_per_batch + (pos.y / 4); + out_texel[k] = VEC4_T(texelFetch( + image_in, ivec3(pos.x, index, src_pos_z), 0))[pos.y % 4]; + } + imageStore(image_out, pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.yaml b/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.yaml new file mode 100644 index 00000000000..c3724f1157a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.yaml @@ -0,0 +1,10 @@ +select_height_4d: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: select_height_4d diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.glsl new file mode 100644 index 00000000000..6f1ffcfe826 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.glsl @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +// index to select +layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { + int data; +} +index; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + // w + const int src_x = index.data; + // h + const int src_y = pos.x; + // c + const int src_z = pos.y; + + const VEC4_T v = VEC4_T(texelFetch(image_in, ivec3(src_x, src_y, src_z), 0)); + + for (int i = 0; i < 4; i++) { + ivec3 new_pos = ivec3(pos.x, pos.y * 4 + i, 0); + + // When the C-channel exceeds original block size, exit early + if (new_pos.y >= out_sizes.data.y) { + return; + } + + imageStore(image_out, new_pos, VEC4_T(v[i], 0, 0, 0)); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.yaml b/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.yaml new file mode 100644 index 00000000000..a3070bf6ca3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.yaml @@ -0,0 +1,10 @@ +select_width_3d: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: select_width_3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.glsl new file mode 100644 index 00000000000..6f9b3771823 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.glsl @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +// index to select +layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal { + // data.x: index along width dim to select + // data.y: number of batches + // data.z: number of texels per batch + // data.w: unused + ivec4 data; +} +select_info; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + const int num_batches = select_info.data.y; + const int num_texel_per_batch = select_info.data.z; + const int index = select_info.data.x; + + //vec4 out_texel = vec4(0, 0, 0, 0); + VEC4_T out_texel = VEC4_T(0, 0, 0, 0); + // read in the same channel from 4 separate batches + for (int k = 0; k < 4; k++) { + if ((k + pos.z * 4) >= + num_batches) { // < 4 batches for this texel, exit early + break; + } + const uint src_pos_z = (pos.z * num_texel_per_batch * 4) + + k * num_texel_per_batch + (pos.y / 4); + + out_texel[k] = VEC4_T(texelFetch( + image_in, ivec3(index, pos.x, src_pos_z), 0))[pos.y % 4]; + } + imageStore(image_out, pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.yaml b/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.yaml new file mode 100644 index 00000000000..f1131d77395 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.yaml @@ -0,0 +1,10 @@ +select_width_4d: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: select_width_4d diff --git a/backends/vulkan/runtime/graph/ops/impl/Select.cpp b/backends/vulkan/runtime/graph/ops/impl/Select.cpp new file mode 100644 index 00000000000..c66d98af4a2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Select.cpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include +#include + +namespace vkcompute { + +void check_args( + const vTensor& t_in, + int64_t dim, + int64_t index, + const vTensor& t_out) { + VK_CHECK_COND(check_memory_layout_is(t_in, api::kChannelsPacked)); + VK_CHECK_COND(check_memory_layout_is(t_out, api::kChannelsPacked)); + + const int64_t in_dim = t_in.dim(); + VK_CHECK_COND( + in_dim == 3 || in_dim == 4, + "Vulkan select only support 3d or 4d tensors!"); + + const int64_t in_size = t_in.size(dim); + + if (index < -in_size || index >= in_size) { + VK_CHECK_COND( + false, + "select(): index ", + index, + " t_outof range for tensor of size ", + in_size, + " at dimension ", + dim); + } +} + +void add_select_int_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef dim_ref, + const ValueRef index_ref, + const ValueRef out) { + vTensorPtr t_in = graph.get_tensor(in); + vTensorPtr t_out = graph.get_tensor(out); + int64_t dim = graph.extract_scalar(dim_ref); + int64_t index = graph.extract_scalar(index_ref); + + check_args(*t_in, dim, index, *t_out); + + const int64_t in_size = t_in->size(dim); + + if (index < 0) { + index += in_size; + } + + std::string kernel_name; + + // for 3d tensors, these values are not used by the shader. + int32_t num_texel_per_batch = 1; + int32_t num_batches = 1; + + int64_t in_dim = t_in->dim(); + if (in_dim == 3) { + if (dim == 0) { + kernel_name = "select_channel_3d"; + } else if (dim == 1) { + kernel_name = "select_height_3d"; + } else if (dim == 2) { + kernel_name = "select_width_3d"; + } else { + VK_CHECK_COND( + false, "Unexpected dim value=", dim, "for the input 3d tensor"); + } + } else { // self.dim() == 4 + num_texel_per_batch = + static_cast(std::ceil(static_cast(t_in->size(1)) / 4)); + num_batches = t_in->size(0); + if (dim == 0) { + kernel_name = "select_batch_4d"; + } else if (dim == 1) { + kernel_name = "select_channel_4d"; + } else if (dim == 2) { + kernel_name = "select_height_4d"; + } else if (dim == 3) { + kernel_name = "select_width_4d"; + } else { + VK_CHECK_COND( + false, "Unexpected dim value=", dim, "for the input 4d tensor"); + } + } + + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, *t_out); + + api::utils::uvec3 global_size = t_out->virtual_extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + // TODO: add resizing to support dynamic shapes. + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, + {t_out->gpu_sizes_ubo(), + // TODO: num_batches and num_texel_per_batch are provided by + // t_out->gpu_sizes. Can change the following to reduce params + // created. + + graph.create_params_buffer(api::utils::make_ivec4( + {index, num_batches, num_texel_per_batch, 0}))})); +} + +void select_int(ComputeGraph& graph, const std::vector& args) { + return add_select_int_node(graph, args[0], args[1], args[2], args[3]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.select.int, select_int); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 56baa60a9f6..31961f3e449 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -142,6 +142,29 @@ def get_full_inputs(): return test_suite +def get_select_int_inputs(): + test_suite = VkTestSuite( + [ + ((6, 2, 7), 0, 3), + ((6, 2, 7), 1, 0), + ((6, 2, 7), 2, 3), + ((6, 10, 7), 0, 3), + ((6, 10, 7), 1, 0), + ((6, 10, 7), 1, 9), + ((6, 10, 7), 2, 6), + ((9, 2, 9, 4), 0, 8), + ((9, 2, 9, 4), 1, 1), + ((9, 2, 9, 4), 2, 0), + ((9, 2, 9, 4), 2, 8), + ((9, 2, 9, 4), 3, 3), + ((8, 6, 1, 1), 0, 4), + ((8, 6, 1, 1), 1, 4), + ] + ) + test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"] + return test_suite + + test_suites = { "aten.add.Tensor": get_binary_elementwise_inputs(), "aten.sub.Tensor": get_binary_elementwise_inputs(), @@ -152,6 +175,7 @@ def get_full_inputs(): "aten.convolution.default": get_conv2d_inputs(), "aten.native_layer_norm.default": get_native_layer_norm_inputs(), "aten.full.default": get_full_inputs(), + "aten.select.int": get_select_int_inputs(), } prepacked_args = {"aten.mm.default": {"mat2"}}