Skip to content

Commit 793307e

Browse files
authored
Import & Cache Mechanism (#26)
* Import \& Cache Mechanism * unused * use None as default cache_dir * invalid name + more layout * fix bert * remove
1 parent 3c6882a commit 793307e

2 files changed

Lines changed: 240 additions & 0 deletions

File tree

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import multiprocessing
2+
import os
3+
import pickle
4+
from typing import Dict, List, Optional, Tuple
5+
6+
import tvm
7+
import tvm.relay.testing
8+
from tvm import relay
9+
from tvm.ir import IRModule
10+
from tvm.runtime import NDArray, load_param_dict, save_param_dict
11+
12+
SUPPORTED = [
13+
# TorchVision
14+
"resnet_18",
15+
"resnet_50",
16+
"mobilenet_v2",
17+
"mobilenet_v3",
18+
"wide_resnet_50",
19+
"resnext_50",
20+
"resnet3d_18",
21+
"inception_v3",
22+
"densenet_121",
23+
"vgg_16",
24+
# Transformer
25+
"bert_tiny",
26+
"bert_base",
27+
"bert_medium",
28+
"bert_large",
29+
# Relay testing
30+
"dcgan",
31+
]
32+
33+
34+
def _get_network(
35+
args: Tuple[str, List[int]]
36+
) -> Tuple[IRModule, bytearray, Tuple[str, List[int], str]]:
37+
name: str
38+
input_shape: List[int]
39+
name, input_shape = args
40+
41+
mod: IRModule
42+
43+
if name in [
44+
"resnet_18",
45+
"resnet_50",
46+
"wide_resnet_50",
47+
"resnext_50",
48+
"mobilenet_v2",
49+
"mobilenet_v3",
50+
"inception_v3",
51+
"densenet_121",
52+
"resnet3d_18",
53+
"vgg_16",
54+
]:
55+
# torchvision>=0.9.0
56+
import torch # type: ignore
57+
import torchvision.models as models # type: ignore
58+
59+
if name in ["resnet_18", "resnet_50"]:
60+
model = getattr(models, name.replace("_", ""))(pretrained=False)
61+
elif name == "wide_resnet_50":
62+
model = getattr(models, "wide_resnet50_2")(pretrained=False)
63+
elif name == "resnext_50":
64+
model = getattr(models, "resnext50_32x4d")(pretrained=False)
65+
elif name == "mobilenet_v2":
66+
model = getattr(models, name)(pretrained=False)
67+
elif name == "mobilenet_v3":
68+
model = getattr(models, name + "_large")(pretrained=False)
69+
elif name == "inception_v3":
70+
model = getattr(models, name)(pretrained=False, aux_logits=False)
71+
elif name == "densenet_121":
72+
model = getattr(models, name.replace("_", ""))(pretrained=False)
73+
elif name == "resnet3d_18":
74+
model = models.video.r3d_18(pretrained=False)
75+
elif name == "vgg_16":
76+
model = getattr(models, name.replace("_", ""))(pretrained=False)
77+
78+
dtype = "float32"
79+
input_data = torch.randn(input_shape).type(
80+
{
81+
"float32": torch.float32,
82+
}[dtype]
83+
)
84+
scripted_model = torch.jit.trace(model, input_data).eval()
85+
input_name = "input0"
86+
shape_list = [(input_name, input_shape)]
87+
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
88+
with tvm.transform.PassContext(opt_level=3):
89+
mod = tvm.transform.Sequential(
90+
[
91+
relay.transform.RemoveUnusedFunctions(),
92+
relay.transform.ConvertLayout(
93+
{
94+
"nn.conv2d": ["NHWC", "default"],
95+
"nn.conv3d": ["NDHWC", "default"],
96+
"nn.max_pool2d": ["NHWC", "default"],
97+
"nn.avg_pool2d": ["NHWC", "default"],
98+
}
99+
),
100+
]
101+
)(mod)
102+
inputs = (input_name, input_shape, dtype)
103+
elif name in ["bert_tiny", "bert_base", "bert_medium", "bert_large"]:
104+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
105+
# pip3 install transformers==3.5 torch==1.7
106+
import torch # type: ignore
107+
import transformers # type: ignore
108+
109+
config_dict = {
110+
"bert_tiny": transformers.BertConfig(
111+
num_hidden_layers=6,
112+
hidden_size=512,
113+
intermediate_size=2048,
114+
num_attention_heads=8,
115+
return_dict=False,
116+
),
117+
"bert_base": transformers.BertConfig(
118+
num_hidden_layers=12,
119+
hidden_size=768,
120+
intermediate_size=3072,
121+
num_attention_heads=12,
122+
return_dict=False,
123+
),
124+
"bert_medium": transformers.BertConfig(
125+
num_hidden_layers=12,
126+
hidden_size=1024,
127+
intermediate_size=4096,
128+
num_attention_heads=16,
129+
return_dict=False,
130+
),
131+
"bert_large": transformers.BertConfig(
132+
num_hidden_layers=24,
133+
hidden_size=1024,
134+
intermediate_size=4096,
135+
num_attention_heads=16,
136+
return_dict=False,
137+
),
138+
}
139+
configuration = config_dict[name]
140+
model = transformers.BertModel(configuration)
141+
input_name = "input_ids"
142+
input_dtype = "int64"
143+
A = torch.randint(10000, input_shape)
144+
model.eval()
145+
scripted_model = torch.jit.trace(model, [A], strict=False)
146+
input_name = "input_ids"
147+
shape_list = [(input_name, input_shape)]
148+
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
149+
mod = relay.transform.FastMath()(mod)
150+
mod = relay.transform.CombineParallelBatchMatmul()(mod)
151+
inputs = (input_name, input_shape, input_dtype)
152+
elif name == "dcgan":
153+
output_shape = input_shape
154+
batch_size = output_shape[0]
155+
oshape = output_shape[1:]
156+
mod, params = relay.testing.dcgan.get_workload(
157+
batch_size=batch_size,
158+
oshape=oshape,
159+
layout="NHWC",
160+
)
161+
inputs = ("data", [100], "float32")
162+
else:
163+
raise ValueError("Invalid name: " + name)
164+
165+
params_bytearray: bytearray = save_param_dict(params)
166+
return mod, params_bytearray, inputs
167+
168+
169+
def get_network(
170+
name: str,
171+
input_shape: List[int],
172+
cache_dir: Optional[str] = None,
173+
) -> Tuple[IRModule, Dict[str, NDArray], Tuple[str, List[int], str]]:
174+
mod: IRModule
175+
params_bytearray: bytearray
176+
params: Dict[str, NDArray]
177+
inputs: Tuple[str, List[int], str]
178+
keyword = f'{name}-{",".join(str(i) for i in input_shape)}.json'
179+
if cache_dir is not None:
180+
path = os.path.join(cache_dir, keyword)
181+
if os.path.exists(path):
182+
print(f"Load cached network file: {path}")
183+
with open(path, "rb") as i_f:
184+
mod, params_bytearray, inputs = pickle.load(i_f)
185+
params = load_param_dict(params_bytearray)
186+
return mod, params, inputs
187+
with multiprocessing.Pool(processes=1) as pool:
188+
result = pool.map(_get_network, [(name, input_shape)])
189+
((mod, params_bytearray, inputs),) = result
190+
params = load_param_dict(params_bytearray)
191+
if cache_dir is not None:
192+
path = os.path.join(cache_dir, keyword)
193+
with open(path, "wb") as o_f:
194+
pickle.dump((mod, params_bytearray, inputs), o_f)
195+
return mod, params, inputs
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from tvm.meta_schedule.testing.e2e import get_network
2+
3+
4+
def test_import():
5+
network_keys = []
6+
for name in [
7+
"resnet_18",
8+
"resnet_50",
9+
"mobilenet_v2",
10+
"mobilenet_v3",
11+
"wide_resnet_50",
12+
"resnext_50",
13+
"densenet_121",
14+
]:
15+
for batch_size in [1, 4, 8]:
16+
for image_size in [224, 240, 256]:
17+
network_keys.append((name, [batch_size, 3, image_size, image_size]))
18+
# inception-v3
19+
for name in ["inception_v3"]:
20+
for batch_size in [1, 2, 4]:
21+
for image_size in [299]:
22+
network_keys.append((name, [batch_size, 3, image_size, image_size]))
23+
# resnet3d
24+
for name in ["resnet3d_18"]:
25+
for batch_size in [1, 2, 4]:
26+
for image_size in [112, 128, 144]:
27+
network_keys.append((name, [batch_size, 3, image_size, image_size, 16]))
28+
# bert
29+
for name in ["bert_tiny", "bert_base", "bert_medium", "bert_large"]:
30+
for batch_size in [1, 2, 4]:
31+
for seq_length in [64, 128, 256]:
32+
network_keys.append((name, [batch_size, seq_length]))
33+
# dcgan
34+
for name in ["dcgan"]:
35+
for batch_size in [1, 4, 8]:
36+
for image_size in [64]:
37+
network_keys.append((name, [batch_size, 3, image_size, image_size]))
38+
39+
for i, (name, input_shape) in enumerate(network_keys, 1):
40+
print(f"[{i} / {len(network_keys)}] {name}, input_shape = {input_shape}")
41+
get_network(name, input_shape, cache_dir="/tmp/relay/")
42+
43+
44+
if __name__ == "__main__":
45+
test_import()

0 commit comments

Comments
 (0)