From c1561600b06ea800272e21ac433acfd613fe713a Mon Sep 17 00:00:00 2001 From: Jorge Pineda Date: Tue, 2 Apr 2024 12:06:30 -0700 Subject: [PATCH] [ET-VK][Ops] aten.convolution (SlidingWindow) ## The Operator `nn.Module` invocations of [`nn.Conv2d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) and [`nn.ConvTranspose2d`](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2d) get compiled to `aten.convolution.default` in the Edge Dialect, which carries the signature ``` - func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor ``` ## Summary (cases handled) We introduce support for the convolution cases covered by [ATen-VK's default SlidingWindow implementation](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L73). This is achieved by - reusing the [existing `conv2d.glsl`](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/glsl/conv2d.glsl), and - [moving special weights prepacking from CPU](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L134-L235) to the GPU in `conv2d_prepack_weights.glsl`. We also include resizing support for dynamic shapes. Note that only height and width of the input can vary. ## Cases not handled The implementation is on-par with ATen-VK's SlidingWindow. This means the following cases are missing: 1. **Groups G > 1.** Largely not covered by ATen-VK. `G = in_channels` is covered by ATen-VK's Depthwise impl and will be added soon. 2. **Batch (input) N > 1.** Not covered by ATen-VK. 3. **Padding > 0 while Dilation, Kernel > 1.** Not covered by ATen-VK. ## Coming soon For our CUNET model, the first two are required and the third is useful. 1. Transpose convolution 2. Depthwise convolution (for completeness) 3. Pointwise convolution (for optimization) 4. Null bias Differential Revision: [D55346778](https://our.internmc.facebook.com/intern/diff/D55346778/) [ghstack-poisoned] --- .../vulkan/partitioner/vulkan_partitioner.py | 2 + backends/vulkan/runtime/graph/ComputeGraph.h | 2 +- .../vulkan/runtime/graph/ops/glsl/conv2d.glsl | 136 +++++++++ .../vulkan/runtime/graph/ops/glsl/conv2d.yaml | 30 ++ .../ops/glsl/conv2d_prepack_weights.glsl | 139 ++++++++++ .../runtime/graph/ops/glsl/indexing_utils.h | 9 + .../vulkan/runtime/graph/ops/impl/Conv2d.cpp | 260 ++++++++++++++++++ .../graph/ops/impl/utils/KernelUtils.h | 2 +- backends/vulkan/test/test_vulkan_delegate.py | 27 ++ backends/vulkan/test/utils/test_utils.cpp | 35 +++ backends/vulkan/test/utils/test_utils.h | 7 + .../vulkan/test/vulkan_compute_api_test.cpp | 53 ++++ 12 files changed, 700 insertions(+), 2 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl create mode 100644 backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index b5df34b08cd..a4cf74097c4 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -48,6 +48,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.max_pool2d_with_indices.default, # Sum exir_ops.edge.aten.sum.dim_IntList, + # Convolution operators + exir_ops.edge.aten.convolution.default, # Other operator.getitem, ] diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 00aa60020f3..0f66d411c81 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -172,7 +172,7 @@ class ComputeGraph final { const api::ScalarType dtype, const api::StorageType storage_type, const api::GPUMemoryLayout memory_layout, - const int64_t shared_object_idx); + const int64_t shared_object_idx = -1); /* * Add a `vTensor` value to the graph with the specified properties. The diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl new file mode 100644 index 00000000000..3bf469a4047 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl @@ -0,0 +1,136 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; +layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in; +layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in; + +layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents { + uvec4 data; +} +out_extents; + +layout(set = 0, binding = 5) uniform PRECISION restrict InExtents { + uvec4 data; +} +in_extents; + +layout(set = 0, binding = 6) uniform PRECISION restrict Params { + ivec2 kernel_size; + ivec2 stride; + ivec2 padding; + ivec2 dilation; +} +params; + +// If fields are separated, SwiftShader cannot identify in_group_size. +layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams { + ivec2 overlay_region; + int in_group_size; +} +extra_params; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Computes a 2D convolution. Each shader invocation calculates the output at + * a single output location. + */ +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, out_extents.data.xyz))) { + return; + } + + // Compute the index of the top-left element of the overlay region. Negative + // indices indicate that the top-left element is in a region added by padding. + const ivec2 ipos = pos.xy * params.stride - params.padding; + + // Compute the start and end of the input indices to load. Padding is assumed + // to be constant 0 padding, so reads from the padding region are skipped. + const ivec2 start = max(ivec2(0), ipos); + const ivec2 end = min(ipos + extra_params.overlay_region.xy, ivec2(in_extents.data.xy)); + // Compute the start of the kernel based on how far we are skipping ahead when + // reading the input. Note that these are "canonical" indices. + ivec2 kstart = (start - ipos) / params.dilation; + // During prepacking, the weight tensor was rearranged in order to optimize + // for data access linearity in this shader. Therefore we need to adjust the + // canonical coordinates to the corresponding index in the rearranged weight + // tensor. The x-coordinate is multipled by 4 since each group of 4 channels + // is folded into the X axis. The y-coordinate is offset based on the z- + // coordinate because the 2D planes were stacked atop each other vertically. + kstart.x *= 4; + kstart.y += pos.z * params.kernel_size.y; + + // Perform the convolution by iterating over the overlay region. + vec4 sum = texelFetch(bias_in, ivec2(pos.z, 0), 0); + const int ic4 = extra_params.in_group_size / 4; + for (int z4 = 0; z4 < ic4; ++z4, kstart.x += params.kernel_size.x * 4) { + for (int y = start.y, ky = kstart.y; y < end.y; y += params.dilation.y, ++ky) { + for (int x = start.x, kx = kstart.x; x < end.x; x += params.dilation.x, kx += 4) { + const vec4 in_texel = texelFetch(image_in, ivec3(x, y, z4), 0); + + // To explain the calculation below, the contents of in_texel and the + // group of 4 texels loaded from kernel_in are shown: + // + // in_texel kernel_in + // -x-> ---x---> + // +---+ +----+----+----+----+ + // ^ | w | ^ | D0 | D1 | D2 | D3 | + // | +---+ | +----+----+----+----+ + // | | z | | | C0 | C1 | C2 | C3 | + // z +---+ z +----+----+----+----+ + // | | y | | | B0 | B1 | B2 | B3 | + // | +---+ | +----+----+----+----+ + // | x | | A0 | A1 | A2 | A3 | + // +---+ +----+----+----+----+ + // + // In the kernel_in graphic, cells sharing the same letter are from + // the same batch/output channel index, and the number denotes a unique + // channel index. To calculate the output texel, the following + // calculation is performed: + // + // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ + // | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 | + // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ + // | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 | + // +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+ + // | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 | + // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ + // | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 | + // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ + // + // which is what is expressed in the following calculations. + + const vec4 ktex_0 = texelFetch(kernel_in, ivec2(kx + 0, ky), 0); + sum = fma(in_texel.xxxx, ktex_0, sum); + + const vec4 ktex_1 = texelFetch(kernel_in, ivec2(kx + 1, ky), 0); + sum = fma(in_texel.yyyy, ktex_1, sum); + + const vec4 ktex_2 = texelFetch(kernel_in, ivec2(kx + 2, ky), 0); + sum = fma(in_texel.zzzz, ktex_2, sum); + + const vec4 ktex_3 = texelFetch(kernel_in, ivec2(kx + 3, ky), 0); + sum = fma(in_texel.wwww, ktex_3, sum); + } + } + } + + imageStore(image_out, pos, sum); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d.yaml new file mode 100644 index 00000000000..53323a3e77c --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d.yaml @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + SUFFIX: half + - VALUE: float + SUFFIX: float + shader_variants: + - NAME: conv2d + +conv2d_prepack_weights: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + SUFFIX: half + - VALUE: float + SUFFIX: float + shader_variants: + - NAME: conv2d_prepack_weights diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl new file mode 100644 index 00000000000..62ed84fa7a4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl @@ -0,0 +1,139 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out; +layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { + ${T[DTYPE]} data[]; +} +buffer_in; + +// Corresponds to {1,4,9,24} in the example below. +layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { + ivec4 data; +} +gpu_sizes; + +// Corresponds to {3,3,7,10} in the example below. +layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes { + ivec4 data; +} +original_sizes; + +// Corresponds to {3,3,8,12} in the example below. +layout(set = 0, binding = 4) uniform PRECISION restrict AlignedSizes { + ivec4 data; +} +padded_sizes; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Computes special prepacking for a 2D convolution. Each shader invocation + * calculates the input buffer location to read into the desired texel. This + * packing was originally developed on CPU and that approach is described in the + * rest of this comment. Refer to the code-level comments, for how we translate + * it to GPU by reversing the steps. + * + * Consider example weight tensor of size {10,7,3,3}. The following + * transformations will be applied. + * + * 1. Pad the N and C dims so that both are a multiple of 4. In this case, 2 + * batches and 1 channel of padding are added, producing a tensor of size + * {12,8,3,3}. + * at::pad(x, {0,0,0,0,0,2,0,1}, "constant", 0); + * + * 2. Split the tensor along the C dim so that each split has 4 channels. + * x.reshape({12,2,4,3,3}); + * + * 3. For each split, "fold" the C dim into the W dim. Suppose the first rows + * at H=0 of the split have values + * 0,1,2 | 10,11,12 | 20,21,22 | 30,31,32 + * + * where | denotes a channel boundary. Then, the goal is to combine those rows + * into one row with the values + * 0, 10, 20, 30, 1, 11, 21, 31, 2, 12, 22, 32 + * + * x.permute({0,1,3,4,2}).reshape({12,2,3,12}); + * + * 4. Stack the splits belonging to the same batch horizontally by swapping the + * C and H dims. + * x.permute({0,2,1,3}).reshape({12,3,24}); + * + * 5. Repeat a similar process to "fold" the N dim into the C dim. Split along + * the N dim so that each split has 4 batches. + * x.reshape({3,4,3,24}); + * + * 6. Stack the batches on each other vertically by swapping the N and C dims. + * x.permute({1,0,2,3}).reshape({4,9,24}); + */ +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_CHANNELS_PACKED(pos, gpu_sizes.data); + + if (any(greaterThanEqual(coord, gpu_sizes.data))) { + return; + } + + // As in usual staging shaders, map from GPU texel position to normal CPU + // buffer indices: (24,9) -> (4,9,24) + const int base_index = COORD_TO_BUFFER_IDX(coord, gpu_sizes.data); + const ivec4 p0 = + base_index + ivec4(0, 1, 2, 3) * STRIDE_CHANNELS_PACKED(gpu_sizes.data); + + // Re-map the normal CPU buffer indices to special indices, through a series + // of permutations: reshape is a no-op to the underlying indices, and permute + // is one of the hardest math problems I've ever solved. + // + // Undo step 6 premute: (4,3,3,24) -> (3,4,3,24) + // Undo step 4 permute: (12,3,2,12) -> (12,2,3,12) + // Undo step 3 permute, part 1: (12,2,3h,3w,4) -> (12,2,3h,4,3w) + // Undo step 3 permute, part 2: (12,2,3h,4,3w) -> (12,2,4,3h,3w) + const ivec4 p1 = SWAP_DIMS( + p0, + 4, + (padded_sizes.data.w / 4), + (padded_sizes.data.y * padded_sizes.data.z * padded_sizes.data.x)); + const ivec4 p2 = SWAP_DIMS( + p1, + padded_sizes.data.y, + (padded_sizes.data.z / 4), + (padded_sizes.data.x * 4)); + const ivec4 p3 = SWAP_DIMS(p2, padded_sizes.data.x, 4, 1); + const ivec4 p4 = SWAP_DIMS(p3, padded_sizes.data.y, 4, padded_sizes.data.x); + + // For values in the padded region, write zero instead of buffer data. + // + // Undo step 1 pad: (12,8,3,3) -> (10,7,3,3) + const ivec4 c = p4 % + (padded_sizes.data.z * padded_sizes.data.y * padded_sizes.data.x) / + (padded_sizes.data.y * padded_sizes.data.x); + const ivec4 n = + p4 / (padded_sizes.data.z * padded_sizes.data.y * padded_sizes.data.x); + const ivec4 p5 = p4 - + n * (padded_sizes.data.z - original_sizes.data.z) * padded_sizes.data.y * + padded_sizes.data.x; + const ivec4 mask = ivec4(greaterThanEqual(c, original_sizes.data.zzzz)) | + ivec4(greaterThanEqual(n, original_sizes.data.wwww)); + + ${T[DTYPE]} val_x = mix(buffer_in.data[p5.x], 0, mask.x); + ${T[DTYPE]} val_y = mix(buffer_in.data[p5.y], 0, mask.y); + ${T[DTYPE]} val_z = mix(buffer_in.data[p5.z], 0, mask.z); + ${T[DTYPE]} val_w = mix(buffer_in.data[p5.w], 0, mask.w); + + ${VEC4_T[DTYPE]} texel = ${VEC4_T[DTYPE]}(val_x, val_y, val_z, val_w); + + imageStore(image_out, pos.xy, texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index c76f054ec67..100e69bd824 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -44,3 +44,12 @@ #define STRIDE_WIDTH_PACKED(vec) (1) #define STRIDE_HEIGHT_PACKED(vec) (vec.x) + +// Given a buffer(1-D) index cur, compute a new index where the corresponding +// tensor(N-D)'s x and y dimensions are swapped, and size is of the M-D plane of +// dimensions lower than x and y. +#define SWAP_DIMS(cur, x, y, size) \ + cur + \ + size*( \ + (1 - y) * ((cur % (x * y * size)) / (y * size)) + \ + (x - 1) * ((cur % (y * size)) / size)) diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp new file mode 100644 index 00000000000..0453c345685 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp @@ -0,0 +1,260 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +#include +#include + +#include + +namespace vkcompute { + +struct Conv2dParams final { + api::utils::ivec2 overlay_region; + int in_group_size; +}; + +void resize_conv2d_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + vTensor& out = graph->get_val(args[0].refs[0]).toTensor(); + vTensor& self = graph->get_val(args[1].refs[0]).toTensor(); + + size_t ndim = self.sizes().size(); + std::vector new_out_sizes(ndim); + + // Batch, Channel + if (ndim == 4) { + new_out_sizes.at(ndim - 4) = self.sizes().at(ndim - 4); + } + const auto weight_sizes = graph->get_val(extra_args[0]).toTensorRef().sizes; + new_out_sizes.at(ndim - 3) = weight_sizes.at(ndim - 4); + + const auto kernel_size = + api::utils::make_ivec2({weight_sizes.at(3), weight_sizes.at(2)}); + const auto stride = reverse(*graph, extra_args[1]); + const auto padding = reverse(*graph, extra_args[2]); + const auto dilation = reverse(*graph, extra_args[3]); + + // Height, Width + new_out_sizes.at(ndim - 2) = calc_out_size( + self.sizes().at(ndim - 2), + kernel_size.data[1], + stride.data[1], + padding.data[1], + dilation.data[1]); + new_out_sizes.at(ndim - 1) = calc_out_size( + self.sizes().at(ndim - 1), + kernel_size.data[0], + stride.data[0], + padding.data[0], + dilation.data[0]); + + VK_CHECK_COND(new_out_sizes.at(ndim - 2) >= 1); + VK_CHECK_COND(new_out_sizes.at(ndim - 1) >= 1); + + out.virtual_resize(new_out_sizes); +} + +ValueRef prepack_biases(ComputeGraph& graph, const ValueRef vref) { + if (graph.get_val(vref).isNone()) { + VK_THROW("aten.convolution.default: Null bias is not supported yet!"); + } + + TensorRef& tref = graph.get_val(vref).toTensorRef(); + ValueRef v = graph.add_tensor( + tref.sizes, + tref.dtype, + api::StorageType::TEXTURE_2D, + api::GPUMemoryLayout::TENSOR_WIDTH_PACKED); + vTensor t = graph.get_val(v).toTensor(); + + api::ShaderInfo shader = get_nchw_to_image_shader(t); + + api::utils::uvec3 global_size = t.extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + shader, + global_size, + local_size, + vref, + v, + {t.gpu_sizes_ubo(), t.cpu_sizes_ubo()})); + + return v; +} + +ValueRef prepack_weights(ComputeGraph& graph, const ValueRef vref) { + TensorRef& tref = graph.get_val(vref).toTensorRef(); + + int64_t batch_padded = + api::utils::align_up(api::utils::val_at(-4, tref.sizes), INT64_C(4)); + int64_t channels_padded = + api::utils::align_up(api::utils::val_at(-3, tref.sizes), INT64_C(4)); + int64_t height = api::utils::val_at(-2, tref.sizes); + int64_t width = api::utils::val_at(-1, tref.sizes); + + const auto final_sizes = std::vector{ + 4, batch_padded * height / 4, channels_padded * width}; + + ValueRef v = graph.add_tensor( + final_sizes, + tref.dtype, + api::StorageType::TEXTURE_2D, + api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED); + vTensor t = graph.get_val(v).toTensor(); + + api::utils::uvec3 global_size = t.extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + std::stringstream kernel_name; + kernel_name << "conv2d_prepack_weights"; + apply_dtype_suffix(kernel_name, t); + api::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name.str()); + + const auto original_sizes = + api::utils::make_ivec4(tref.sizes, /*reverse=*/true); + const auto padded_sizes = api::utils::make_ivec4( + {batch_padded, channels_padded, height, width}, /*reverse=*/true); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + shader, + global_size, + local_size, + vref, + v, + {t.gpu_sizes_ubo(), + graph.create_params_buffer(original_sizes), + graph.create_params_buffer(padded_sizes)})); + + return v; +} + +void check_conv2d_args(const vTensor& in, const vTensor& out) { + VK_CHECK_COND( + check_memory_layout_is(in, api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED)); + VK_CHECK_COND(check_memory_layout_is( + out, api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED)); +} + +void add_conv2d_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef weight, + const ValueRef bias, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef out) { + ValueRef arg_in = prepack_if_tensor_ref(graph, in); + vTensor& t_in = graph.get_val(arg_in).toTensor(); + vTensor& t_out = graph.get_val(out).toTensor(); + + check_conv2d_args(t_in, t_out); + + if (t_in.sizes().at(0) > 1) { + VK_THROW( + "aten.convolution.default: input batch size > 1 is not supported yet!"); + } + + ValueRef arg_weight = prepack_weights(graph, weight); + ValueRef arg_bias = prepack_biases(graph, bias); + + api::utils::uvec3 global_size = t_out.virtual_extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes; + const int64_t k_height = weight_sizes.at(2); + const int64_t k_width = weight_sizes.at(3); + const auto kernel_size = api::utils::make_ivec2({k_width, k_height}); + + const auto stride_vec = reverse(graph, stride); + const auto padding_vec = reverse(graph, padding); + const auto dilation_vec = reverse(graph, dilation); + + KernelParams kernel_params{ + kernel_size, + stride_vec, + padding_vec, + dilation_vec, + }; + + const int64_t d_height = dilation_vec.data[1]; + const int64_t d_width = dilation_vec.data[0]; + const int64_t p_height = padding_vec.data[1]; + const int64_t p_width = padding_vec.data[0]; + + if ((p_width > 0 && k_width > 1 && d_width > 1) || + (p_height > 0 && k_height > 1 && d_height > 1)) { + VK_THROW( + "aten.convolution.default: padding > 0 while dilation, kernel_size > 1 is not supported yet!"); + } + + const auto overlay_region = api::utils::make_ivec2({ + k_width + (k_width - 1) * (d_width - 1), + k_height + (k_height - 1) * (d_height - 1), + }); + const int32_t in_group_size = api::utils::safe_downcast( + api::utils::align_up(weight_sizes.at(1), INT64_C(4))); + + Conv2dParams extra_params{ + overlay_region, + in_group_size, + }; + + std::stringstream kernel_name; + kernel_name << "conv2d"; + apply_dtype_suffix(kernel_name, t_out); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name.str()), + global_size, + local_size, + // Inputs and Outputs + {{out, api::MemoryAccessType::WRITE}, + {{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}}, + // Shader params buffers + { + t_out.extents_ubo(), + t_in.extents_ubo(), + graph.create_params_buffer(kernel_params), + graph.create_params_buffer(extra_params), + }, + // Resizing + resize_conv2d_node, + {weight, stride, padding, dilation})); +} + +void conv2d(ComputeGraph& graph, const std::vector& args) { + const bool transposed = graph.get_val(args[6]).toBool(); + if (transposed) { + VK_THROW("aten.convolution.default: transpose is not supported yet!"); + } + const int64_t groups = graph.get_val(args[8]).toInt(); + if (groups > 1) { + VK_THROW("aten.convolution.default: groups > 1 is not supported yet!"); + } + return add_conv2d_node( + graph, args[0], args[1], args[2], args[3], args[4], args[5], args[9]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.convolution.default, conv2d); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h index 6e6763dc574..ed7aa873e07 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h @@ -29,7 +29,7 @@ int64_t calc_out_size( const int64_t stride, const int64_t padding, const int64_t dilation, - const bool ceil_mode); + const bool ceil_mode = false); api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref); diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index d90cfad7bbe..d305fd19663 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -496,3 +496,30 @@ def forward(self, x): sample_inputs, memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], ) + + def test_vulkan_backend_conv2d(self): + class Conv2dModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=6, + out_channels=8, + kernel_size=(3, 3), + padding=(2, 3), + stride=(1, 2), + dilation=1, + groups=1, + bias=True, + ) + + def forward(self, x): + return self.conv(x) + + conv2d_module = Conv2dModule() + sample_inputs = (torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),) + + self.lower_module_and_test_output( + conv2d_module, + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index ea051474a31..ae64f90ef9d 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -108,6 +108,41 @@ void record_image_to_nchw_op( v_src.cpu_sizes_ubo()->buffer()); } +void record_conv2d_prepack_weights_op( + api::Context* const context, + api::VulkanBuffer& src_buffer, + vTensor& v_dst, + const std::vector& original_sizes, + const std::vector& padded_sizes) { + api::PipelineBarrier pipeline_barrier{}; + + std::stringstream kernel_name; + kernel_name << "conv2d_prepack_weights"; + apply_dtype_suffix(kernel_name, v_dst); + api::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name.str()); + + api::UniformParamsBuffer original_sizes_ubo( + context, api::utils::make_ivec4(original_sizes, /*reverse=*/true)); + api::UniformParamsBuffer padded_sizes_ubo( + context, api::utils::make_ivec4(padded_sizes, /*reverse=*/true)); + + context->submit_compute_job( + shader, + pipeline_barrier, + v_dst.virtual_extents(), + adaptive_work_group_size(v_dst.virtual_extents()), + VK_NULL_HANDLE, + v_dst.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + src_buffer, + v_dst.gpu_sizes_ubo()->buffer(), + v_dst.cpu_sizes_ubo()->buffer(), + original_sizes_ubo.buffer(), + padded_sizes_ubo.buffer()); +} + void record_binary_op( api::Context* const context, const std::string& op_name, diff --git a/backends/vulkan/test/utils/test_utils.h b/backends/vulkan/test/utils/test_utils.h index 8dcba015520..2d7d0b0746f 100644 --- a/backends/vulkan/test/utils/test_utils.h +++ b/backends/vulkan/test/utils/test_utils.h @@ -81,6 +81,13 @@ void record_image_to_nchw_op( vTensor& v_src, api::VulkanBuffer& dst_buffer); +void record_conv2d_prepack_weights_op( + api::Context* const context, + api::VulkanBuffer& src_buffer, + vTensor& v_dst, + const std::vector& original_sizes, + const std::vector& padded_sizes); + void record_binary_op( api::Context* const context, const std::string& op_name, diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 3265191180b..763c6e575a3 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1144,3 +1144,56 @@ TEST(VulkanComputeGraphOpsTest, max_pool2d_smoke_test) { /*base_val=*/10.0f, kernel); } + +TEST(VulkanComputeGraphOpsTest, conv2d_prepack_test) { + const auto original_sizes = std::vector{2, 3, 1, 2}; + const auto padded_sizes = std::vector{4, 4, 1, 2}; + const auto gpu_sizes = std::vector{4, 1, 8}; + + vTensor vten = vTensor( + api::context(), + gpu_sizes, + api::kFloat, + api::StorageType::TEXTURE_2D, + api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED); + + // Create and fill input staging buffer + const int64_t in_numel = api::utils::multiply_integers(original_sizes); + api::StorageBuffer staging_buffer_in(api::context(), api::kFloat, in_numel); + + std::vector data_in(in_numel); + for (int i = 0; i < in_numel; i++) { + data_in[i] = i + 1; + } + copy_ptr_to_staging( + data_in.data(), staging_buffer_in, sizeof(float) * in_numel); + + // Output staging buffer + const int64_t out_numel = api::utils::multiply_integers(padded_sizes); + api::StorageBuffer staging_buffer_out(api::context(), api::kFloat, out_numel); + + // Copy data in and out of the tensor + record_conv2d_prepack_weights_op( + api::context(), + staging_buffer_in.buffer(), + vten, + original_sizes, + padded_sizes); + record_image_to_nchw_op(api::context(), vten, staging_buffer_out.buffer()); + + // Execute command buffer + submit_to_gpu(); + + // Extract data from output staging buffer + std::vector data_out(out_numel); + copy_staging_to_ptr( + staging_buffer_out, data_out.data(), sizeof(float) * out_numel); + + // Check data matches results copied from ATen-VK + std::vector data_out_expected = {1, 3, 5, 0, 2, 4, 6, 0, 7, 9, 11, + 0, 8, 10, 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + for (int i = 0; i < vten.numel(); i++) { + CHECK_VALUE(data_out, i, data_out_expected[i]); + } +}