Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 84 additions & 4 deletions backends/apple/mps/operators/indexing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
MPSIndexPut,
MPSIndexSelect,
MPSIndexTensor,
MPSScatter,
)
from executorch.backends.apple.mps.utils.mps_utils import get_input_node
from executorch.backends.transforms import get_shape
Expand Down Expand Up @@ -65,12 +66,9 @@ def define_node(
mps_graph.mps_nodes.append(mps_node)


# [MPS TODO]: Works on a single iteration of llama2, but subsequent tokens
# are wrong when using Index put. Disabling it for now.
@register_node_visitor
class IndexPutVisitor(NodeVisitor):
# target = "aten.index_put.default"
target = "disabled"
target = "aten.index_put.default"

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down Expand Up @@ -115,6 +113,88 @@ def define_node(
mps_graph.mps_nodes.append(mps_node)


@register_node_visitor
class SliceScatterVisitor(NodeVisitor):
target = "aten.slice_scatter.default"

def __init__(self, *args) -> None:
super().__init__(*args)
self.invalid_val = 2**63 - 1

def maybe_wrap_dim(self, dim: int, n: int) -> List[int]:
if dim < 0:
wrapped_dim = dim + n
if wrapped_dim < 0:
wrapped_dim = 0
return wrapped_dim
elif dim > n:
return n
return dim

def get_exapnded_index(self, idx, shape, dim):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

if idx.dim() == 0:
return idx.expand(shape)

dim = self.maybe_wrap_dim(dim, len(shape))

# setup new_index_shape as [BS, 1, ..., idx_size, ..., 1]
# to reshape index_
idx_size = idx.size(0)
new_index_shape = [1] * len(shape)
new_index_shape[dim] = idx_size

# Now apply expand to index_
index = idx.view(new_index_shape)
new_index_shape = list(shape)
new_index_shape[dim] = idx_size
index = index.expand(new_index_shape)

return index

def get_slice_scatter_indices(
self, dim, start, end, step, input_shape, dtype=torch.int64
):
idx = torch.arange(start, end, step, dtype=dtype)
return self.get_exapnded_index(idx, input_shape, dim)

def define_node(
self,
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
mps_node = self.create_unary_node(node, mps_graph, MPSScatter)

start = None
end = None
step = 1

mps_node.mpsnode_union.src_id = self.define_tensor(
get_input_node(node, 1), mps_graph
)
if len(node.args) >= 3:
mps_node.mpsnode_union.dim = cast(int, node.args[2])
if len(node.args) >= 4:
start = cast(int, node.args[3])
if len(node.args) >= 5 and node.args[4] != self.invalid_val:
end = cast(int, node.args[4])
if len(node.args) >= 6:
step = cast(int, node.args[5])

input_shape = get_shape(get_input_node(node, 0))
dim_len = input_shape[
self.maybe_wrap_dim(mps_node.mpsnode_union.dim, len(input_shape))
]

start_val = start if start is not None else 0
end_val = end if end is not None else dim_len

scatter_indices = self.get_slice_scatter_indices(
mps_node.mpsnode_union.dim, start_val, end_val, step, input_shape
)
mps_node.mpsnode_union.idx_id = self.define_constant(scatter_indices, mps_graph)
mps_graph.mps_nodes.append(mps_node)


@register_node_visitor
class EmbeddingVisitor(NodeVisitor):
target = "aten.embedding.default"
Expand Down
33 changes: 33 additions & 0 deletions backends/apple/mps/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,38 @@ def define_tensor_list(self, node: torch.fx.Node, mps_graph: MPSGraph) -> List[i
mps_graph.mps_values.append(mps_tensor)
return self.tensor_to_id[node]

def define_constant(
self,
constant_tensor: torch.tensor,
mps_graph: MPSGraph,
):
"""Defines a scalar value into the MPSGraph serialization schema

Args:
tensor (torch.fx.Node): EdgeIR tensor to define into mps_graph
mps_graph (MPSGraph): MPSGraph object for serializing into flatbuffer
"""
constant_tensor = constant_tensor.contiguous()
# MPS TODO: cache these values
id = len(mps_graph.mps_values)
self.tensor_to_id[constant_tensor] = id
mps_data_type = edge_dtype_to_mps_dtype(constant_tensor.dtype)
constant_buffer_size, constant_buffer, mps_data_type = self.get_serialized_data(
constant_tensor, mps_graph, mps_data_type, id
)
dims = list(constant_tensor.shape)

mps_tensor = MPSTensor(
datatype=mps_data_type,
num_dims=len(dims),
dims=dims,
constant_buffer_size=constant_buffer_size,
constant_buffer=constant_buffer,
)

mps_graph.mps_values.append(mps_tensor)
return id

def define_scalar(
self,
val: Union[float, int],
Expand All @@ -157,6 +189,7 @@ def define_scalar(
"""
assert isinstance(val, int) or isinstance(val, float)

# MPS TODO: cache these values
id = len(mps_graph.mps_values)
self.tensor_to_id[val] = id

Expand Down
1 change: 1 addition & 0 deletions backends/apple/mps/runtime/MPSGraphBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class MPSGraphBuilder {
_DEFINE_MPS_OP(Embedding);
_DEFINE_MPS_OP(IndexTensor);
_DEFINE_MPS_OP(IndexPut);
_DEFINE_MPS_OP(Scatter);
// Linear algebra ops
_DEFINE_MPS_OP(MatMul);
_DEFINE_MPS_OP(Addmm);
Expand Down
24 changes: 24 additions & 0 deletions backends/apple/mps/runtime/operations/IndexingOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,30 @@
return err;
}

Error
MPSGraphBuilder::mpsScatterOp(NodePtr nodePtr) {
auto graphNode = nodePtr->mpsnode_union_as_MPSScatter();
ET_LOG(
Debug, "%s %d: %d",
__FUNCTION__, graphNode->input1_id(), graphNode->output_id()
);

int64_t dim = graphNode->dim();
MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
MPSGraphTensor* indicesTensor = getMPSGraphTensor(graphNode->idx_id());
MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->src_id());

_idToMPSGraphTensor[graphNode->output_id()] =
[_mpsGraph scatterAlongAxis:dim
withDataTensor:inputTensor
updatesTensor:updatesTensor
indicesTensor:indicesTensor
mode:MPSGraphScatterModeSet
name:nil];
return Error::Ok;
}


} // namespace delegate
} // namespace mps
} // namespace executor
Expand Down
1 change: 1 addition & 0 deletions backends/apple/mps/runtime/operations/OperationUtils.mm
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
_DEFINE_MPS_NODE(Embedding);
_DEFINE_MPS_NODE(IndexTensor);
_DEFINE_MPS_NODE(IndexPut);
_DEFINE_MPS_NODE(Scatter);
// Reduce ops
_DEFINE_MPS_NODE(Mean);
// Shape ops
Expand Down
8 changes: 8 additions & 0 deletions backends/apple/mps/serialization/mps_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,13 @@ class MPSIndexPut(MPSNode1x1):
values_id: int = -1


@dataclass
class MPSScatter(MPSNode1x1):
dim: int = 0
idx_id: int = -1
src_id: int = -1


##
## Shape ops
##
Expand Down Expand Up @@ -703,6 +710,7 @@ class MPSArange:
MPSEmbedding,
MPSIndexTensor,
MPSIndexPut,
MPSScatter,
# Shape ops
MPSPermute,
MPSView,
Expand Down
9 changes: 9 additions & 0 deletions backends/apple/mps/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ table MPSIndexPut {
output_id:int;
}

table MPSScatter {
input1_id:int;
output_id:int;
dim:long;
idx_id:int;
src_id:int;
}

// Shape ops.
table MPSPermute {
input1_id:int;
Expand Down Expand Up @@ -390,6 +398,7 @@ union MPSNodeUnion {
MPSEmbedding,
MPSIndexTensor,
MPSIndexPut,
MPSScatter,

// Reduce ops
MPSMean,
Expand Down
41 changes: 40 additions & 1 deletion backends/apple/mps/test/test_mps_indexing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def forward(self, x):
# )

def test_mps_indexing_put_1(self):

class IndexPut(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -223,3 +222,43 @@ def forward(self, x, y, z):
self.lower_and_test_with_partitioner(
module, model_inputs, func_name=inspect.stack()[0].function[5:]
)

def test_mps_indexing_slice_scatter_1(self):
class IndexSliceScatter(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return x.slice_scatter(y, start=6)

module = IndexSliceScatter()
input = torch.zeros(8, 8)
src = torch.ones(2, 8)
model_inputs = (
input,
src,
)

self.lower_and_test_with_partitioner(
module, model_inputs, func_name=inspect.stack()[0].function[5:]
)

def test_mps_indexing_slice_scatter_2(self):
class IndexSliceScatter(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return x.slice_scatter(y, dim=1, start=2, end=6, step=2)

module = IndexSliceScatter()
input = torch.zeros(8, 8)
src = torch.ones(8, 2)
model_inputs = (
input,
src,
)

self.lower_and_test_with_partitioner(
module, model_inputs, func_name=inspect.stack()[0].function[5:]
)