forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_frontend_stablehlo.py
More file actions
387 lines (305 loc) · 10.7 KB
/
test_frontend_stablehlo.py
File metadata and controls
387 lines (305 loc) · 10.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=c-extension-no-member
import functools
from typing import Union, Tuple, List
import pytest
import numpy as np
import tvm
import tvm.testing
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
from tvm.relax.frontend.stablehlo import from_stablehlo
def generate_np_inputs(
input_shapes: Union[Tuple, List[Tuple]], dtype: str = "float32"
) -> Union[np.ndarray, List[np.ndarray]]:
"""Generate numpy data as the inputs of model
Parameters
----------
input_shapes: Union[Tuple, List[Tuple]]
shapes for inputs
dtype: str
the data type of inputs
Results
-------
out: List[np.ndarray]
numpy input data
"""
if not isinstance(input_shapes[0], (list, tuple)):
return [np.random.uniform(size=input_shapes).astype(dtype)]
out = []
for input_shape in input_shapes:
out.append(np.random.uniform(size=input_shape).astype(dtype))
return out
def np2jnp(inputs_np: Union[np.ndarray, List[np.ndarray]]):
"""Convert data from numpy to jax.numpy
Parameters
----------
inputs_np: Union[np.ndarray, List[np.ndarray]]
numpy input data
Results
-------
out: Union[jnp.ndarray, List[jnp.ndarray]]
jax numpy data
"""
import jax.numpy as jnp
# Use jnp.asarray to avoid unnecessary memory copies
inputs_jnp = []
if isinstance(inputs_np, (tuple, list)):
for input_np in inputs_np:
inputs_jnp.append(jnp.asarray(input_np))
return inputs_jnp
return jnp.asarray(inputs_np)
def check_correctness(
jax_jit_mod,
input_shapes: Union[Tuple, List[Tuple]],
dtype: str = "float32",
) -> None:
"""Run a jax model and the translated TVM IRModule,
verify the inference accuracy.
Parameters
----------
jax_jit_mod: jaxlib.xla_extension.CompiledFunction
The input jax jitted model
input_shapes: Union[Tuple, List[Tuple]]
shapes for inputs
dtype: str
the data type of inputs
"""
# Generate numpy inputs
inputs_np = generate_np_inputs(input_shapes, dtype)
# Get the jax numpy data
inputs_jnp = np2jnp(inputs_np)
# lower the jitted function to StableHLO
lowered = jax_jit_mod.lower(*inputs_np)
# lowered.as_text(dialect="stablehlo") generates text format
# compiler_ir generates the related jaxlib.mlir.Module
stablehlo_module = lowered.compiler_ir(dialect="stablehlo")
# Convert the StableHLO IR to Relax
ir_mod = from_stablehlo(stablehlo_module)
# Run the jax jitted model with the input jax numpy data
jax_output = jax_jit_mod(*inputs_jnp)
# TODO (yongwww): support multiple targets,
# "llvm" should be good for this check
target = tvm.target.Target("llvm", host="llvm")
# Compile and run
ex = relax.build(ir_mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
vm.set_input("main", *inputs_np)
vm.invoke_stateful("main")
tvm_output = vm.get_outputs("main")
# Single ouput
if isinstance(tvm_output, tvm.nd.NDArray):
tvm.testing.assert_allclose(tvm_output.numpy(), jax_output, rtol=1e-5, atol=1e-5)
return
# Multiple ouputs
assert len(tvm_output) == len(jax_output), "numbers of outputs mismatch"
for tvm_out, jax_out in zip(tvm_output, jax_output):
tvm.testing.assert_allclose(tvm_out.numpy(), jax_out, rtol=1e-5, atol=1e-5)
def get_vm_res(
ir_mod: tvm.IRModule, weights: Union[np.ndarray, List[np.ndarray]]
) -> Union[tvm.nd.NDArray, List[tvm.nd.NDArray]]:
"""Compile and run an ir_module on Relax VM
Parameters
----------
ir_mod: tvm.IRModule
input ir module
weights: Union[np.ndarray, List[np.ndarray]]
input weights
Results
-------
out: Union[tvm.nd.NDArray, List[tvm.nd.NDArray]]
inference result
"""
target = tvm.target.Target("llvm", host="llvm")
# Compile and run
ex = relax.build(ir_mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
vm.set_input("main", *weights)
vm.invoke_stateful("main")
tvm_output = vm.get_outputs("main")
return tvm_output
@tvm.testing.requires_gpu
def test_add_dynamic():
add_dyn = """
func.func @test(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%1 = stablehlo.add %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
func.return %1 : tensor<?x?xf32>
}
"""
mod = from_stablehlo(add_dyn)
@I.ir_module
class Expected:
@R.function
def main(
arg0: R.Tensor(("n_0", "n_1"), dtype="float32"),
arg1: R.Tensor(("n_2", "n_3"), dtype="float32"),
) -> R.Tensor(dtype="float32", ndim=2):
n_0 = T.int64()
n_1 = T.int64()
n_2 = T.int64()
n_3 = T.int64()
with R.dataflow():
lv: R.Tensor(dtype="float32", ndim=2) = R.add(arg0, arg1)
gv: R.Tensor(dtype="float32", ndim=2) = lv
R.output(gv)
return gv
tvm.ir.assert_structural_equal(mod, Expected)
@tvm.testing.requires_gpu
@pytest.mark.skip(
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
)
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
def test_unary():
import jax
def _rsqrt(x):
return jax.lax.rsqrt(x)
def _sqrt(x):
return jax.lax.sqrt(x)
def _sin(x):
return jax.lax.sin(x)
def _sinh(x):
return jax.lax.sinh(x)
def _cos(x):
return jax.lax.cos(x)
def _cosh(x):
return jax.lax.cos(x)
def _exp(x):
return jax.lax.exp(x)
def _round(x):
return jax.lax.round(x)
input_shapes = (2, 3, 4)
for fn in [_rsqrt, _sqrt, _sin, _cos, _cosh, _exp, _round]:
check_correctness(jax.jit(fn), input_shapes)
@tvm.testing.requires_gpu
@pytest.mark.skip(
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
)
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
def test_binary():
import jax
def fn(x, y):
r1 = x + y
r2 = r1 * r1
r3 = r2 / r1
r = r2 - r3
return r
input_shape = (1, 2, 3)
input_shapes = (input_shape, input_shape)
# jit the function
jit_fn = jax.jit(fn)
# verify inference accuracy
check_correctness(jit_fn, input_shapes)
@tvm.testing.requires_gpu
@pytest.mark.skip(
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
)
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
def test_const():
import jax
def fn(x):
return x + 1
check_correctness(jax.jit(fn), (2,))
@tvm.testing.requires_gpu
@pytest.mark.skip(
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
)
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
def test_maximum():
import jax
import jax.numpy as jnp
def fn(x, y):
return jnp.maximum(x, y)
check_correctness(jax.jit(fn), ((2, 3), (2, 3)))
@tvm.testing.requires_gpu
@pytest.mark.skip(
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
)
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
def test_minimum():
import jax
import jax.numpy as jnp
def fn(x, y):
return jnp.minimum(x, y)
check_correctness(jax.jit(fn), ((2, 3), (2, 3)))
@tvm.testing.requires_gpu
@pytest.mark.skip(
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
)
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
def test_reduce():
import jax
import jax.numpy as jnp
def fn(x):
return jnp.mean(x, axis=(1, 2))
check_correctness(jax.jit(fn), (2, 3, 4, 5))
@tvm.testing.requires_gpu
@pytest.mark.skip(
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
)
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
def test_reduce_window():
import jax
from flax import linen as nn
def fn(x):
return nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
check_correctness(jax.jit(fn), (2, 3, 4))
@tvm.testing.requires_gpu
@pytest.mark.skip(
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
)
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
def test_dot_general():
import jax
def fn(x, y):
return jax.lax.dot_general(x, y, (([1], [0]), ([], [])))
input_shapes = ((1, 512), (512, 2))
check_correctness(jax.jit(fn), input_shapes)
@tvm.testing.requires_gpu
@pytest.mark.skip(
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
)
# TODO(yongwww): fix flaky error of "invalid device ordinal"
def test_conv():
import jax
from flax import linen as nn
import jax.random as jrandom
conv = nn.Conv(64, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name="conv_init")
input_shape = (7, 7, 5, 64)
input_np = generate_np_inputs(input_shape)[0]
input_jnp = np2jnp(input_np)
# initialize the conv
weights = conv.init(jrandom.PRNGKey(0), input_jnp)
# get jax inference output
jax_output = conv.apply(weights, input_jnp)
# assemble numpy data using weights generated above
kernel_np = np.asarray(weights["params"]["kernel"])
bias_np = np.asarray(weights["params"]["bias"])
inputs_np = [bias_np, kernel_np, input_np]
# jit and lower to StableHLO
apply = functools.partial(conv.apply)
stablehlo_module = jax.jit(apply).lower(weights, input_jnp).compiler_ir(dialect="stablehlo")
# convert in Relax
ir_mod = from_stablehlo(stablehlo_module)
# compile and run
tvm_output = get_vm_res(ir_mod, inputs_np)
# verify accuracy
tvm.testing.assert_allclose(tvm_output.numpy(), jax_output, rtol=1e-5, atol=1e-5)
if __name__ == "__main__":
tvm.testing.main()