-
Notifications
You must be signed in to change notification settings - Fork 974
Expand file tree
/
Copy pathexport_example.py
More file actions
65 lines (46 loc) · 1.98 KB
/
export_example.py
File metadata and controls
65 lines (46 loc) · 1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Example script for exporting simple models to flatbuffer
import logging
from .meta_registrations import * # noqa
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from ...portable.utils import save_pte_program
from .compiler import export_to_edge
from .quantizer import (
QuantFusion,
ReplacePT2DequantWithXtensaDequant,
ReplacePT2QuantWithXtensaQuant,
XtensaBaseQuantizer,
)
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
def export_xtensa_model(model, example_inputs):
# Quantizer
quantizer = XtensaBaseQuantizer()
# Export
model_exp = capture_pre_autograd_graph(model, example_inputs)
# Prepare
prepared_model = prepare_pt2e(model_exp, quantizer)
prepared_model(*example_inputs)
# Convert
converted_model = convert_pt2e(prepared_model)
# pyre-fixme[16]: Pyre doesn't get that XtensaQuantizer has a patterns attribute
patterns = [q.pattern for q in quantizer.quantizers]
QuantFusion(patterns)(converted_model)
# Get edge program (note: the name will change to export_to_xtensa in future PRs)
edge_prog_manager = export_to_edge(converted_model, example_inputs, pt2_quant=True)
# Run a couple required passes for quant/dequant ops
xtensa_prog_manager = edge_prog_manager.transform(
[ReplacePT2QuantWithXtensaQuant(), ReplacePT2DequantWithXtensaDequant()],
check_ir_validity=False,
)
exec_prog = xtensa_prog_manager.to_executorch()
logging.info(
f"Final exported graph module:\n{exec_prog.exported_program().graph_module}"
)
# Save the program as XtensaDemoModel.pte
save_pte_program(exec_prog, "XtensaDemoModel")