From 6ca821f5399f68a60ab1d92c2e1a605181f708e3 Mon Sep 17 00:00:00 2001 From: Jorge Pineda Date: Fri, 5 Apr 2024 15:16:27 -0700 Subject: [PATCH 1/2] [ET-VK][Ops] aten.convolution (Bias=False) The final touches to get ET-VK convolution on-par with ATen-VK's convolution. ## Idea In our shaders, we add the bias to our sum. ``` ${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0); ``` To keep our shaders as is, we implement having no bias by allocating a buffer of zeros. Then, our shader adds zero to our sum. ## Issue If `Bias=False`, dummy buffer of zeros is not serialized with the graph. The bias ValueRef is deserialized in the runtime as `TypeTag::NONE`, not `TypeTag::TENSORREF`. ## Solution If `TypeTag::NONE` is given, (1) create the `vTensor` using the `out_channels` value from the weights and (2) allocate a StagingBuffer of that size. The StagingBuffer will be transferred to GPU memory and initialized to zeros. Differential Revision: [D55814589](https://our.internmc.facebook.com/intern/diff/D55814589/) [ghstack-poisoned] --- .../vulkan/runtime/graph/ops/PrepackNode.cpp | 24 +++++++++++++---- .../vulkan/runtime/graph/ops/PrepackNode.h | 3 +++ .../vulkan/runtime/graph/ops/impl/Conv2d.cpp | 18 ++++++++----- backends/vulkan/test/test_vulkan_delegate.py | 27 +++++++++++++++++++ 4 files changed, 60 insertions(+), 12 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index 192d7496045..48e4e941ac7 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -32,17 +32,31 @@ PrepackNode::PrepackNode( graph.update_descriptor_counts(shader, /*execute = */ false); } -void PrepackNode::encode(ComputeGraph* graph) { - api::Context* const context = graph->context(); - api::PipelineBarrier pipeline_barrier{}; - - TensorRef& tref = graph->get_val(tref_).toTensorRef(); +api::StorageBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) { vTensor& packed = graph->get_val(packed_).toTensor(); + // If no TensorRef is provided, create a zeroed staging buffer according to + // the vTensor metadata. + if (graph->get_val(tref_).isNone()) { + size_t numel = api::utils::multiply_integers(packed.sizes()); + api::StorageBuffer staging(graph->context(), packed.dtype(), numel); + return staging; + } + + TensorRef& tref = graph->get_val(tref_).toTensorRef(); size_t numel = api::utils::multiply_integers(tref.sizes); api::StorageBuffer staging(graph->context(), tref.dtype, numel); size_t nbytes = numel * api::element_size(tref.dtype); copy_ptr_to_staging(tref.data, staging, nbytes); + return staging; +} + +void PrepackNode::encode(ComputeGraph* graph) { + api::Context* const context = graph->context(); + api::PipelineBarrier pipeline_barrier{}; + + vTensor& packed = graph->get_val(packed_).toTensor(); + api::StorageBuffer staging = create_staging_buffer(graph); std::unique_lock cmd_lock = context->dispatch_lock(); diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.h b/backends/vulkan/runtime/graph/ops/PrepackNode.h index dd31be12b37..a6759c81b14 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.h +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.h @@ -47,6 +47,9 @@ class PrepackNode final { const ValueRef packed_; // TODO(T180906457): allow re-computing param buffers. std::vector> params_; + + private: + api::StorageBuffer create_staging_buffer(ComputeGraph* graph); }; } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp index 896a26a06db..c6ff2f7490c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp @@ -52,13 +52,17 @@ void resize_conv2d_node( out.virtual_resize(new_out_sizes); } -ValueRef prepack_biases(ComputeGraph& graph, const ValueRef vref) { - if (graph.get_val(vref).isNone()) { - VK_THROW("aten.convolution.default: Null bias is not supported yet!"); - } +ValueRef prepack_biases( + ComputeGraph& graph, + const ValueRef vref, + const ValueRef weight, + const bool transposed) { + TensorRef& tref = graph.get_val(weight).toTensorRef(); + const int64_t out_channels = transposed ? tref.sizes.at(1) : tref.sizes.at(0); - ValueRef v = graph.add_tensor_like( - vref, + ValueRef v = graph.add_tensor( + {out_channels}, + tref.dtype, api::StorageType::TEXTURE_2D, api::GPUMemoryLayout::TENSOR_WIDTH_PACKED); vTensor& t = graph.get_val(v).toTensor(); @@ -301,7 +305,7 @@ void add_conv2d_node( ValueRef arg_in = prepack_if_tensor_ref(graph, in); ValueRef arg_weight = prepack_weights(graph, weight, method); - ValueRef arg_bias = prepack_biases(graph, bias); + ValueRef arg_bias = prepack_biases(graph, bias, weight, transposed_val); vTensor& t_in = graph.get_val(arg_in).toTensor(); vTensor& t_out = graph.get_val(out).toTensor(); diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 8ba695524cd..dd2142eee47 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -601,3 +601,30 @@ def forward(self, x): sample_inputs, memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], ) + + def test_vulkan_backend_conv2d_bias_false(self): + class Conv2dModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=6, + out_channels=8, + kernel_size=(3, 3), + padding=(2, 3), + stride=(1, 2), + dilation=1, + groups=1, + bias=False, + ) + + def forward(self, x): + return self.conv(x) + + conv2d_module = Conv2dModule() + sample_inputs = (torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),) + + self.lower_module_and_test_output( + conv2d_module, + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) From bf72506cb5b3021c8ae03873f7c7020ec9759964 Mon Sep 17 00:00:00 2001 From: Jorge Pineda Date: Mon, 8 Apr 2024 13:01:46 -0700 Subject: [PATCH 2/2] Update on "[ET-VK][Ops] aten.convolution (Bias=False)" The final touches to get ET-VK convolution on-par with ATen-VK's convolution. ## Idea In our shaders, we add the bias to our sum. ``` ${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0); ``` To keep our shaders as is, we implement having no bias by allocating a buffer of zeros. Then, our shader adds zero to our sum. ## Issue If `Bias=False`, dummy buffer of zeros is not serialized with the graph. The bias ValueRef is deserialized in the runtime as `TypeTag::NONE`, not `TypeTag::TENSORREF`. ## Solution If `TypeTag::NONE` is given, (1) create the `vTensor` using the `out_channels` value from the weights and (2) allocate a StagingBuffer of that size. The StagingBuffer will be transferred to GPU memory and initialized to zeros. Differential Revision: [D55814589](https://our.internmc.facebook.com/intern/diff/D55814589/) [ghstack-poisoned] --- backends/vulkan/runtime/graph/ops/PrepackNode.cpp | 4 +++- backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp | 7 +++++++ backends/vulkan/runtime/graph/ops/utils/StagingUtils.h | 2 ++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index 48e4e941ac7..26ce76ebcdc 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -35,11 +35,13 @@ PrepackNode::PrepackNode( api::StorageBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) { vTensor& packed = graph->get_val(packed_).toTensor(); - // If no TensorRef is provided, create a zeroed staging buffer according to + // If no TensorRef is provided, create a staging buffer of zeros according to // the vTensor metadata. if (graph->get_val(tref_).isNone()) { size_t numel = api::utils::multiply_integers(packed.sizes()); api::StorageBuffer staging(graph->context(), packed.dtype(), numel); + size_t nbytes = numel * api::element_size(packed.dtype()); + copy_zeros_to_staging(staging, nbytes); return staging; } diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp index 7f5ae409d44..f2b380cd8db 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp @@ -89,6 +89,13 @@ void copy_staging_to_ptr( memcpy_from_mapping(mapping, dst, nbytes, staging.dtype()); } +void copy_zeros_to_staging(api::StorageBuffer& staging, const size_t nbytes) { + void* data = malloc(nbytes); + memset(data, 0, nbytes); + copy_ptr_to_staging(data, staging, nbytes); + free(data); +} + api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) { if (v_dst.is_quantized()) { VK_THROW("Quantized Tensors are currently not supported!"); diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h index 0634d8d02e7..4ef25656904 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h @@ -25,6 +25,8 @@ void copy_staging_to_ptr( void* dst, const size_t nbytes); +void copy_zeros_to_staging(api::StorageBuffer& staging, const size_t nbytes); + // // Functions to get shaders //