From f708c12c5c79e59b851d1a195494facac6bbb0f1 Mon Sep 17 00:00:00 2001 From: Jorge Pineda Date: Fri, 5 Apr 2024 15:16:19 -0700 Subject: [PATCH] [ET-VK][Ops] aten.convolution (Depthwise Output-Tile) We port an optimization from ATen-VK for specific weight sizes: [`conv2d_dw_output_tile.glsl`](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/glsl/conv2d_dw_output_tile.glsl) Differential Revision: [D55814588](https://our.internmc.facebook.com/intern/diff/D55814588/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/conv2d_dw.glsl | 6 +- .../graph/ops/glsl/conv2d_dw_output_tile.glsl | 83 +++++++++++++++++++ .../graph/ops/glsl/conv2d_dw_output_tile.yaml | 21 +++++ .../vulkan/runtime/graph/ops/impl/Conv2d.cpp | 19 ++++- 4 files changed, 122 insertions(+), 7 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl index ab15658111f..50b60ad956d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl @@ -70,9 +70,9 @@ void main() { int kx = 0; for (int y = start.y; y < end.y; y += params.dilation.y) { for (int x = start.x; x < end.x; x += params.dilation.x) { - // The weight kernel was rearranged so that every NxN filter is flattened - // to fits in one row. Each filter was then stacked on top of each other - // vertically. + // The weight kernel was rearranged such that every NxN filter is + // flattened to fit in one row. Each filter was then stacked on top of + // each other vertically. const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0); sum = fma(in_texel, texelFetch(kernel_in, ivec2(kx, pos.z), 0), sum); ++kx; diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl new file mode 100644 index 00000000000..470eef6cdeb --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; +layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in; +layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in; + +layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents { + uvec4 data; +} +out_extents; + +layout(set = 0, binding = 5) uniform PRECISION restrict InExtents { + uvec4 data; +} +in_extents; + +layout(set = 0, binding = 6) uniform PRECISION restrict Params { + ivec2 kernel_size; + ivec2 stride; + ivec2 padding; + ivec2 dilation; +} +params; + +// If fields are separated, SwiftShader cannot identify in_group_size. +layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams { + ivec2 overlay_region; + int in_group_size; +} +extra_params; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Computes a depthwise convolution. Each shader invocation calculates the + * output at a single output location. + */ +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, out_extents.data.xyz))) { + return; + } + + // Compute the index of the top-left element of the overlay region. Negative + // indices indicate that the top-left element is in a region added by padding. + const ivec2 ipos = pos.xy * params.stride - params.padding; + + // Compute the start and end of the input indices to load. Padding is assumed + // to be constant 0 padding, so any reads from the padding region is skipped. + const ivec2 start = ipos; + const ivec2 end = ipos + extra_params.overlay_region.xy; + + ${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0); + int kx = 0; + for (int y = start.y, i = 0; i < ${TILE_SIZE}; y += params.dilation.y, i++) { + for (int x = start.x, j = 0; j < ${TILE_SIZE}; x += params.dilation.x, j++) { + // The weight kernel was rearranged such that every NxN filter is + // flattened to fit in one row. Each filter was then stacked on top of + // each other vertically. + const vec4 in_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0); + sum = fma(in_texel, texelFetch(kernel_in, ivec2(kx, pos.z), 0), sum); + kx++; + } + } + + imageStore(image_out, pos, sum); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml new file mode 100644 index 00000000000..1d4405e0276 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_dw_output_tile: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + TILE_SIZE: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + SUFFIX: half + - VALUE: float + SUFFIX: float + shader_variants: + - NAME: conv2d_dw_output_tile_3x3 + - NAME: conv2d_dw_output_tile_5x5 + TILE_SIZE: 5 diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp index dc4421e04c5..5ee180b6f8f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp @@ -87,13 +87,24 @@ enum class Conv2dMethod : uint8_t { }; api::ShaderInfo get_conv2d_shader( + ComputeGraph& graph, const vTensor& t_out, const bool prepack_weights, - const Conv2dMethod method) { + const Conv2dMethod method, + const ValueRef weight) { std::stringstream kernel_name; switch (method) { case Conv2dMethod::Depthwise: kernel_name << "conv2d_dw"; + if (!prepack_weights) { + const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes; + if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) { + kernel_name << "_output_tile_3x3"; + } + if (weight_sizes.at(2) == 5 && weight_sizes.at(3) == 5) { + kernel_name << "_output_tile_5x5"; + } + } break; case Conv2dMethod::SlidingWindow: kernel_name << "conv2d"; @@ -169,7 +180,7 @@ ValueRef prepack_weights( api::utils::uvec3 local_size = adaptive_work_group_size(global_size); api::ShaderInfo shader = - get_conv2d_shader(t, /*prepack_weights = */ true, method); + get_conv2d_shader(graph, t, /*prepack_weights = */ true, method, vref); const auto padded_sizes = get_padded_sizes(original_sizes, method); @@ -298,8 +309,8 @@ void add_conv2d_node( check_conv2d_params(kernel_params, transposed_val); - api::ShaderInfo shader = - get_conv2d_shader(t_out, /*prepack_weights = */ false, method); + api::ShaderInfo shader = get_conv2d_shader( + graph, t_out, /*prepack_weights = */ false, method, weight); graph.execute_nodes().emplace_back(new ExecuteNode( graph,