From b7444833990de210b882b4bcbec28f264b5a8501 Mon Sep 17 00:00:00 2001 From: Justin Yip Date: Mon, 29 Apr 2024 13:30:37 -0700 Subject: [PATCH 1/3] [ET-VK][14/n] Add operators to Partitioner 1. Register aten operators in the vulkan partitioner. 2. Fix some minor operators name issue due to mismatch between the torch api and actual aten name Note: Permute is not yet registered due to tensor movement issues with the "Partial" model where the `Linear` operator is decomposed into `permute` and `addmm`. Will fix in later diffs. Differential Revision: [D56695929](https://our.internmc.facebook.com/intern/diff/D56695929/) [ghstack-poisoned] --- .../vulkan/partitioner/vulkan_partitioner.py | 6 + .../vulkan/runtime/graph/ops/impl/Split.cpp | 5 +- .../serialization/vulkan_graph_builder.py | 11 +- backends/vulkan/test/test_vulkan_delegate.py | 124 ++++++++++++++++++ 4 files changed, 138 insertions(+), 8 deletions(-) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index e9ec9f2d84c..0e171ddd582 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -56,6 +56,12 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten.unsqueeze_copy.default, exir_ops.edge.aten.view_copy.default, + # Copy-releated operators + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.split.Tensor, + exir_ops.edge.aten.slice_copy.Tensor, # Other operator.getitem, exir_ops.edge.aten.full.default, diff --git a/backends/vulkan/runtime/graph/ops/impl/Split.cpp b/backends/vulkan/runtime/graph/ops/impl/Split.cpp index 3b40871a791..2d218f722a2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Split.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Split.cpp @@ -106,7 +106,7 @@ void add_split_with_sizes_default_node( add_split_with_sizes_default_node(graph, in, split_sizes, dim, out); } -void split_with_sizes_default( +void split_with_sizes_copy_default( ComputeGraph& graph, const std::vector& args) { add_split_with_sizes_default_node(graph, args[0], args[1], args[2], args[3]); @@ -134,7 +134,8 @@ void split_tensor(ComputeGraph& graph, const std::vector& args) { } REGISTER_OPERATORS { - VK_REGISTER_OP(aten.split_with_sizes.default, split_with_sizes_default); + VK_REGISTER_OP( + aten.split_with_sizes_copy.default, split_with_sizes_copy_default); VK_REGISTER_OP(aten.split.Tensor, split_tensor); } diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 9c12cb4a010..686b10ce8ab 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -6,7 +6,7 @@ import operator from types import NoneType -from typing import cast, List, Optional, Union +from typing import cast, List, Optional, Tuple, Union import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema @@ -133,9 +133,9 @@ def create_node_value(self, node: Node) -> int: new_id = self.create_tensor_value(spec, constant_id) self.node_to_value_ids[node] = new_id return new_id - elif isinstance(spec, tuple): - # Create a Value for each element in the tuple, wrap Values in a - # ValueList, and map the Node to the ValueList id. + elif isinstance(spec, list) or isinstance(spec, tuple): + # pyre-ignore[6]: pyre having hard time to infer Node type inside + # the container. new_id = self.create_value_list_value(spec) self.node_to_value_ids[node] = new_id return new_id @@ -202,7 +202,7 @@ def create_scalar_list_value(self, arg: List[_ScalarType]) -> int: ) return new_id - def create_value_list_value(self, arg: List[Node] | tuple) -> int: + def create_value_list_value(self, arg: tuple | list) -> int: self.values.append( vk_graph_schema.VkValue( vk_graph_schema.ValueList( @@ -242,7 +242,6 @@ def get_or_create_value_for(self, arg: _Argument): # pyre-ignore[6] return self.create_scalar_list_value(arg) elif isinstance(arg, list) and isinstance(arg[0], Node): - # pyre-ignore[6] return self.create_value_list_value(arg) elif isinstance(arg, str): return self.create_string_value(arg) diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index a458fc1c24e..07c7de0c0c5 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -80,6 +80,10 @@ def run_test(memory_layout): compile_options = { "memory_layout_override": memory_layout, } + + # At least model should run in eager mode. + model(*sample_inputs) + program: ExportedProgram = export( model, sample_inputs, dynamic_shapes=dynamic_shapes ) @@ -798,3 +802,123 @@ def forward(self, x): sample_inputs, memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], ) + + def DISABLED_test_vulkan_backend_permute_copy(self): + # aten.permute_copy.default is not enabled yet in partitioner + class PermuteModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.permute(x, [3, 0, 2, 1]) + + sample_inputs = (torch.randn(size=(3, 6, 2, 7), dtype=torch.float32),) + + self.lower_module_and_test_output( + PermuteModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + + def test_vulkan_backend_cat(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z, w): + return torch.cat([x, y, z, w], dim=1) + + sample_inputs = ( + torch.randn(size=(3, 6, 2, 7), dtype=torch.float32), + torch.randn(size=(3, 1, 2, 7), dtype=torch.float32), + torch.randn(size=(3, 9, 2, 7), dtype=torch.float32), + torch.randn(size=(3, 3, 2, 7), dtype=torch.float32), + ) + + self.lower_module_and_test_output( + TestModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + + def test_vulkan_backend_slice(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[:, 2:9:2, :] + + sample_inputs = (torch.randn(size=(3, 13, 7, 3), dtype=torch.float32),) + + self.lower_module_and_test_output( + TestModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + + def test_vulkan_backend_split_with_sizes(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.split(x, (3, 6, 1, 3), dim=1) + + sample_inputs = (torch.randn(size=(3, 13, 7, 3), dtype=torch.float32),) + + self.lower_module_and_test_output( + TestModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + + def test_vulkan_backend_split_tensor(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.tensor_split(x, 2, dim=1) + + sample_inputs = (torch.randn(size=(3, 14, 7, 3), dtype=torch.float32),) + + self.lower_module_and_test_output( + TestModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + + def test_vulkan_backend_clone(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.clone(x) + + sample_inputs = (torch.randn(size=(3, 14, 7, 3), dtype=torch.float32),) + + self.lower_module_and_test_output( + TestModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + + def DISABLED_test_vulkan_backend_t_default(self): + # aten.permute_copy.default is not enabled yet in partitioner + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # torch.t is actually exported as aten::permute. + return torch.t(x) + + sample_inputs = (torch.randn(size=(3, 14), dtype=torch.float32),) + + self.lower_module_and_test_output( + TestModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) From fd026ea70897bcdc7c44ecc07bc2ab316c912ae0 Mon Sep 17 00:00:00 2001 From: Justin Yip Date: Mon, 29 Apr 2024 13:48:08 -0700 Subject: [PATCH 2/3] Update on "[ET-VK][14/n] Add operators to Partitioner" 1. Register aten operators in the vulkan partitioner. 2. Fix some minor operators name issue due to mismatch between the torch api and actual aten name Note: Permute is not yet registered due to tensor movement issues with the "Partial" model where the `Linear` operator is decomposed into `permute` and `addmm`. Will fix in later diffs. Differential Revision: [D56695929](https://our.internmc.facebook.com/intern/diff/D56695929/) [ghstack-poisoned] --- backends/vulkan/serialization/vulkan_graph_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 686b10ce8ab..4b3c804ceca 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -6,7 +6,7 @@ import operator from types import NoneType -from typing import cast, List, Optional, Tuple, Union +from typing import cast, List, Optional, Union import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema From 567cd8cc27d47eb5c4ddced78ed16c622563a992 Mon Sep 17 00:00:00 2001 From: Justin Yip Date: Mon, 29 Apr 2024 14:33:58 -0700 Subject: [PATCH 3/3] Update on "[ET-VK][14/n] Add operators to Partitioner" 1. Register aten operators in the vulkan partitioner. 2. Fix some minor operators name issue due to mismatch between the torch api and actual aten name Note: Permute is not yet registered due to tensor movement issues with the "Partial" model where the `Linear` operator is decomposed into `permute` and `addmm`. Will fix in later diffs. Differential Revision: [D56695929](https://our.internmc.facebook.com/intern/diff/D56695929/) [ghstack-poisoned] --- backends/vulkan/test/op_tests/cases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index a326402cc39..329ae02aa95 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -575,6 +575,6 @@ def get_split_tensor_inputs(): "aten.clone.default": get_clone_inputs(), "aten.repeat.default": get_repeat_inputs(), "aten.cat.default": get_cat_inputs(), - "aten.split_with_sizes.default": get_split_with_sizes_inputs(), + "aten.split_with_sizes_copy.default": get_split_with_sizes_inputs(), "aten.split.Tensor": get_split_tensor_inputs(), }