Skip to content

Commit 1982101

Browse files
Lunderbergmehrdadh
authored andcommitted
[MicroTVM][PyTest] Explicitly skip MicroTVM unittests. (apache#9335)
* [MicroTVM][PyTest] Explicitly skip MicroTVM unittests. Refactor unit tests so they will show as skipped if `USE_MICRO=OFF`. * Updates following PR review. - Updated to avoid name shadowing of BaseTestHandler - Updated test_micro_transport to use fixture for setup. Ended up needing to refactor to use pytest instead of unittest, split up test functionality during refactor.
1 parent 5dc6452 commit 1982101

4 files changed

Lines changed: 253 additions & 173 deletions

File tree

tests/python/conftest.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,3 @@
3737
# collect_ignore.append("unittest/test_auto_scheduler_measure.py") # exception ignored
3838

3939
collect_ignore.append("unittest/test_tir_intrin.py")
40-
41-
if tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON":
42-
collect_ignore.append("unittest/test_micro_transport.py")

tests/python/unittest/test_crt.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@
3535
from tvm.topi.utils import get_const_tuple
3636
from tvm.topi.testing import conv2d_nchw_python
3737

38-
pytest.importorskip("tvm.micro.testing")
39-
from tvm.micro.testing import check_tune_log
40-
4138
BUILD = True
4239
DEBUG = False
4340

@@ -222,6 +219,7 @@ def test_platform_timer():
222219
def test_autotune():
223220
"""Verify that autotune works with micro."""
224221
import tvm.relay as relay
222+
from tvm.micro.testing import check_tune_log
225223

226224
data = relay.var("data", relay.TensorType((1, 3, 64, 64), "float32"))
227225
weight = relay.var("weight", relay.TensorType((8, 3, 5, 5), "float32"))

tests/python/unittest/test_micro_project_api.py

Lines changed: 92 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -26,45 +26,52 @@
2626

2727
import tvm
2828

29-
pytest.importorskip("tvm.micro")
30-
from tvm.micro import project_api
3129

30+
# Implementing as a fixture so that the tvm.micro import doesn't occur
31+
# until fixture setup time. This is necessary for pytest's collection
32+
# phase to work when USE_MICRO=OFF, while still explicitly listing the
33+
# tests as skipped.
34+
@tvm.testing.fixture
35+
def BaseTestHandler():
36+
from tvm.micro import project_api
37+
38+
class BaseTestHandler_Impl(project_api.server.ProjectAPIHandler):
39+
40+
DEFAULT_TEST_SERVER_INFO = project_api.server.ServerInfo(
41+
platform_name="platform_name",
42+
is_template=True,
43+
model_library_format_path="./model-library-format-path.sh",
44+
project_options=[
45+
project_api.server.ProjectOption(name="foo", help="Option foo"),
46+
project_api.server.ProjectOption(name="bar", choices=["qux"], help="Option bar"),
47+
],
48+
)
3249

33-
class BaseTestHandler(project_api.server.ProjectAPIHandler):
34-
35-
DEFAULT_TEST_SERVER_INFO = project_api.server.ServerInfo(
36-
platform_name="platform_name",
37-
is_template=True,
38-
model_library_format_path="./model-library-format-path.sh",
39-
project_options=[
40-
project_api.server.ProjectOption(name="foo", help="Option foo"),
41-
project_api.server.ProjectOption(name="bar", choices=["qux"], help="Option bar"),
42-
],
43-
)
50+
def server_info_query(self, tvm_version):
51+
return self.DEFAULT_TEST_SERVER_INFO
4452

45-
def server_info_query(self, tvm_version):
46-
return self.DEFAULT_TEST_SERVER_INFO
53+
def generate_project(self, model_library_format_path, crt_path, project_path, options):
54+
assert False, "generate_project is not implemented for this test"
4755

48-
def generate_project(self, model_library_format_path, crt_path, project_path, options):
49-
assert False, "generate_project is not implemented for this test"
56+
def build(self, options):
57+
assert False, "build is not implemented for this test"
5058

51-
def build(self, options):
52-
assert False, "build is not implemented for this test"
59+
def flash(self, options):
60+
assert False, "flash is not implemented for this test"
5361

54-
def flash(self, options):
55-
assert False, "flash is not implemented for this test"
62+
def open_transport(self, options):
63+
assert False, "open_transport is not implemented for this test"
5664

57-
def open_transport(self, options):
58-
assert False, "open_transport is not implemented for this test"
65+
def close_transport(self, options):
66+
assert False, "open_transport is not implemented for this test"
5967

60-
def close_transport(self, options):
61-
assert False, "open_transport is not implemented for this test"
68+
def read_transport(self, n, timeout_sec):
69+
assert False, "read_transport is not implemented for this test"
6270

63-
def read_transport(self, n, timeout_sec):
64-
assert False, "read_transport is not implemented for this test"
71+
def write_transport(self, data, timeout_sec):
72+
assert False, "write_transport is not implemented for this test"
6573

66-
def write_transport(self, data, timeout_sec):
67-
assert False, "write_transport is not implemented for this test"
74+
return BaseTestHandler_Impl
6875

6976

7077
class Transport:
@@ -100,6 +107,8 @@ def write(self, data):
100107

101108
class ClientServerFixture:
102109
def __init__(self, handler):
110+
from tvm.micro import project_api
111+
103112
self.handler = handler
104113
self.client_to_server = Transport()
105114
self.server_to_client = Transport()
@@ -121,7 +130,8 @@ def _process_server_request(self):
121130
), "Server failed to process request"
122131

