Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
lintrunner -a
  • Loading branch information
George Pawelczak committed Dec 18, 2023
commit 2f3dda4a83512a71fe02d112d9fbb4693f0596eb
35 changes: 24 additions & 11 deletions backends/apple/mps/mps_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def get_param_from_node(
return None


def create_mpsgraph_constant_tensor(tensor: torch.Tensor, mpsGraph, convert_model_to_fp16: bool):
def create_mpsgraph_constant_tensor(
tensor: torch.Tensor, mpsGraph, convert_model_to_fp16: bool
):
dtype = get_mps_data_type(tensor.dtype)
if convert_model_to_fp16 and dtype == get_mps_data_type(torch.float32):
tensor = tensor.half()
Expand Down Expand Up @@ -178,19 +180,24 @@ def _apply_to_structure(value, func):
def get_node(self, key, cast_from_type=None, cast_to_type=None):
value = dict.__getitem__(self, key)
if cast_from_type:
assert(cast_to_type)
assert cast_to_type

def handle(value):
current_data_type = mpsGraph.get_data_type(value)
if current_data_type == get_mps_data_type(cast_from_type):
value = mpsGraph.cast_tensor(value, get_mps_data_type(cast_to_type))
value = mpsGraph.cast_tensor(
value, get_mps_data_type(cast_to_type)
)
return value

value = GraphNodesDict._apply_to_structure(value, handle)
return value

def __getitem__(self, key):
if self._convert_model_to_fp16:
return self.get_node(key, cast_from_type=torch.float32, cast_to_type=torch.float16)
return self.get_node(
key, cast_from_type=torch.float32, cast_to_type=torch.float16
)
return self.get_node(key)

def __setitem__(self, key, value):
Expand All @@ -212,7 +219,8 @@ def __repr__(self):
graphNodes[node.name] = create_mpsgraph_constant_tensor(
tensor=attr,
mpsGraph=mpsGraph,
convert_model_to_fp16=convert_model_to_fp16)
convert_model_to_fp16=convert_model_to_fp16,
)

# Handle inputs to the graph.
elif node.op == "placeholder":
Expand All @@ -223,7 +231,8 @@ def __repr__(self):
graphNodes[node.name] = create_mpsgraph_constant_tensor(
tensor=lifted_param_or_buffer,
mpsGraph=mpsGraph,
convert_model_to_fp16=convert_model_to_fp16)
convert_model_to_fp16=convert_model_to_fp16,
)
else:
if node.meta["val"] is None:
continue
Expand Down Expand Up @@ -818,15 +827,19 @@ def __repr__(self):
# Handle output nodes in the graph.
elif node.op == "output":
output_nodes = []
assert(isinstance(node.meta["val"], (tuple, list)))
assert(len(node.args) == 1)
assert(len(node.meta["val"]) == len(node.args[0]))
assert isinstance(node.meta["val"], (tuple, list))
assert len(node.args) == 1
assert len(node.meta["val"]) == len(node.args[0])
for i in range(len(node.args[0])):
cast_kwargs = {}
if get_mps_data_type(node.meta["val"][i].dtype) == get_mps_data_type(torch.float32):
if get_mps_data_type(
node.meta["val"][i].dtype
) == get_mps_data_type(torch.float32):
cast_kwargs["cast_from_type"] = torch.float16
cast_kwargs["cast_to_type"] = torch.float32
output_nodes.append(graphNodes.get_node(node.args[0][i].name, **cast_kwargs))
output_nodes.append(
graphNodes.get_node(node.args[0][i].name, **cast_kwargs)
)
mpsGraph.set_outputs(*output_nodes)
else:
torch._assert(
Expand Down
35 changes: 29 additions & 6 deletions backends/apple/mps/test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def run_model(
# Step 3: Lower to MPSGraph
logging.info("Step 3: Lowering to MPSGraph...")
compile_specs = [CompileSpec("use_fp16", bytes([use_fp16]))]
lowered_module = to_backend(MPSBackend.__name__, edge.exported_program, compile_specs)
lowered_module = to_backend(
MPSBackend.__name__, edge.exported_program, compile_specs
)

logging.info("Step 4: Capturing executorch program with lowered module...")

Expand Down Expand Up @@ -468,7 +470,14 @@ def test_conv1d(self):
example_inputs = (torch.randn(1, 57, 40),)
stride = random.randint(1, 4)
padding = random.randint(1, 4)
conv = torch.nn.Conv1d(57, 20, stride=stride, padding=padding, kernel_size=3, bias=random.choice([True, False]))
conv = torch.nn.Conv1d(
57,
20,
stride=stride,
padding=padding,
kernel_size=3,
bias=random.choice([True, False]),
)
conv.eval()
self.lower_and_test_with_partitioner(
conv, example_inputs, func_name=inspect.stack()[0].function[5:]
Expand All @@ -484,9 +493,16 @@ def test_conv2d(self):
weight_memory_format = torch.contiguous_format
strideX = random.randint(1, 4)
strideY = random.randint(1, 4)
example_inputs = (torch.randn(N, C, H, W).to(memory_format=input_memory_format), )
example_inputs = (
torch.randn(N, C, H, W).to(memory_format=input_memory_format),
)
conv = torch.nn.Conv2d(
in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY))
in_channels=N,
out_channels=C,
kernel_size=H,
groups=groups,
stride=(strideX, strideY),
)
conv.weight.data = conv.weight.to(memory_format=weight_memory_format)
conv.eval()
self.lower_and_test_with_partitioner(
Expand All @@ -503,9 +519,16 @@ def test_conv2d_to_depthwise_conv_3d(self):
weight_memory_format = torch.contiguous_format
strideX = random.randint(1, 4)
strideY = random.randint(1, 4)
example_inputs = (torch.randn(N, C, H, W).to(memory_format=input_memory_format), )
example_inputs = (
torch.randn(N, C, H, W).to(memory_format=input_memory_format),
)
conv = torch.nn.Conv2d(
in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY))
in_channels=N,
out_channels=C,
kernel_size=H,
groups=groups,
stride=(strideX, strideY),
)
conv.weight.data = conv.weight.to(memory_format=weight_memory_format)
conv.eval()
self.lower_and_test_with_partitioner(
Expand Down
Loading