1111from executorch import exir
1212from executorch .backends .example .example_partitioner import ExamplePartitioner
1313from executorch .backends .example .example_quantizer import ExampleQuantizer
14- from executorch .exir . backend . backend_api import to_backend
14+ from executorch .exir import to_edge
1515
1616from executorch .exir .backend .canonical_partitioners .duplicate_dequant_node_pass import (
1717 DuplicateDequantNodePass ,
1818)
1919from executorch .exir .delegate import executorch_call_delegate
2020
2121from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
22+ from torch .export import export
2223
23- # @manual=//pytorch/vision:torchvision
2424from torchvision .models .quantization import mobilenet_v2
2525
2626
@@ -40,7 +40,6 @@ def get_example_inputs():
4040
4141 model = Conv2dModule ()
4242 example_inputs = Conv2dModule .get_example_inputs ()
43- CAPTURE_CONFIG = exir .CaptureConfig (enable_aot = True )
4443 EDGE_COMPILE_CONFIG = exir .EdgeCompileConfig (
4544 _check_ir_validity = False ,
4645 )
@@ -59,24 +58,23 @@ def get_example_inputs():
5958 m = convert_pt2e (m )
6059
6160 quantized_gm = m
62- exported_program = exir .capture (
63- quantized_gm , copy .deepcopy (example_inputs ), CAPTURE_CONFIG
64- ).to_edge (EDGE_COMPILE_CONFIG )
61+ exported_program = to_edge (
62+ export (quantized_gm , copy .deepcopy (example_inputs )),
63+ compile_config = EDGE_COMPILE_CONFIG ,
64+ )
6565
66- lowered_export_program = to_backend (
67- exported_program .exported_program ,
66+ lowered_export_program = exported_program .to_backend (
6867 ExamplePartitioner (),
6968 )
7069
7170 print ("After lowering to qnn backend: " )
72- lowered_export_program .graph .print_tabular ()
71+ lowered_export_program .exported_program (). graph .print_tabular ()
7372
7473 def test_delegate_mobilenet_v2 (self ):
7574 model = mobilenet_v2 (num_classes = 3 )
7675 model .eval ()
7776 example_inputs = (torch .rand (1 , 3 , 320 , 240 ),)
7877
79- CAPTURE_CONFIG = exir .CaptureConfig (enable_aot = True )
8078 EDGE_COMPILE_CONFIG = exir .EdgeCompileConfig (
8179 _check_ir_validity = False ,
8280 )
@@ -91,20 +89,22 @@ def test_delegate_mobilenet_v2(self):
9189 m = convert_pt2e (m )
9290
9391 quantized_gm = m
94- exported_program = exir .capture (
95- quantized_gm , copy .deepcopy (example_inputs ), CAPTURE_CONFIG
96- ).to_edge (EDGE_COMPILE_CONFIG )
92+ exported_program = to_edge (
93+ export (quantized_gm , copy .deepcopy (example_inputs )),
94+ compile_config = EDGE_COMPILE_CONFIG ,
95+ )
9796
98- lowered_export_program = to_backend (
99- exported_program .transform (DuplicateDequantNodePass ()).exported_program ,
97+ lowered_export_program = exported_program .transform (
98+ [DuplicateDequantNodePass ()]
99+ ).to_backend (
100100 ExamplePartitioner (),
101101 )
102102
103- lowered_export_program .graph .print_tabular ()
103+ lowered_export_program .exported_program (). graph .print_tabular ()
104104
105105 call_deleage_node = [
106106 node
107- for node in lowered_export_program .graph .nodes
107+ for node in lowered_export_program .exported_program (). graph .nodes
108108 if node .target == executorch_call_delegate
109109 ]
110110 self .assertEqual (len (call_deleage_node ), 1 )
0 commit comments