From 79943cade63cab039e101c915cb478b977627255 Mon Sep 17 00:00:00 2001 From: Wei Lu Date: Tue, 9 Apr 2024 17:43:40 -0700 Subject: [PATCH] add aten.sum.default (#2807) Summary: The operator `aten.sum.dim_IntList` could take an empty list as the parameter for `dims`. We modify `vulkan_graph_builder.py` to accommodate the empty list. Moreover, the op `aten.sum.default` is implemented as a [decomposition](https://www.internalfb.com/code/fbsource/[96e496f9db8f92967b4394bd4f60e39ab916740b]/xplat/caffe2/torch/_decomp/decompositions.py?lines=4676) into `aten.sum.dim_IntList` with empty `dims`. So we will support `aten.sum.default` with the changes. Context: `torch.sum(x, ())` and `torch.sum(x)` are two ways to compute the sum of all elements in tensor `x`. Reviewed By: SS-JIA, jorgep31415 Differential Revision: D55630993 --- .../vulkan/runtime/graph/ops/impl/Sum.cpp | 15 +++++++++++---- .../serialization/vulkan_graph_builder.py | 10 ++++++++-- backends/vulkan/test/test_vulkan_delegate.py | 19 +++++++++++++++++++ 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/impl/Sum.cpp b/backends/vulkan/runtime/graph/ops/impl/Sum.cpp index ff235a2357f..da9347d2714 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Sum.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Sum.cpp @@ -120,10 +120,17 @@ void add_sum_dim_IntList( const auto& dims_to_sum = graph.get_val(opt_dim).toIntList(); int64_t in_dim = in_tensor.sizes().size(); - for (const auto& dim : dims_to_sum) { - // Normalize (negative) dim into range [0, self.dim() - 1] - int64_t dim_normalized = normalize(dim, in_dim); - dims_set.insert(dim_normalized); + if (dims_to_sum.empty()) { + // If dim is not specified, reduce over all dims + for (int64_t i = 0; i < in_dim; ++i) { + dims_set.insert(i); + } + } else { + for (const auto& dim : dims_to_sum) { + // Normalize (negative) dim into range [0, self.dim() - 1] + int64_t dim_normalized = normalize(dim, in_dim); + dims_set.insert(dim_normalized); + } } // Reduce the higher dimensionalities first, otherwise when keepdim is diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 3f7e473c27c..b6e8df7466e 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -178,7 +178,11 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: def create_scalar_list_value(self, arg: List[_ScalarType]) -> int: new_id = len(self.values) - if isinstance(arg[0], bool): + if len(arg) == 0: + self.values.append( + vk_graph_schema.VkValue(vk_graph_schema.IntList(items=[])) + ) + elif isinstance(arg[0], bool): self.values.append( vk_graph_schema.VkValue( vk_graph_schema.BoolList(items=[cast(bool, e) for e in arg]) @@ -227,7 +231,9 @@ def get_or_create_value_for(self, arg: _Argument): return self.create_scalar_value(arg) elif isinstance(arg, TensorSpec): return self.create_tensor_value(arg) - elif isinstance(arg, list) and isinstance(arg[0], _ScalarType): + elif isinstance(arg, list) and ( + len(arg) == 0 or isinstance(arg[0], _ScalarType) + ): # pyre-ignore[6] return self.create_scalar_list_value(arg) elif isinstance(arg, list) and isinstance(arg[0], Node): diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 8ba695524cd..97a28eb6002 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -497,6 +497,25 @@ def forward(self, x): memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], ) + def test_vulkan_backend_sum(self): + class SumModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.sum(x, (), keepdim=True) + x = torch.sum(x) + return x + + module = SumModule() + sample_inputs = (torch.rand(size=(3, 2, 7, 5), dtype=torch.float32),) + + self.lower_module_and_test_output( + module, + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + def test_vulkan_backend_conv2d(self): class Conv2dModule(torch.nn.Module): def __init__(self):