From ae4940c5a168a770363d3591d038f282dbe19d6f Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Mon, 29 Apr 2024 10:37:23 -0700 Subject: [PATCH] Add support for scatter_slice --- backends/apple/mps/operators/indexing_ops.py | 88 ++++++++++++++++++- backends/apple/mps/operators/node_visitor.py | 33 +++++++ backends/apple/mps/runtime/MPSGraphBuilder.h | 1 + .../mps/runtime/operations/IndexingOps.mm | 24 +++++ .../mps/runtime/operations/OperationUtils.mm | 1 + .../mps/serialization/mps_graph_schema.py | 8 ++ backends/apple/mps/serialization/schema.fbs | 9 ++ .../apple/mps/test/test_mps_indexing_ops.py | 41 ++++++++- 8 files changed, 200 insertions(+), 5 deletions(-) diff --git a/backends/apple/mps/operators/indexing_ops.py b/backends/apple/mps/operators/indexing_ops.py index 690549973a4..02506e11823 100644 --- a/backends/apple/mps/operators/indexing_ops.py +++ b/backends/apple/mps/operators/indexing_ops.py @@ -16,6 +16,7 @@ MPSIndexPut, MPSIndexSelect, MPSIndexTensor, + MPSScatter, ) from executorch.backends.apple.mps.utils.mps_utils import get_input_node from executorch.backends.transforms import get_shape @@ -65,12 +66,9 @@ def define_node( mps_graph.mps_nodes.append(mps_node) -# [MPS TODO]: Works on a single iteration of llama2, but subsequent tokens -# are wrong when using Index put. Disabling it for now. @register_node_visitor class IndexPutVisitor(NodeVisitor): - # target = "aten.index_put.default" - target = "disabled" + target = "aten.index_put.default" def __init__(self, *args) -> None: super().__init__(*args) @@ -115,6 +113,88 @@ def define_node( mps_graph.mps_nodes.append(mps_node) +@register_node_visitor +class SliceScatterVisitor(NodeVisitor): + target = "aten.slice_scatter.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + self.invalid_val = 2**63 - 1 + + def maybe_wrap_dim(self, dim: int, n: int) -> List[int]: + if dim < 0: + wrapped_dim = dim + n + if wrapped_dim < 0: + wrapped_dim = 0 + return wrapped_dim + elif dim > n: + return n + return dim + + def get_exapnded_index(self, idx, shape, dim): + if idx.dim() == 0: + return idx.expand(shape) + + dim = self.maybe_wrap_dim(dim, len(shape)) + + # setup new_index_shape as [BS, 1, ..., idx_size, ..., 1] + # to reshape index_ + idx_size = idx.size(0) + new_index_shape = [1] * len(shape) + new_index_shape[dim] = idx_size + + # Now apply expand to index_ + index = idx.view(new_index_shape) + new_index_shape = list(shape) + new_index_shape[dim] = idx_size + index = index.expand(new_index_shape) + + return index + + def get_slice_scatter_indices( + self, dim, start, end, step, input_shape, dtype=torch.int64 + ): + idx = torch.arange(start, end, step, dtype=dtype) + return self.get_exapnded_index(idx, input_shape, dim) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_node = self.create_unary_node(node, mps_graph, MPSScatter) + + start = None + end = None + step = 1 + + mps_node.mpsnode_union.src_id = self.define_tensor( + get_input_node(node, 1), mps_graph + ) + if len(node.args) >= 3: + mps_node.mpsnode_union.dim = cast(int, node.args[2]) + if len(node.args) >= 4: + start = cast(int, node.args[3]) + if len(node.args) >= 5 and node.args[4] != self.invalid_val: + end = cast(int, node.args[4]) + if len(node.args) >= 6: + step = cast(int, node.args[5]) + + input_shape = get_shape(get_input_node(node, 0)) + dim_len = input_shape[ + self.maybe_wrap_dim(mps_node.mpsnode_union.dim, len(input_shape)) + ] + + start_val = start if start is not None else 0 + end_val = end if end is not None else dim_len + + scatter_indices = self.get_slice_scatter_indices( + mps_node.mpsnode_union.dim, start_val, end_val, step, input_shape + ) + mps_node.mpsnode_union.idx_id = self.define_constant(scatter_indices, mps_graph) + mps_graph.mps_nodes.append(mps_node) + + @register_node_visitor class EmbeddingVisitor(NodeVisitor): target = "aten.embedding.default" diff --git a/backends/apple/mps/operators/node_visitor.py b/backends/apple/mps/operators/node_visitor.py index e9f879db88a..0b9b2d5512c 100644 --- a/backends/apple/mps/operators/node_visitor.py +++ b/backends/apple/mps/operators/node_visitor.py @@ -143,6 +143,38 @@ def define_tensor_list(self, node: torch.fx.Node, mps_graph: MPSGraph) -> List[i mps_graph.mps_values.append(mps_tensor) return self.tensor_to_id[node] + def define_constant( + self, + constant_tensor: torch.tensor, + mps_graph: MPSGraph, + ): + """Defines a scalar value into the MPSGraph serialization schema + + Args: + tensor (torch.fx.Node): EdgeIR tensor to define into mps_graph + mps_graph (MPSGraph): MPSGraph object for serializing into flatbuffer + """ + constant_tensor = constant_tensor.contiguous() + # MPS TODO: cache these values + id = len(mps_graph.mps_values) + self.tensor_to_id[constant_tensor] = id + mps_data_type = edge_dtype_to_mps_dtype(constant_tensor.dtype) + constant_buffer_size, constant_buffer, mps_data_type = self.get_serialized_data( + constant_tensor, mps_graph, mps_data_type, id + ) + dims = list(constant_tensor.shape) + + mps_tensor = MPSTensor( + datatype=mps_data_type, + num_dims=len(dims), + dims=dims, + constant_buffer_size=constant_buffer_size, + constant_buffer=constant_buffer, + ) + + mps_graph.mps_values.append(mps_tensor) + return id + def define_scalar( self, val: Union[float, int], @@ -157,6 +189,7 @@ def define_scalar( """ assert isinstance(val, int) or isinstance(val, float) + # MPS TODO: cache these values id = len(mps_graph.mps_values) self.tensor_to_id[val] = id diff --git a/backends/apple/mps/runtime/MPSGraphBuilder.h b/backends/apple/mps/runtime/MPSGraphBuilder.h index e4e89d68691..29b9471ae9a 100644 --- a/backends/apple/mps/runtime/MPSGraphBuilder.h +++ b/backends/apple/mps/runtime/MPSGraphBuilder.h @@ -123,6 +123,7 @@ class MPSGraphBuilder { _DEFINE_MPS_OP(Embedding); _DEFINE_MPS_OP(IndexTensor); _DEFINE_MPS_OP(IndexPut); + _DEFINE_MPS_OP(Scatter); // Linear algebra ops _DEFINE_MPS_OP(MatMul); _DEFINE_MPS_OP(Addmm); diff --git a/backends/apple/mps/runtime/operations/IndexingOps.mm b/backends/apple/mps/runtime/operations/IndexingOps.mm index b4dcf192b46..6536aa52cf3 100644 --- a/backends/apple/mps/runtime/operations/IndexingOps.mm +++ b/backends/apple/mps/runtime/operations/IndexingOps.mm @@ -204,6 +204,30 @@ return err; } +Error +MPSGraphBuilder::mpsScatterOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSScatter(); + ET_LOG( + Debug, "%s %d: %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + int64_t dim = graphNode->dim(); + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); + MPSGraphTensor* indicesTensor = getMPSGraphTensor(graphNode->idx_id()); + MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->src_id()); + + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph scatterAlongAxis:dim + withDataTensor:inputTensor + updatesTensor:updatesTensor + indicesTensor:indicesTensor + mode:MPSGraphScatterModeSet + name:nil]; + return Error::Ok; +} + + } // namespace delegate } // namespace mps } // namespace executor diff --git a/backends/apple/mps/runtime/operations/OperationUtils.mm b/backends/apple/mps/runtime/operations/OperationUtils.mm index 648421ee2cd..21c4a0d3e7b 100644 --- a/backends/apple/mps/runtime/operations/OperationUtils.mm +++ b/backends/apple/mps/runtime/operations/OperationUtils.mm @@ -181,6 +181,7 @@ _DEFINE_MPS_NODE(Embedding); _DEFINE_MPS_NODE(IndexTensor); _DEFINE_MPS_NODE(IndexPut); + _DEFINE_MPS_NODE(Scatter); // Reduce ops _DEFINE_MPS_NODE(Mean); // Shape ops diff --git a/backends/apple/mps/serialization/mps_graph_schema.py b/backends/apple/mps/serialization/mps_graph_schema.py index 8134091a01d..6909926e8cf 100644 --- a/backends/apple/mps/serialization/mps_graph_schema.py +++ b/backends/apple/mps/serialization/mps_graph_schema.py @@ -456,6 +456,13 @@ class MPSIndexPut(MPSNode1x1): values_id: int = -1 +@dataclass +class MPSScatter(MPSNode1x1): + dim: int = 0 + idx_id: int = -1 + src_id: int = -1 + + ## ## Shape ops ## @@ -703,6 +710,7 @@ class MPSArange: MPSEmbedding, MPSIndexTensor, MPSIndexPut, + MPSScatter, # Shape ops MPSPermute, MPSView, diff --git a/backends/apple/mps/serialization/schema.fbs b/backends/apple/mps/serialization/schema.fbs index 6ba2c937f32..6e089d4526f 100644 --- a/backends/apple/mps/serialization/schema.fbs +++ b/backends/apple/mps/serialization/schema.fbs @@ -166,6 +166,14 @@ table MPSIndexPut { output_id:int; } +table MPSScatter { + input1_id:int; + output_id:int; + dim:long; + idx_id:int; + src_id:int; +} + // Shape ops. table MPSPermute { input1_id:int; @@ -390,6 +398,7 @@ union MPSNodeUnion { MPSEmbedding, MPSIndexTensor, MPSIndexPut, + MPSScatter, // Reduce ops MPSMean, diff --git a/backends/apple/mps/test/test_mps_indexing_ops.py b/backends/apple/mps/test/test_mps_indexing_ops.py index 7991f1a165a..03709fc891a 100644 --- a/backends/apple/mps/test/test_mps_indexing_ops.py +++ b/backends/apple/mps/test/test_mps_indexing_ops.py @@ -201,7 +201,6 @@ def forward(self, x): # ) def test_mps_indexing_put_1(self): - class IndexPut(torch.nn.Module): def __init__(self): super().__init__() @@ -223,3 +222,43 @@ def forward(self, x, y, z): self.lower_and_test_with_partitioner( module, model_inputs, func_name=inspect.stack()[0].function[5:] ) + + def test_mps_indexing_slice_scatter_1(self): + class IndexSliceScatter(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x.slice_scatter(y, start=6) + + module = IndexSliceScatter() + input = torch.zeros(8, 8) + src = torch.ones(2, 8) + model_inputs = ( + input, + src, + ) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_indexing_slice_scatter_2(self): + class IndexSliceScatter(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x.slice_scatter(y, dim=1, start=2, end=6, step=2) + + module = IndexSliceScatter() + input = torch.zeros(8, 8) + src = torch.ones(8, 2) + model_inputs = ( + input, + src, + ) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + )