From 0d9f2bde998a2d1835b143e49a4d8b4b2d697f57 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 11 Apr 2024 14:32:19 -0700 Subject: [PATCH] Dynamic Shapes (#2442) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/2442 Only need to look at tester.py file for the tester changes. Change is from `.run_method().compare_outputs() ` to `.run_method_and_compare_outputs()` now if Tester is initialized with dynamic inputs, we will generate random dynamic inputs (according to the specification of the dynamic shapes) to run on the model. This allows us to test that the inputs fed into the model can be dynamic. We ad a num_runs to run_method_and_compare_outputs so that we can choose to run a number of different dynamic inputs with dynamic shapes. Reviewed By: digantdesai, kirklandsign Differential Revision: D54650121 --- backends/xnnpack/test/models/deeplab_v3.py | 3 +- backends/xnnpack/test/models/edsr.py | 6 +- backends/xnnpack/test/models/emformer_rnnt.py | 28 ++-- backends/xnnpack/test/models/inception_v3.py | 6 +- backends/xnnpack/test/models/inception_v4.py | 6 +- .../xnnpack/test/models/llama2_et_example.py | 3 +- backends/xnnpack/test/models/mobilebert.py | 3 +- backends/xnnpack/test/models/mobilenet_v2.py | 6 +- backends/xnnpack/test/models/mobilenet_v3.py | 6 +- backends/xnnpack/test/models/resnet.py | 6 +- .../xnnpack/test/models/torchvision_vit.py | 3 +- .../xnnpack/test/models/very_big_model.py | 3 +- backends/xnnpack/test/models/w2l.py | 6 +- backends/xnnpack/test/ops/abs.py | 3 +- backends/xnnpack/test/ops/add.py | 24 +-- backends/xnnpack/test/ops/avgpool2d.py | 3 +- backends/xnnpack/test/ops/bilinear2d.py | 6 +- backends/xnnpack/test/ops/cat.py | 114 ++++++------- backends/xnnpack/test/ops/ceil.py | 3 +- backends/xnnpack/test/ops/clamp.py | 6 +- backends/xnnpack/test/ops/conv1d.py | 3 +- backends/xnnpack/test/ops/conv2d.py | 3 +- backends/xnnpack/test/ops/div.py | 6 +- backends/xnnpack/test/ops/elu.py | 9 +- backends/xnnpack/test/ops/floor.py | 3 +- backends/xnnpack/test/ops/hardswish.py | 6 +- backends/xnnpack/test/ops/hardtanh.py | 9 +- backends/xnnpack/test/ops/leaky_relu.py | 12 +- backends/xnnpack/test/ops/linear.py | 10 +- backends/xnnpack/test/ops/max_dim.py | 6 +- backends/xnnpack/test/ops/maximum.py | 6 +- backends/xnnpack/test/ops/maxpool2d.py | 6 +- backends/xnnpack/test/ops/mean_dim.py | 6 +- backends/xnnpack/test/ops/minimum.py | 3 +- backends/xnnpack/test/ops/multiply.py | 15 +- backends/xnnpack/test/ops/negate.py | 3 +- backends/xnnpack/test/ops/permute.py | 12 +- backends/xnnpack/test/ops/pow.py | 3 +- backends/xnnpack/test/ops/prelu.py | 3 +- .../xnnpack/test/ops/quantize_per_tensor.py | 6 +- backends/xnnpack/test/ops/relu.py | 3 +- backends/xnnpack/test/ops/sdpa.py | 3 +- backends/xnnpack/test/ops/sigmoid.py | 3 +- backends/xnnpack/test/ops/slice_copy.py | 6 +- backends/xnnpack/test/ops/softmax.py | 3 +- backends/xnnpack/test/ops/sqrt.py | 8 +- backends/xnnpack/test/ops/square.py | 3 +- .../xnnpack/test/ops/static_constant_pad.py | 9 +- backends/xnnpack/test/ops/sub.py | 15 +- .../test/passes/test_batch_norm_fusion.py | 6 +- .../test_channels_last_tagged_reshape.py | 15 +- .../test/passes/test_convert_to_linear.py | 3 +- .../test/passes/test_remove_get_item_pass.py | 12 +- .../passes/test_tag_implicit_q_dq_pass.py | 3 +- backends/xnnpack/test/tester/tester.py | 151 ++++++++++++++---- 55 files changed, 298 insertions(+), 319 deletions(-) diff --git a/backends/xnnpack/test/models/deeplab_v3.py b/backends/xnnpack/test/models/deeplab_v3.py index ccaccb898d2..c5f6bfe17bc 100644 --- a/backends/xnnpack/test/models/deeplab_v3.py +++ b/backends/xnnpack/test/models/deeplab_v3.py @@ -36,6 +36,5 @@ def test_fp32_dl3(self): .partition() .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/edsr.py b/backends/xnnpack/test/models/edsr.py index d748e35bb74..ca080b20b49 100644 --- a/backends/xnnpack/test/models/edsr.py +++ b/backends/xnnpack/test/models/edsr.py @@ -25,8 +25,7 @@ def test_fp32_edsr(self): .partition() .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_edsr(self): @@ -38,6 +37,5 @@ def test_qs8_edsr(self): .partition() .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/emformer_rnnt.py b/backends/xnnpack/test/models/emformer_rnnt.py index 3728c9b07c9..3992c828964 100644 --- a/backends/xnnpack/test/models/emformer_rnnt.py +++ b/backends/xnnpack/test/models/emformer_rnnt.py @@ -21,8 +21,8 @@ def __init__(self): self.rnnt = decoder.model class Joiner(EmformerRnnt): - def forward(self, predict_inputs): - return self.rnnt.join(*predict_inputs) + def forward(self, a, b, c, d): + return self.rnnt.join(a, b, c, d) def get_example_inputs(self): join_inputs = ( @@ -31,7 +31,7 @@ def get_example_inputs(self): torch.rand([1, 128, 1024]), torch.tensor([128]), ) - return (join_inputs,) + return join_inputs def test_fp32_emformer_joiner(self): joiner = self.Joiner() @@ -43,21 +43,19 @@ def test_fp32_emformer_joiner(self): .check(["torch.ops.higher_order.executorch_call_delegate"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) class Predictor(EmformerRnnt): - def forward(self, predict_inputs): - return self.rnnt.predict(*predict_inputs) + def forward(self, a, b): + return self.rnnt.predict(a, b, None) def get_example_inputs(self): predict_inputs = ( torch.zeros([1, 128], dtype=int), torch.tensor([128], dtype=int), - None, ) - return (predict_inputs,) + return predict_inputs @unittest.skip("T183426271") def test_fp32_emformer_predictor(self): @@ -70,20 +68,19 @@ def test_fp32_emformer_predictor(self): .check(["torch.ops.higher_order.executorch_call_delegate"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) class Transcriber(EmformerRnnt): - def forward(self, predict_inputs): - return self.rnnt.transcribe(*predict_inputs) + def forward(self, a, b): + return self.rnnt.transcribe(a, b) def get_example_inputs(self): transcribe_inputs = ( torch.randn(1, 128, 80), torch.tensor([128]), ) - return (transcribe_inputs,) + return transcribe_inputs def test_fp32_emformer_transcriber(self): transcriber = self.Transcriber() @@ -95,6 +92,5 @@ def test_fp32_emformer_transcriber(self): .check(["torch.ops.higher_order.executorch_call_delegate"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/inception_v3.py b/backends/xnnpack/test/models/inception_v3.py index 58839014557..b861afc5cd5 100644 --- a/backends/xnnpack/test/models/inception_v3.py +++ b/backends/xnnpack/test/models/inception_v3.py @@ -42,8 +42,7 @@ def test_fp32_ic3(self): .check_not(list(self.all_operators)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_ic3(self): @@ -63,6 +62,5 @@ def test_qs8_ic3(self): .check_not(list(ops_after_quantization)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/inception_v4.py b/backends/xnnpack/test/models/inception_v4.py index 534fb90ad6c..528512c82f2 100644 --- a/backends/xnnpack/test/models/inception_v4.py +++ b/backends/xnnpack/test/models/inception_v4.py @@ -39,8 +39,7 @@ def test_fp32_ic4(self): .check_not(list(self.all_operators)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_ic4(self): @@ -60,6 +59,5 @@ def test_qs8_ic4(self): .check_not(list(ops_after_quantization)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/llama2_et_example.py b/backends/xnnpack/test/models/llama2_et_example.py index 46dae356cd8..4716f2d6a95 100644 --- a/backends/xnnpack/test/models/llama2_et_example.py +++ b/backends/xnnpack/test/models/llama2_et_example.py @@ -45,6 +45,5 @@ def _test(self, dtype: torch.dtype = torch.float): .dump_artifact() .to_executorch() .serialize() - .run_method() - .compare_outputs(atol=5e-2) + .run_method_and_compare_outputs(atol=5e-2) ) diff --git a/backends/xnnpack/test/models/mobilebert.py b/backends/xnnpack/test/models/mobilebert.py index bf6b2dfc408..df66ffd4507 100644 --- a/backends/xnnpack/test/models/mobilebert.py +++ b/backends/xnnpack/test/models/mobilebert.py @@ -38,6 +38,5 @@ def test_fp32_mobilebert(self): .check_not(list(self.supported_ops)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/mobilenet_v2.py b/backends/xnnpack/test/models/mobilenet_v2.py index dbd9bc744b4..9a0e19b9290 100644 --- a/backends/xnnpack/test/models/mobilenet_v2.py +++ b/backends/xnnpack/test/models/mobilenet_v2.py @@ -40,8 +40,7 @@ def test_fp32_mv2(self): .check_not(list(self.all_operators)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_mv2(self): @@ -61,6 +60,5 @@ def test_qs8_mv2(self): .check_not(list(ops_after_quantization)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/mobilenet_v3.py b/backends/xnnpack/test/models/mobilenet_v3.py index 20d04b119e1..1287bb6e969 100644 --- a/backends/xnnpack/test/models/mobilenet_v3.py +++ b/backends/xnnpack/test/models/mobilenet_v3.py @@ -42,8 +42,7 @@ def test_fp32_mv3(self): .check_not(list(self.all_operators)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_mv3(self): @@ -63,6 +62,5 @@ def test_qs8_mv3(self): .check_not(list(ops_after_lowering)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/resnet.py b/backends/xnnpack/test/models/resnet.py index 73e68c855e9..26ffcffbb56 100644 --- a/backends/xnnpack/test/models/resnet.py +++ b/backends/xnnpack/test/models/resnet.py @@ -23,8 +23,7 @@ def test_fp32_resnet18(self): .partition() .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_resnet18(self): @@ -37,6 +36,5 @@ def test_qs8_resnet18(self): .partition() .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/torchvision_vit.py b/backends/xnnpack/test/models/torchvision_vit.py index 226cc73f401..de5f263641d 100644 --- a/backends/xnnpack/test/models/torchvision_vit.py +++ b/backends/xnnpack/test/models/torchvision_vit.py @@ -57,6 +57,5 @@ def test_fp32_vit(self): .check_not(list(lowerable_xnn_operators)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/very_big_model.py b/backends/xnnpack/test/models/very_big_model.py index 2200b50a6b2..f3f06380414 100644 --- a/backends/xnnpack/test/models/very_big_model.py +++ b/backends/xnnpack/test/models/very_big_model.py @@ -39,6 +39,5 @@ def test_very_big_model(self): .check(["torch.ops.higher_order.executorch_call_delegate"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/w2l.py b/backends/xnnpack/test/models/w2l.py index 10d7ca15b08..c95fc29d8cc 100644 --- a/backends/xnnpack/test/models/w2l.py +++ b/backends/xnnpack/test/models/w2l.py @@ -34,8 +34,7 @@ def test_fp32_w2l(self): .check(["torch.ops.higher_order.executorch_call_delegate"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_w2l(self): @@ -54,6 +53,5 @@ def test_qs8_w2l(self): .check(["torch.ops.higher_order.executorch_call_delegate"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/abs.py b/backends/xnnpack/test/ops/abs.py index c71fe5ab4e0..2906654dfb7 100644 --- a/backends/xnnpack/test/ops/abs.py +++ b/backends/xnnpack/test/ops/abs.py @@ -31,8 +31,7 @@ def _test_abs(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_abs_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_abs(self): diff --git a/backends/xnnpack/test/ops/add.py b/backends/xnnpack/test/ops/add.py index 3a56e0f4c6a..8b0d0c6234d 100644 --- a/backends/xnnpack/test/ops/add.py +++ b/backends/xnnpack/test/ops/add.py @@ -54,8 +54,7 @@ def _test_add(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_add(self): @@ -79,8 +78,7 @@ def test_fp32_add_constant(self): .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_add_constant(self): @@ -121,8 +119,7 @@ def test_qs8_add(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_add2(self): @@ -145,8 +142,7 @@ def test_qs8_add2(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_add3(self): @@ -169,8 +165,7 @@ def test_qs8_add3(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) class AddRelu(torch.nn.Module): @@ -194,8 +189,7 @@ def test_fp32_add_relu(self): .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_add_relu(self): @@ -214,8 +208,7 @@ def test_qs8_add_relu(self): .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_add_relu_seq(self): @@ -261,6 +254,5 @@ def forward(self, x, z): .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/avgpool2d.py b/backends/xnnpack/test/ops/avgpool2d.py index 2dd46932988..edb92d09a35 100644 --- a/backends/xnnpack/test/ops/avgpool2d.py +++ b/backends/xnnpack/test/ops/avgpool2d.py @@ -42,8 +42,7 @@ def _test_argpool2d(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_avgpool2d(self): diff --git a/backends/xnnpack/test/ops/bilinear2d.py b/backends/xnnpack/test/ops/bilinear2d.py index 2e80eaf2bc5..ab9d3d3c11d 100644 --- a/backends/xnnpack/test/ops/bilinear2d.py +++ b/backends/xnnpack/test/ops/bilinear2d.py @@ -87,8 +87,7 @@ def test_fp32_static_resize_bilinear2d(self): .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp32_static_resize_bilinear2d_with_align_cornesr(self): @@ -103,8 +102,7 @@ def test_fp32_static_resize_bilinear2d_with_align_cornesr(self): .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp32_static_resize_bilinear2d_antialiased(self): diff --git a/backends/xnnpack/test/ops/cat.py b/backends/xnnpack/test/ops/cat.py index 8cb9b760b0d..85c5b51a2c7 100644 --- a/backends/xnnpack/test/ops/cat.py +++ b/backends/xnnpack/test/ops/cat.py @@ -11,16 +11,31 @@ class TestCat(unittest.TestCase): - class Cat(torch.nn.Module): - def forward(self, xs): + class Cat2(torch.nn.Module): + def forward(self, arg1, arg2): + xs = [arg1, arg2] x = torch.cat(xs) return x + x # Quantize by propagation. - class Cat2(torch.nn.Module): - def forward(self, xs): - return torch.cat(xs) + class Cat3(torch.nn.Module): + def forward(self, arg1, arg2, arg3): + xs = [arg1, arg2, arg3] + x = torch.cat(xs) + return x + x # Quantize by propagation. + + class Cat4(torch.nn.Module): + def forward(self, arg1, arg2, arg3, arg4): + xs = [arg1, arg2, arg3, arg4] + x = torch.cat(xs) + return x + x # Quantize by propagation. - def _test_cat(self, module, inputs, quant=False, quant_ops=2): + class Cat5(torch.nn.Module): + def forward(self, arg1, arg2, arg3, arg4, arg5): + xs = [arg1, arg2, arg3, arg4, arg5] + x = torch.cat(xs) + return x + x # Quantize by propagation. + + def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2): tester = Tester(module, inputs) if quant: @@ -36,7 +51,7 @@ def _test_cat(self, module, inputs, quant=False, quant_ops=2): # Q/DQ pair for each input and quantized op. For most tests, there are # two quantized ops - cat and add. torch.ops.quantized_decomposed.quantize_per_tensor.default: ( - len(inputs[0]) + quant_ops + cat_num + quant_ops ) } ) @@ -55,8 +70,7 @@ def _test_cat(self, module, inputs, quant=False, quant_ops=2): .check_not(["executorch_exir_dialects_edge__ops_aten_cat"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_cat2(self): @@ -64,10 +78,8 @@ def test_fp16_cat2(self): Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. """ inputs = ( - ( - torch.ones(1, 2, 3).to(torch.float16), - torch.ones(3, 2, 3).to(torch.float16), - ), + torch.ones(1, 2, 3).to(torch.float16), + torch.ones(3, 2, 3).to(torch.float16), ) self._test_cat(self.Cat2(), inputs) @@ -76,81 +88,71 @@ def test_fp16_cat3(self): Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. """ inputs = ( - ( - torch.ones(1, 2, 3).to(torch.float16), - torch.ones(3, 2, 3).to(torch.float16), - torch.ones(2, 2, 3).to(torch.float16), - ), + torch.ones(1, 2, 3).to(torch.float16), + torch.ones(3, 2, 3).to(torch.float16), + torch.ones(2, 2, 3).to(torch.float16), ) - self._test_cat(self.Cat2(), inputs) + self._test_cat(self.Cat3(), inputs) def test_fp16_cat4(self): """ Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. """ inputs = ( - ( - torch.ones(1, 2, 3).to(torch.float16), - torch.ones(3, 2, 3).to(torch.float16), - torch.ones(2, 2, 3).to(torch.float16), - torch.ones(5, 2, 3).to(torch.float16), - ), + torch.ones(1, 2, 3).to(torch.float16), + torch.ones(3, 2, 3).to(torch.float16), + torch.ones(2, 2, 3).to(torch.float16), + torch.ones(5, 2, 3).to(torch.float16), ) - self._test_cat(self.Cat2(), inputs) + self._test_cat(self.Cat4(), inputs) def test_fp32_cat2(self): - inputs = ((torch.ones(1, 2, 3), torch.ones(3, 2, 3)),) - self._test_cat(self.Cat(), inputs) + inputs = (torch.ones(1, 2, 3), torch.ones(3, 2, 3)) + self._test_cat(self.Cat2(), inputs) def test_fp32_cat3(self): - inputs = ((torch.ones(1, 2, 3), torch.ones(3, 2, 3), torch.ones(2, 2, 3)),) - self._test_cat(self.Cat(), inputs) + inputs = (torch.ones(1, 2, 3), torch.ones(3, 2, 3), torch.ones(2, 2, 3)) + self._test_cat(self.Cat3(), inputs) def test_fp32_cat4(self): inputs = ( - ( - torch.ones(1, 2, 3), - torch.ones(3, 2, 3), - torch.ones(2, 2, 3), - torch.ones(5, 2, 3), - ), + torch.ones(1, 2, 3), + torch.ones(3, 2, 3), + torch.ones(2, 2, 3), + torch.ones(5, 2, 3), ) - self._test_cat(self.Cat(), inputs) + self._test_cat(self.Cat4(), inputs) def test_qs8_cat2(self): - inputs = ((torch.ones(1, 2, 3), torch.ones(3, 2, 3)),) - self._test_cat(self.Cat(), inputs, quant=True) + inputs = (torch.ones(1, 2, 3), torch.ones(3, 2, 3)) + self._test_cat(self.Cat2(), inputs, cat_num=2, quant=True) def test_qs8_cat3(self): - inputs = ((torch.ones(1, 2, 3), torch.ones(3, 2, 3), torch.ones(2, 2, 3)),) - self._test_cat(self.Cat(), inputs, quant=True) + inputs = (torch.ones(1, 2, 3), torch.ones(3, 2, 3), torch.ones(2, 2, 3)) + self._test_cat(self.Cat3(), inputs, cat_num=3, quant=True) def test_qs8_cat4(self): inputs = ( - ( - torch.ones(1, 2, 3), - torch.ones(3, 2, 3), - torch.ones(2, 2, 3), - torch.ones(5, 2, 3), - ), + torch.ones(1, 2, 3), + torch.ones(3, 2, 3), + torch.ones(2, 2, 3), + torch.ones(5, 2, 3), ) - self._test_cat(self.Cat(), inputs, quant=True) + self._test_cat(self.Cat4(), inputs, cat_num=4, quant=True) def test_fp32_cat_unsupported(self): """ XNNPACK only supports concatenating up to 4 values, so it should not delegate here. """ inputs = ( - ( - torch.ones(1, 2, 3), - torch.ones(3, 2, 3), - torch.ones(2, 2, 3), - torch.ones(5, 2, 3), - torch.ones(1, 2, 3), - ), + torch.ones(1, 2, 3), + torch.ones(3, 2, 3), + torch.ones(2, 2, 3), + torch.ones(5, 2, 3), + torch.ones(1, 2, 3), ) ( - Tester(self.Cat(), inputs) + Tester(self.Cat5(), inputs) .export() .check_count({"torch.ops.aten.cat": 1}) .to_edge() diff --git a/backends/xnnpack/test/ops/ceil.py b/backends/xnnpack/test/ops/ceil.py index 853de03ff1d..8d59f3b35d7 100644 --- a/backends/xnnpack/test/ops/ceil.py +++ b/backends/xnnpack/test/ops/ceil.py @@ -31,8 +31,7 @@ def _test_ceil(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_ceil_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_ceil(self): diff --git a/backends/xnnpack/test/ops/clamp.py b/backends/xnnpack/test/ops/clamp.py index 6ffaed3fe1b..c52fd011f8b 100644 --- a/backends/xnnpack/test/ops/clamp.py +++ b/backends/xnnpack/test/ops/clamp.py @@ -33,8 +33,7 @@ def _test_clamp(self, module, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_clamp(self): @@ -77,6 +76,5 @@ def test_qs8_clamp(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/conv1d.py b/backends/xnnpack/test/ops/conv1d.py index 604e37c724c..50f9aa3a996 100644 --- a/backends/xnnpack/test/ops/conv1d.py +++ b/backends/xnnpack/test/ops/conv1d.py @@ -97,8 +97,7 @@ def _test_conv1d(self, module, inputs, conv_count, quantized=False): .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_conv1d(self): diff --git a/backends/xnnpack/test/ops/conv2d.py b/backends/xnnpack/test/ops/conv2d.py index 3eb80072a68..9a2bb25dc8d 100644 --- a/backends/xnnpack/test/ops/conv2d.py +++ b/backends/xnnpack/test/ops/conv2d.py @@ -152,8 +152,7 @@ def _test( .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() - .run_method() - .compare_outputs(qtol=1) + .run_method_and_compare_outputs(qtol=1) ) def test_fp16_conv2d(self) -> None: diff --git a/backends/xnnpack/test/ops/div.py b/backends/xnnpack/test/ops/div.py index 007122db981..2882c59b875 100644 --- a/backends/xnnpack/test/ops/div.py +++ b/backends/xnnpack/test/ops/div.py @@ -39,8 +39,7 @@ def _test_div(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_div(self): @@ -64,6 +63,5 @@ def test_fp32_div_single_input(self): .check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/elu.py b/backends/xnnpack/test/ops/elu.py index f1f8d7628a6..89fef6f9d4b 100644 --- a/backends/xnnpack/test/ops/elu.py +++ b/backends/xnnpack/test/ops/elu.py @@ -39,8 +39,7 @@ def _test_elu(self, inputs): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T171810227 - Missing recomposition for ELU") @@ -74,8 +73,7 @@ def test_qs8_elu(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T171810227 - Missing recomposition for ELU") @@ -99,6 +97,5 @@ def test_qs8_elu_functional(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/floor.py b/backends/xnnpack/test/ops/floor.py index 31c3da09b42..cb65ca2aa58 100644 --- a/backends/xnnpack/test/ops/floor.py +++ b/backends/xnnpack/test/ops/floor.py @@ -31,8 +31,7 @@ def _test_floor(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_floor_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_floor(self): diff --git a/backends/xnnpack/test/ops/hardswish.py b/backends/xnnpack/test/ops/hardswish.py index d35e7ab5d78..8f6a190412c 100644 --- a/backends/xnnpack/test/ops/hardswish.py +++ b/backends/xnnpack/test/ops/hardswish.py @@ -41,8 +41,7 @@ def _test_hardswish(self, inputs): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T158969708 - Missing recomposition pass for hardswish") @@ -75,6 +74,5 @@ def test_fp32_hardswish_functional(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/hardtanh.py b/backends/xnnpack/test/ops/hardtanh.py index fdcfb7c7efe..d13624663ca 100644 --- a/backends/xnnpack/test/ops/hardtanh.py +++ b/backends/xnnpack/test/ops/hardtanh.py @@ -38,8 +38,7 @@ def test_fp32_hardtanh(self): .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp32_hardtanh_bound(self): @@ -58,8 +57,7 @@ def test_fp32_hardtanh_bound(self): .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_hardtanh(self): @@ -90,6 +88,5 @@ def test_qs8_hardtanh(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/leaky_relu.py b/backends/xnnpack/test/ops/leaky_relu.py index 477188ed752..ae5f2e3197e 100644 --- a/backends/xnnpack/test/ops/leaky_relu.py +++ b/backends/xnnpack/test/ops/leaky_relu.py @@ -43,8 +43,7 @@ def _test_leaky_relu(self, module, inputs): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_leaky_relu(self): @@ -76,8 +75,7 @@ def test_fp32_leaky_relu_functional(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T172863987 - Missing quantizer support.") @@ -107,8 +105,7 @@ def test_qs8_leaky_relu(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T172863987 - Missing quantizer support.") @@ -143,6 +140,5 @@ def test_qs8_leaky_relu_default_slope(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/linear.py b/backends/xnnpack/test/ops/linear.py index b4a9cb62856..3b3669feba0 100644 --- a/backends/xnnpack/test/ops/linear.py +++ b/backends/xnnpack/test/ops/linear.py @@ -595,8 +595,7 @@ def _test_manual_dq_linear( ) .to_executorch() .serialize() - .run_method() - .compare_outputs(atol=atol, rtol=rtol) + .run_method_and_compare_outputs(atol=atol, rtol=rtol) ) def _run_manual_dqlinear_tests(self, weight_n_bit: int, op_dtype: torch.dtype): @@ -731,9 +730,7 @@ def _test_linear( tester.to_executorch() tester.serialize() - tester.run_method() - tester.compare_outputs(qtol=quant, atol=atol) - print("success") + tester.run_method_and_compare_outputs(qtol=quant, atol=atol) def _test_dqlinear( self, @@ -779,5 +776,4 @@ def _test_dqlinear( tester.to_executorch() tester.serialize() - tester.run_method() - tester.compare_outputs(atol=5e-02) + tester.run_method_and_compare_outputs(atol=5e-02) diff --git a/backends/xnnpack/test/ops/max_dim.py b/backends/xnnpack/test/ops/max_dim.py index b43d1ce4e82..9cab1236e4c 100644 --- a/backends/xnnpack/test/ops/max_dim.py +++ b/backends/xnnpack/test/ops/max_dim.py @@ -37,8 +37,7 @@ def _test_max_dim(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_max_dim"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T171468483 - Fails to partition due to index output dtype.") @@ -65,6 +64,5 @@ def test_fp32_max_dim_no_indices(self): .check_not(["executorch_exir_dialects_edge__ops_aten_max_dim"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/maximum.py b/backends/xnnpack/test/ops/maximum.py index 5ce05d33e37..feff02744d3 100644 --- a/backends/xnnpack/test/ops/maximum.py +++ b/backends/xnnpack/test/ops/maximum.py @@ -30,8 +30,7 @@ def _test_maximum(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_maximum_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_maximum(self): @@ -64,6 +63,5 @@ def test_fp32_maximum_broadcast(self): .check_not(["executorch_exir_dialects_edge__ops_aten_maximum_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/maxpool2d.py b/backends/xnnpack/test/ops/maxpool2d.py index 84c76a6e6c9..7e510dd9155 100644 --- a/backends/xnnpack/test/ops/maxpool2d.py +++ b/backends/xnnpack/test/ops/maxpool2d.py @@ -64,8 +64,7 @@ def _test_maxpool2d(self, inputs): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_maxpool2d(self): @@ -135,6 +134,5 @@ def forward(self, x): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/mean_dim.py b/backends/xnnpack/test/ops/mean_dim.py index b8d7e77a224..750b0e8f508 100644 --- a/backends/xnnpack/test/ops/mean_dim.py +++ b/backends/xnnpack/test/ops/mean_dim.py @@ -33,8 +33,7 @@ def _test_mean_dim(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_mean_dim(self): @@ -85,6 +84,5 @@ def test_qs8_mean_dim(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs(qtol=1) + .run_method_and_compare_outputs(qtol=1) ) diff --git a/backends/xnnpack/test/ops/minimum.py b/backends/xnnpack/test/ops/minimum.py index 5d6f08fd1a2..121fbeb1852 100644 --- a/backends/xnnpack/test/ops/minimum.py +++ b/backends/xnnpack/test/ops/minimum.py @@ -30,8 +30,7 @@ def _test_minimum(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_minimum_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_minimum(self): diff --git a/backends/xnnpack/test/ops/multiply.py b/backends/xnnpack/test/ops/multiply.py index 09f9b39ea60..d151f58bd6a 100644 --- a/backends/xnnpack/test/ops/multiply.py +++ b/backends/xnnpack/test/ops/multiply.py @@ -43,8 +43,7 @@ def _test_mul(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_mul_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_mul(self): @@ -78,8 +77,7 @@ def test_qs8_mul(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_mul2(self): @@ -102,8 +100,7 @@ def test_qs8_mul2(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_mul_functional(self): @@ -126,8 +123,7 @@ def test_qs8_mul_functional(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_mul_relu(self): @@ -156,6 +152,5 @@ def test_qs8_mul_relu(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/negate.py b/backends/xnnpack/test/ops/negate.py index b7777136f5a..c4a47bb93ce 100644 --- a/backends/xnnpack/test/ops/negate.py +++ b/backends/xnnpack/test/ops/negate.py @@ -31,8 +31,7 @@ def _test_negate(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_neg_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_negate(self): diff --git a/backends/xnnpack/test/ops/permute.py b/backends/xnnpack/test/ops/permute.py index 3441acb6315..2c995376753 100644 --- a/backends/xnnpack/test/ops/permute.py +++ b/backends/xnnpack/test/ops/permute.py @@ -45,8 +45,7 @@ def _test_permute(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_permute_copy_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_permute(self): @@ -72,8 +71,7 @@ def test_fp32_permute_copy(self): .check_not(["executorch_exir_dialects_edge__ops_aten_permute_copy_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_permute(self): @@ -102,8 +100,7 @@ def test_qs8_permute(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_permute_copy(self): @@ -132,6 +129,5 @@ def test_qs8_permute_copy(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/pow.py b/backends/xnnpack/test/ops/pow.py index b4bd6b5862c..d99f2c546e6 100644 --- a/backends/xnnpack/test/ops/pow.py +++ b/backends/xnnpack/test/ops/pow.py @@ -34,8 +34,7 @@ def _test_pow2(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_pow2(self): diff --git a/backends/xnnpack/test/ops/prelu.py b/backends/xnnpack/test/ops/prelu.py index a4e9ef7df95..985ddecf363 100644 --- a/backends/xnnpack/test/ops/prelu.py +++ b/backends/xnnpack/test/ops/prelu.py @@ -36,8 +36,7 @@ def _test_prelu(self, module, inputs): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T158653285 - Missing recomposition for PReLU") diff --git a/backends/xnnpack/test/ops/quantize_per_tensor.py b/backends/xnnpack/test/ops/quantize_per_tensor.py index 82aaca0b6f7..f912428a8ab 100644 --- a/backends/xnnpack/test/ops/quantize_per_tensor.py +++ b/backends/xnnpack/test/ops/quantize_per_tensor.py @@ -39,8 +39,7 @@ def forward(self, x): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_dequantize_per_tenstor(self): @@ -76,6 +75,5 @@ def forward(self, x): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/relu.py b/backends/xnnpack/test/ops/relu.py index c52055e45f1..3ab1c72b57d 100644 --- a/backends/xnnpack/test/ops/relu.py +++ b/backends/xnnpack/test/ops/relu.py @@ -33,6 +33,5 @@ def test_fp32_relu(self): .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/sdpa.py b/backends/xnnpack/test/ops/sdpa.py index 5cf8534c928..d68bcab2086 100644 --- a/backends/xnnpack/test/ops/sdpa.py +++ b/backends/xnnpack/test/ops/sdpa.py @@ -70,8 +70,7 @@ def _test(self, module, inputs, atol=1e-03, rtol=1e-03): ) .to_executorch() .serialize() - .run_method() - .compare_outputs(atol=atol, rtol=rtol) + .run_method_and_compare_outputs(atol=atol, rtol=rtol) ) def test_fp16_sdpa_mask2d(self): diff --git a/backends/xnnpack/test/ops/sigmoid.py b/backends/xnnpack/test/ops/sigmoid.py index be8eda605ee..5ed6fc64402 100644 --- a/backends/xnnpack/test/ops/sigmoid.py +++ b/backends/xnnpack/test/ops/sigmoid.py @@ -32,8 +32,7 @@ def _test_sigmoid(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_sigmoid(self): diff --git a/backends/xnnpack/test/ops/slice_copy.py b/backends/xnnpack/test/ops/slice_copy.py index 99b5842313f..2d0f150dd15 100644 --- a/backends/xnnpack/test/ops/slice_copy.py +++ b/backends/xnnpack/test/ops/slice_copy.py @@ -27,8 +27,7 @@ def _test_slice_copy(self, module, inputs, copy_count=1, edge_copy_count=1): .check_not(["executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_slice_copy(self): @@ -143,6 +142,5 @@ def forward(self, x): .check_not(["executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/softmax.py b/backends/xnnpack/test/ops/softmax.py index 43ff89f1206..d3f674d7ae2 100644 --- a/backends/xnnpack/test/ops/softmax.py +++ b/backends/xnnpack/test/ops/softmax.py @@ -38,8 +38,7 @@ def _test_softmax(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten__softmax_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_softmax(self): diff --git a/backends/xnnpack/test/ops/sqrt.py b/backends/xnnpack/test/ops/sqrt.py index 99ab8f72340..e2a5f4ac2f6 100644 --- a/backends/xnnpack/test/ops/sqrt.py +++ b/backends/xnnpack/test/ops/sqrt.py @@ -16,6 +16,7 @@ def __init__(self): super().__init__() def forward(self, x): + x = torch.abs(x) z = torch.sqrt(x) return z @@ -31,14 +32,13 @@ def _test_sqrt(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_sqrt_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_sqrt(self): - inputs = (torch.randn(20).to(torch.float16).abs(),) + inputs = (torch.randn(20).to(torch.float16),) self._test_sqrt(inputs) def test_fp32_sqrt(self): - inputs = (torch.randn(20).abs(),) + inputs = (torch.randn(20),) self._test_sqrt(inputs) diff --git a/backends/xnnpack/test/ops/square.py b/backends/xnnpack/test/ops/square.py index faad836becf..02dc12e16e4 100644 --- a/backends/xnnpack/test/ops/square.py +++ b/backends/xnnpack/test/ops/square.py @@ -37,8 +37,7 @@ def _test_square(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_square(self): diff --git a/backends/xnnpack/test/ops/static_constant_pad.py b/backends/xnnpack/test/ops/static_constant_pad.py index 6b8563e291d..c836b404ac7 100644 --- a/backends/xnnpack/test/ops/static_constant_pad.py +++ b/backends/xnnpack/test/ops/static_constant_pad.py @@ -99,8 +99,7 @@ def _test_static_constant_pad_functional(self, inputs): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_static_constant_pad_functional(self): @@ -154,8 +153,7 @@ def forward(self, x): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_static_constant_pad_2d(self): @@ -180,6 +178,5 @@ def test_qs8_static_constant_pad_2d(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/sub.py b/backends/xnnpack/test/ops/sub.py index bcb4f389bd6..d3cc6e8aa80 100644 --- a/backends/xnnpack/test/ops/sub.py +++ b/backends/xnnpack/test/ops/sub.py @@ -39,8 +39,7 @@ def _test_sub(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_sub_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_sub(self): @@ -75,8 +74,7 @@ def test_qs8_sub(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T171957656 - Quantized sub not implemented.") @@ -100,8 +98,7 @@ def test_qs8_sub2(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T171957656 - Quantized sub not implemented.") @@ -125,8 +122,7 @@ def test_qs8_sub3(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T171957656 - Quantized sub not implemented.") @@ -166,6 +162,5 @@ def forward(self, x, y): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/passes/test_batch_norm_fusion.py b/backends/xnnpack/test/passes/test_batch_norm_fusion.py index ab9b02af4bf..06517c526c8 100644 --- a/backends/xnnpack/test/passes/test_batch_norm_fusion.py +++ b/backends/xnnpack/test/passes/test_batch_norm_fusion.py @@ -40,8 +40,7 @@ def test_fp32_batch_norm_fusion(self): .to_edge() .run_passes(self.PassStage) .check_count({self.bn_name: 1}) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_q8_batch_norm_fusion(self): @@ -52,8 +51,7 @@ def test_q8_batch_norm_fusion(self): .to_edge() .run_passes(self.PassStage) .check_count({self.bn_name: 1}) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp32_batch_norm_no_fusion_doesnt_partition(self): diff --git a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py index abb18a8c0b2..36e566abc36 100644 --- a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py +++ b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py @@ -42,8 +42,7 @@ def test_fp32_channels_last_tagged_reshape_pass(self): self.to_copy_name: num_reshape, } ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_channels_last_tagged_reshape_pass(self): @@ -64,8 +63,7 @@ def test_qs8_channels_last_tagged_reshape_pass(self): ] * num_reshape ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) class ConvRelu(torch.nn.Module): @@ -86,8 +84,7 @@ def test_fp32_channels_last_tagged_reshape_pass_conv_relu(self): .check( [self.to_copy_name, self.conv_name, self.relu_name, self.to_copy_name] ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_channels_last_tagged_reshape_pass_conv_relu(self): @@ -109,8 +106,7 @@ def test_qs8_channels_last_tagged_reshape_pass_conv_relu(self): self.to_copy_name, ] ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) class Conv2dBnHardtanhMeanSequenceModule(torch.nn.Module): @@ -175,6 +171,5 @@ def test_fp32_channels_last_tagged_reshape_pass_conv_bn_hardtanh_mean_seq(self): self.to_copy_name: 4, } ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/passes/test_convert_to_linear.py b/backends/xnnpack/test/passes/test_convert_to_linear.py index 783336a01cd..0fa80246fd6 100644 --- a/backends/xnnpack/test/passes/test_convert_to_linear.py +++ b/backends/xnnpack/test/passes/test_convert_to_linear.py @@ -35,6 +35,5 @@ def test_fp32_convert_to_linear(self): .check_count( {"executorch_exir_dialects_edge__ops_aten_linear_default": 1} ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/passes/test_remove_get_item_pass.py b/backends/xnnpack/test/passes/test_remove_get_item_pass.py index 35bd4d8b966..fa68c403e38 100644 --- a/backends/xnnpack/test/passes/test_remove_get_item_pass.py +++ b/backends/xnnpack/test/passes/test_remove_get_item_pass.py @@ -42,8 +42,7 @@ def test_fp32_max_pool2d_remove_getitem(self): .to_edge() .run_passes(self.PassStage) .check_count({self.max_pool2d_name: 1}) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_q8_max_pool2d_remove_getitem(self): @@ -54,8 +53,7 @@ def test_q8_max_pool2d_remove_getitem(self): .to_edge() .run_passes(self.PassStage) .check_count({self.max_pool2d_name: 1}) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) class MaxModule(torch.nn.Module): @@ -79,8 +77,7 @@ def test_fp32_max_remove_getitem(self): self.amax_name: 1, } ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_q8_max_remove_getitem(self): @@ -95,6 +92,5 @@ def test_q8_max_remove_getitem(self): self.amax_name: 1, } ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py b/backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py index 97c31c3d43a..dc67a6582df 100644 --- a/backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py +++ b/backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py @@ -55,8 +55,7 @@ def test_tag_implicit_q_dq_test(self): .export() .to_edge() .run_passes(self.PassStage) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() .get_artifact(Tester.stage_name(self.PassStage)) ) diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index ec03fa2529d..e0115a29eef 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -7,6 +7,7 @@ import copy import logging +import random import sys from abc import ABC, abstractmethod from collections import Counter, OrderedDict @@ -26,7 +27,7 @@ ) from executorch.exir.backend.backend_api import validation_disabled from executorch.exir.backend.partitioner import Partitioner -from executorch.exir.passes.spec_prop_pass import SpecPropPass +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.print_program import pretty_print, print_program logger = logging.getLogger(__name__) @@ -177,11 +178,18 @@ def graph_module(self) -> str: @register_stage class Export(Stage): - def __init__(self): + def __init__(self, dynamic_shapes: Optional[Tuple[Any]] = None): self.exported_program = None + self.dynamic_shapes = dynamic_shapes - def run(self, artifact: torch.nn.Module, inputs) -> None: - self.exported_program = export(artifact, inputs) + def run( + self, + artifact: torch.nn.Module, + inputs: Tuple[torch.Tensor], + ) -> None: + self.exported_program = export( + artifact, inputs, dynamic_shapes=self.dynamic_shapes + ) @property def artifact(self) -> ExportedProgram: @@ -261,8 +269,8 @@ def __init__( config: Optional[ExecutorchBackendConfig] = None, ): self.config = config or ExecutorchBackendConfig( - passes=[SpecPropPass()], extract_delegate_segments=True, + sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) self.executorch_program = None @@ -334,11 +342,13 @@ def __init__( self, module: torch.nn.Module, inputs: Tuple[torch.Tensor], + dynamic_shapes: Optional[Tuple[Any]] = None, ): module.eval() self.original_module = module self.inputs = inputs + self.dynamic_shapes = dynamic_shapes self.stages: Dict[str, Stage] = OrderedDict.fromkeys(list(_stages_.keys())) self.pipeline = { self.stage_name(Quantize): [self.stage_name(Export)], @@ -371,6 +381,59 @@ def __init__( # Artifact output from stage self.stage_output = None + def generate_random_inputs(self): + # Get shapes of inputs + input_shapes = [] + if self.dynamic_shapes is None: + for tensor_arg in self.inputs: + assert isinstance(tensor_arg, torch.Tensor) + input_shapes.append(tensor_arg.shape) + else: + # Random shapes depending on dynamic shape constraint + dim_name_to_size = {} + for arg_idx in range(len(self.inputs)): + assert isinstance(self.inputs[arg_idx], torch.Tensor) + ex_shape = list(self.inputs[arg_idx].shape) + dynamic_dim_spec = self.dynamic_shapes[arg_idx] + for dim_idx, dim_spec in dynamic_dim_spec.items(): + assert dim_idx < len(ex_shape) + if isinstance(dim_spec, torch.export.dynamic_shapes._DerivedDim): + # derived dims are of the form {0: 2 * torch.export.Dim() // 2} + # The root contains the min/max of the export dim and fn contains + # the function to compute the derived dim. + dim_spec = dim_spec.root + fn = dim_spec.fn + elif isinstance(dim_spec, torch.export.dynamic_shapes._Dim): + # Not derived dim so fn is just itself + def fn(x): + return x + + else: + raise RuntimeError( + f"Expected Dynamic Dims to be of type _DerivedDim or _Dim but got {type(dim_spec)}" + ) + dim_name = dim_spec.__name__ + if dim_name not in dim_name_to_size: + upper_bound = min( + dim_spec.max, 1000 + ) # unbounded int max is too large + lower_bound = ( + dim_spec.min if dim_spec.min != 2 else 1 + ) # 0/1 specialization means dim_spec.min can never be 1 + dim_name_to_size[dim_name] = fn( + random.randint(lower_bound, upper_bound) + ) + ex_shape[dim_idx] = dim_name_to_size[dim_spec.__name__] + input_shapes.append(torch.Size(ex_shape)) + # create random tensor inputs with the shapes given above: + random_inputs = [] + for arg_idx in range(len(self.inputs)): + random_inputs.append( + torch.randn(input_shapes[arg_idx]).to(dtype=self.inputs[arg_idx].dtype) + ) + + yield tuple(random_inputs) + @staticmethod def stage_name(stage) -> str: t = stage if isinstance(stage, type) else type(stage) @@ -406,7 +469,9 @@ def quantize(self, quantize_stage: Optional[Quantize] = None): return self._run_stage(quantize_stage or Quantize(), self.inputs) def export(self, export_stage: Optional[Export] = None): - return self._run_stage(export_stage or Export(), self.inputs) + return self._run_stage( + export_stage or Export(dynamic_shapes=self.dynamic_shapes), self.inputs + ) def to_edge(self, to_edge_stage: Optional[ToEdge] = None): return self._run_stage(to_edge_stage or ToEdge()) @@ -469,21 +534,39 @@ def check_node_count(self, input: Dict[Any, int]): return self - def run_method( - self, stage: Optional[str] = None, inputs: Optional[Tuple[torch.Tensor]] = None + def run_method_and_compare_outputs( + self, + stage: Optional[str] = None, + inputs: Optional[Tuple[torch.Tensor]] = None, + num_runs=1, + atol=1e-03, + rtol=1e-03, + qtol=0, ): - inputs_to_run = inputs or self.inputs - export_stage = self.stages[self.stage_name(Export)] - - # Reference output (and quantization scale) - ( - self.reference_output, - self.quantization_scale, - ) = self._calculate_reference_output(export_stage.artifact, inputs_to_run) + number_of_runs = 1 if inputs is not None else num_runs + reference_stage = self.stages[self.stage_name(Export)] - # Output from running artifact at stage stage = stage or self.cur - self.stage_output = self.stages[stage].run_artifact(inputs_to_run) + + print(f"Comparing Stage {stage} with Stage {reference_stage}") + for run_iteration in range(number_of_runs): + inputs_to_run = inputs if inputs else next(self.generate_random_inputs()) + input_shapes = [generated_input.shape for generated_input in inputs_to_run] + print(f"Run {run_iteration} with input shapes: {input_shapes}") + + # Reference output (and quantization scale) + ( + reference_output, + quantization_scale, + ) = self._calculate_reference_output( + reference_stage.artifact, inputs_to_run + ) + + # Output from running artifact at stage + stage_output = self.stages[stage].run_artifact(inputs_to_run) + self._compare_outputs( + reference_output, stage_output, quantization_scale, atol, rtol, qtol + ) return self @@ -521,33 +604,37 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): f"\t Min: {model.min()}, {ref.min()}\n" ) - def compare_outputs(self, atol=1e-03, rtol=1e-03, qtol=0): + @staticmethod + def _compare_outputs( + reference_output, + stage_output, + quantization_scale=None, + atol=1e-03, + rtol=1e-03, + qtol=0, + ): """ Compares the original of the original nn module with the output of the generated artifact. This requres calling run_method before calling compare_outputs. As that runs the generated artifact on the sample inputs and sets the stage output to be compared against the reference. """ - assert self.reference_output is not None - assert self.stage_output is not None - # Wrap both outputs as tuple, since executor output is always a tuple even if single tensor - if isinstance(self.reference_output, torch.Tensor): - self.reference_output = (self.reference_output,) - if isinstance(self.stage_output, torch.Tensor): - self.stage_output = (self.stage_output,) + if isinstance(reference_output, torch.Tensor): + reference_output = (reference_output,) + if isinstance(stage_output, torch.Tensor): + stage_output = (stage_output,) # If a qtol is provided and we found an dequantization node prior to the output, relax the # atol by qtol quant units. - if self.quantization_scale is not None: - atol += self.quantization_scale * qtol + if quantization_scale is not None: + atol += quantization_scale * qtol - self._assert_outputs_equal( - self.stage_output, - self.reference_output, + Tester._assert_outputs_equal( + stage_output, + reference_output, atol=atol, rtol=rtol, ) - return self @staticmethod def _calculate_reference_output(