123132

124-
def test_server_info_query():
133+
@tvm.testing.requires_micro
134+
def test_server_info_query(BaseTestHandler):
125135
fixture = ClientServerFixture(BaseTestHandler())
126136

127137
# Examine reply explicitly because these are the defaults for all derivative test cases.
@@ -136,7 +146,10 @@ def test_server_info_query():
136146
]
137147

138148

139-
def test_server_info_query_wrong_tvm_version():
149+
@tvm.testing.requires_micro
150+
def test_server_info_query_wrong_tvm_version(BaseTestHandler):
151+
from tvm.micro import project_api
152+
140153
def server_info_query(tvm_version):
141154
raise project_api.server.UnsupportedTVMVersionError()
142155

@@ -148,7 +161,10 @@ def server_info_query(tvm_version):
148161
assert "UnsupportedTVMVersionError" in str(exc_info.value)
149162

150163

151-
def test_server_info_query_wrong_protocol_version():
164+
@tvm.testing.requires_micro
165+
def test_server_info_query_wrong_protocol_version(BaseTestHandler):
166+
from tvm.micro import project_api
167+
152168
ServerInfoProtocol = collections.namedtuple(
153169
"ServerInfoProtocol", list(project_api.server.ServerInfo._fields) + ["protocol_version"]
154170
)
@@ -166,7 +182,8 @@ def server_info_query(tvm_version):
166182
assert "microTVM API Server supports protocol version 0; want 1" in str(exc_info.value)
167183

168184

169-
def test_base_test_handler():
185+
@tvm.testing.requires_micro
186+
def test_base_test_handler(BaseTestHandler):
170187
"""All methods should raise AssertionError on BaseTestHandler."""
171188
fixture = ClientServerFixture(BaseTestHandler())
172189

@@ -180,22 +197,27 @@ def test_base_test_handler():
180197
assert (exc_info.exception) == f"{method} is not implemented for this test"
181198

182199

183-
def test_build():
200+
@tvm.testing.requires_micro
201+
def test_build(BaseTestHandler):
184202
with mock.patch.object(BaseTestHandler, "build", return_value=None) as patch:
185203
fixture = ClientServerFixture(BaseTestHandler())
186204
fixture.client.build(options={"bar": "baz"})
187205

188206
fixture.handler.build.assert_called_once_with(options={"bar": "baz"})
189207

190208

191-
def test_flash():
209+
@tvm.testing.requires_micro
210+
def test_flash(BaseTestHandler):
192211
with mock.patch.object(BaseTestHandler, "flash", return_value=None) as patch:
193212
fixture = ClientServerFixture(BaseTestHandler())
194213
fixture.client.flash(options={"bar": "baz"})
195214
fixture.handler.flash.assert_called_once_with(options={"bar": "baz"})
196215

197216

