forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_elu.py
More file actions
98 lines (86 loc) · 3.08 KB
/
test_elu.py
File metadata and controls
98 lines (86 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from executorch.backends.xnnpack.test.tester import Tester
class TestElu(unittest.TestCase):
def setUp(self):
torch._dynamo.reset()
class ELU(torch.nn.Module):
def __init__(self):
super().__init__()
self.elu = torch.nn.ELU(alpha=0.5)
def forward(self, x):
return self.elu(x)
class ELUFunctional(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.elu(x, alpha=1.2)
def _test_elu(self, inputs):
(
Tester(self.ELU(), inputs)
.export()
.check_count({"torch.ops.aten.elu.default": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(
[
"executorch_exir_dialects_edge__ops_aten_elu_default",
]
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)
@unittest.skip("PyTorch Pin Update Required")
def _test_fp16_elu(self):
inputs = (torch.randn(1, 3, 3).to(torch.float16),)
self._test_elu(inputs)
@unittest.skip("PyTorch Pin Update Required")
def _test_fp32_elu(self):
inputs = (torch.randn(1, 3, 3),)
self._test_elu(inputs)
@unittest.skip("Update Quantizer to quantize Elu")
def _test_qs8_elu(self):
inputs = (torch.randn(1, 3, 4, 4),)
(
Tester(self.ELU(), inputs)
.quantize()
.export()
.check_count({"torch.ops.aten.elu.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(
[
"executorch_exir_dialects_edge__ops_aten_elu_default",
"torch.ops.quantized_decomposed",
]
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)
@unittest.skip("Update Quantizer to quantize Elu")
def _test_qs8_elu_functional(self):
inputs = (torch.randn(1, 3, 4, 4),)
(
Tester(self.ELU(), inputs)
.quantize()
.export()
.check_count({"torch.ops.aten.elu.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(
[
"executorch_exir_dialects_edge__ops_aten_elu_default",
"torch.ops.quantized_decomposed",
]
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)