forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaot_compiler.py
More file actions
148 lines (124 loc) · 4.89 KB
/
aot_compiler.py
File metadata and controls
148 lines (124 loc) · 4.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
# Example script for exporting simple models to flatbuffer
import argparse
import logging
import torch
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
to_edge_transform_and_lower,
)
from executorch.extension.export_util.utils import save_pte_program
from ..models import MODEL_NAME_TO_MODEL
from ..models.model_factory import EagerModelFactory
from . import MODEL_NAME_TO_OPTIONS
from .quantization.utils import quantize
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name",
required=True,
help=f"Model name. Valid ones: {list(MODEL_NAME_TO_OPTIONS.keys())}",
)
parser.add_argument(
"-q",
"--quantize",
action="store_true",
required=False,
default=False,
help="Produce an 8-bit quantized model",
)
parser.add_argument(
"-d",
"--delegate",
action="store_true",
required=False,
default=True,
help="Produce an XNNPACK delegated model",
)
parser.add_argument(
"-r",
"--etrecord",
required=False,
default="",
help="Generate and save an ETRecord to the given file location",
)
parser.add_argument(
"-t",
"--test_after_export",
action="store_true",
required=False,
default=False,
help="Test the pte with pybindings",
)
parser.add_argument("-o", "--output_dir", default=".", help="output directory")
args = parser.parse_args()
if not args.delegate and args.quantize:
raise NotImplementedError(
"T161880157: Quantization-only without delegation is not supported yet"
)
if args.model_name not in MODEL_NAME_TO_OPTIONS and args.quantize:
raise RuntimeError(
f"Model {args.model_name} is not a valid name. or not quantizable right now, "
"please contact executorch team if you want to learn why or how to support "
"quantization for the requested model"
f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}."
)
quant_type = MODEL_NAME_TO_OPTIONS[args.model_name].quantization
model, example_inputs, _, _ = EagerModelFactory.create_model(
*MODEL_NAME_TO_MODEL[args.model_name]
)
model = model.eval()
# pre-autograd export. eventually this will become torch.export
ep = torch.export.export(model, example_inputs, strict=False)
model = ep.module()
if args.quantize:
logging.info("Quantizing Model...")
# TODO(T165162973): This pass shall eventually be folded into quantizer
model = quantize(model, example_inputs, quant_type)
ep = torch.export.export(model, example_inputs, strict=False)
edge = to_edge_transform_and_lower(
ep,
partitioner=[XnnpackPartitioner()],
compile_config=EdgeCompileConfig(
_check_ir_validity=False if args.quantize else True,
_skip_dim_order=True, # TODO(T182187531): enable dim order in xnnpack
),
generate_etrecord=args.etrecord,
)
logging.info(f"Exported and lowered graph:\n{edge.exported_program().graph}")
exec_prog = edge.to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=False)
)
if args.etrecord:
exec_prog.get_etrecord().save(args.etrecord)
logging.info(f"Saved ETRecord to {args.etrecord}")
quant_tag = "q8" if args.quantize else "fp32"
model_name = f"{args.model_name}_xnnpack_{quant_tag}"
save_pte_program(exec_prog, model_name, args.output_dir)
if args.test_after_export:
logging.info("Testing the pte with pybind")
from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)
# Import custom ops. This requires portable_lib to be loaded first.
from executorch.extension.llm.custom_ops import ( # noqa: F401, F403
custom_ops,
) # usort: skip
# Import quantized ops. This requires portable_lib to be loaded first.
from executorch.kernels import quantized # usort: skip # noqa: F401, F403
from torch.utils._pytree import tree_flatten
m = _load_for_executorch_from_buffer(exec_prog.buffer)
logging.info("Successfully loaded the model")
flattened = tree_flatten(example_inputs)[0]
res = m.run_method("forward", flattened)
logging.info("Successfully ran the model")