198-
def test_open_transport():
217+
@tvm.testing.requires_micro
218+
def test_open_transport(BaseTestHandler):
219+
from tvm.micro import project_api
220+
199221
timeouts = project_api.server.TransportTimeouts(
200222
session_start_retry_timeout_sec=1.0,
201223
session_start_timeout_sec=2.0,
@@ -210,14 +232,18 @@ def test_open_transport():
210232
fixture.handler.open_transport.assert_called_once_with({"bar": "baz"})
211233

212234

213-
def test_close_transport():
235+
@tvm.testing.requires_micro
236+
def test_close_transport(BaseTestHandler):
214237
with mock.patch.object(BaseTestHandler, "close_transport", return_value=None) as patch:
215238
fixture = ClientServerFixture(BaseTestHandler())
216239
fixture.client.close_transport()
217240
fixture.handler.close_transport.assert_called_once_with()
218241

219242

220-
def test_read_transport():
243+
@tvm.testing.requires_micro
244+
def test_read_transport(BaseTestHandler):
245+
from tvm.micro import project_api
246+
221247
with mock.patch.object(BaseTestHandler, "read_transport", return_value=b"foo\x1b") as patch:
222248
fixture = ClientServerFixture(BaseTestHandler())
223249
assert fixture.client.read_transport(128, timeout_sec=5.0) == {"data": b"foo\x1b"}
@@ -239,7 +265,10 @@ def test_read_transport():
239265
assert fixture.handler.read_transport.call_count == 3
240266

241267

242-
def test_write_transport():
268+
@tvm.testing.requires_micro
269+
def test_write_transport(BaseTestHandler):
270+
from tvm.micro import project_api
271+
243272
with mock.patch.object(BaseTestHandler, "write_transport", return_value=None) as patch:
244273
fixture = ClientServerFixture(BaseTestHandler())
245274
assert fixture.client.write_transport(b"foo", timeout_sec=5.0) is None
@@ -264,7 +293,10 @@ class ProjectAPITestError(Exception):
264293
"""An error raised in test."""
265294

266295

267-
def test_method_raises_error():
296+
@tvm.testing.requires_micro
297+
def test_method_raises_error(BaseTestHandler):
298+
from tvm.micro import project_api
299+
268300
with mock.patch.object(
269301
BaseTestHandler, "close_transport", side_effect=ProjectAPITestError
270302
) as patch:
@@ -276,7 +308,10 @@ def test_method_raises_error():
276308
assert "ProjectAPITestError" in str(exc_info.value)
277309

278310

279-
def test_method_not_found():
311+
@tvm.testing.requires_micro
312+
def test_method_not_found(BaseTestHandler):
313+
from tvm.micro import project_api
314+
280315
fixture = ClientServerFixture(BaseTestHandler())
281316

282317
with pytest.raises(project_api.server.JSONRPCError) as exc_info:
@@ -285,7 +320,10 @@ def test_method_not_found():
285320
assert exc_info.value.code == project_api.server.ErrorCode.METHOD_NOT_FOUND
286321

287322

288-
def test_extra_param():
323+
@tvm.testing.requires_micro
324+
def test_extra_param(BaseTestHandler):
325+
from tvm.micro import project_api
326+
289327
fixture = ClientServerFixture(BaseTestHandler())
290328

291329
# test one with has_preprocssing and one without
@@ -304,7 +342,10 @@ def test_extra_param():
304342
assert "open_transport: extra parameters: invalid_param_name" in str(exc_info.value)
305343

306344

307-
def test_missing_param():
345+
@tvm.testing.requires_micro
346+
def test_missing_param(BaseTestHandler):
347+
from tvm.micro import project_api
348+
308349
fixture = ClientServerFixture(BaseTestHandler())
309350

310351
# test one with has_preprocssing and one without
@@ -323,7 +364,10 @@ def test_missing_param():
323364
assert "open_transport: parameter options not given" in str(exc_info.value)
324365

325366

326-
def test_incorrect_param_type():
367+
@tvm.testing.requires_micro
368+
def test_incorrect_param_type(BaseTestHandler):
369+
from tvm.micro import project_api
370+
327371
fixture = ClientServerFixture(BaseTestHandler())
328372

329373
# The error message given at the JSON-RPC server level doesn't make sense when preprocessing is
@@ -338,7 +382,10 @@ def test_incorrect_param_type():
338382
)
339383

340384

341-
def test_invalid_request():
385+
@tvm.testing.requires_micro
386+
def test_invalid_request(BaseTestHandler):
387+
from tvm.micro import project_api
388+
342389
fixture = ClientServerFixture(BaseTestHandler())
343390

344391
# Invalid JSON does not get a reply.

0 commit comments

Comments
 (0)