2626
2727import tvm
2828
29- pytest .importorskip ("tvm.micro" )
30- from tvm .micro import project_api
3129
30+ # Implementing as a fixture so that the tvm.micro import doesn't occur
31+ # until fixture setup time. This is necessary for pytest's collection
32+ # phase to work when USE_MICRO=OFF, while still explicitly listing the
33+ # tests as skipped.
34+ @tvm .testing .fixture
35+ def BaseTestHandler ():
36+ from tvm .micro import project_api
37+
38+ class BaseTestHandler_Impl (project_api .server .ProjectAPIHandler ):
39+
40+ DEFAULT_TEST_SERVER_INFO = project_api .server .ServerInfo (
41+ platform_name = "platform_name" ,
42+ is_template = True ,
43+ model_library_format_path = "./model-library-format-path.sh" ,
44+ project_options = [
45+ project_api .server .ProjectOption (name = "foo" , help = "Option foo" ),
46+ project_api .server .ProjectOption (name = "bar" , choices = ["qux" ], help = "Option bar" ),
47+ ],
48+ )
3249
33- class BaseTestHandler (project_api .server .ProjectAPIHandler ):
34-
35- DEFAULT_TEST_SERVER_INFO = project_api .server .ServerInfo (
36- platform_name = "platform_name" ,
37- is_template = True ,
38- model_library_format_path = "./model-library-format-path.sh" ,
39- project_options = [
40- project_api .server .ProjectOption (name = "foo" , help = "Option foo" ),
41- project_api .server .ProjectOption (name = "bar" , choices = ["qux" ], help = "Option bar" ),
42- ],
43- )
50+ def server_info_query (self , tvm_version ):
51+ return self .DEFAULT_TEST_SERVER_INFO
4452
45- def server_info_query (self , tvm_version ):
46- return self . DEFAULT_TEST_SERVER_INFO
53+ def generate_project (self , model_library_format_path , crt_path , project_path , options ):
54+ assert False , "generate_project is not implemented for this test"
4755
48- def generate_project ( self , model_library_format_path , crt_path , project_path , options ):
49- assert False , "generate_project is not implemented for this test"
56+ def build ( self , options ):
57+ assert False , "build is not implemented for this test"
5058
51- def build (self , options ):
52- assert False , "build is not implemented for this test"
59+ def flash (self , options ):
60+ assert False , "flash is not implemented for this test"
5361
54- def flash (self , options ):
55- assert False , "flash is not implemented for this test"
62+ def open_transport (self , options ):
63+ assert False , "open_transport is not implemented for this test"
5664
57- def open_transport (self , options ):
58- assert False , "open_transport is not implemented for this test"
65+ def close_transport (self , options ):
66+ assert False , "open_transport is not implemented for this test"
5967
60- def close_transport (self , options ):
61- assert False , "open_transport is not implemented for this test"
68+ def read_transport (self , n , timeout_sec ):
69+ assert False , "read_transport is not implemented for this test"
6270
63- def read_transport (self , n , timeout_sec ):
64- assert False , "read_transport is not implemented for this test"
71+ def write_transport (self , data , timeout_sec ):
72+ assert False , "write_transport is not implemented for this test"
6573
66- def write_transport (self , data , timeout_sec ):
67- assert False , "write_transport is not implemented for this test"
74+ return BaseTestHandler_Impl
6875
6976
7077class Transport :
@@ -100,6 +107,8 @@ def write(self, data):
100107
101108class ClientServerFixture :
102109 def __init__ (self , handler ):
110+ from tvm .micro import project_api
111+
103112 self .handler = handler
104113 self .client_to_server = Transport ()
105114 self .server_to_client = Transport ()
@@ -121,7 +130,8 @@ def _process_server_request(self):
121130 ), "Server failed to process request"
122131
123132
124- def test_server_info_query ():
133+ @tvm .testing .requires_micro
134+ def test_server_info_query (BaseTestHandler ):
125135 fixture = ClientServerFixture (BaseTestHandler ())
126136
127137 # Examine reply explicitly because these are the defaults for all derivative test cases.
@@ -136,7 +146,10 @@ def test_server_info_query():
136146 ]
137147
138148
139- def test_server_info_query_wrong_tvm_version ():
149+ @tvm .testing .requires_micro
150+ def test_server_info_query_wrong_tvm_version (BaseTestHandler ):
151+ from tvm .micro import project_api
152+
140153 def server_info_query (tvm_version ):
141154 raise project_api .server .UnsupportedTVMVersionError ()
142155
@@ -148,7 +161,10 @@ def server_info_query(tvm_version):
148161 assert "UnsupportedTVMVersionError" in str (exc_info .value )
149162
150163
151- def test_server_info_query_wrong_protocol_version ():
164+ @tvm .testing .requires_micro
165+ def test_server_info_query_wrong_protocol_version (BaseTestHandler ):
166+ from tvm .micro import project_api
167+
152168 ServerInfoProtocol = collections .namedtuple (
153169 "ServerInfoProtocol" , list (project_api .server .ServerInfo ._fields ) + ["protocol_version" ]
154170 )
@@ -166,7 +182,8 @@ def server_info_query(tvm_version):
166182 assert "microTVM API Server supports protocol version 0; want 1" in str (exc_info .value )
167183
168184
169- def test_base_test_handler ():
185+ @tvm .testing .requires_micro
186+ def test_base_test_handler (BaseTestHandler ):
170187 """All methods should raise AssertionError on BaseTestHandler."""
171188 fixture = ClientServerFixture (BaseTestHandler ())
172189
@@ -180,22 +197,27 @@ def test_base_test_handler():
180197 assert (exc_info .exception ) == f"{ method } is not implemented for this test"
181198
182199
183- def test_build ():
200+ @tvm .testing .requires_micro
201+ def test_build (BaseTestHandler ):
184202 with mock .patch .object (BaseTestHandler , "build" , return_value = None ) as patch :
185203 fixture = ClientServerFixture (BaseTestHandler ())
186204 fixture .client .build (options = {"bar" : "baz" })
187205
188206 fixture .handler .build .assert_called_once_with (options = {"bar" : "baz" })
189207
190208
191- def test_flash ():
209+ @tvm .testing .requires_micro
210+ def test_flash (BaseTestHandler ):
192211 with mock .patch .object (BaseTestHandler , "flash" , return_value = None ) as patch :
193212 fixture = ClientServerFixture (BaseTestHandler ())
194213 fixture .client .flash (options = {"bar" : "baz" })
195214 fixture .handler .flash .assert_called_once_with (options = {"bar" : "baz" })
196215
197216
198- def test_open_transport ():
217+ @tvm .testing .requires_micro
218+ def test_open_transport (BaseTestHandler ):
219+ from tvm .micro import project_api
220+
199221 timeouts = project_api .server .TransportTimeouts (
200222 session_start_retry_timeout_sec = 1.0 ,
201223 session_start_timeout_sec = 2.0 ,
@@ -210,14 +232,18 @@ def test_open_transport():
210232 fixture .handler .open_transport .assert_called_once_with ({"bar" : "baz" })
211233
212234
213- def test_close_transport ():
235+ @tvm .testing .requires_micro
236+ def test_close_transport (BaseTestHandler ):
214237 with mock .patch .object (BaseTestHandler , "close_transport" , return_value = None ) as patch :
215238 fixture = ClientServerFixture (BaseTestHandler ())
216239 fixture .client .close_transport ()
217240 fixture .handler .close_transport .assert_called_once_with ()
218241
219242
220- def test_read_transport ():
243+ @tvm .testing .requires_micro
244+ def test_read_transport (BaseTestHandler ):
245+ from tvm .micro import project_api
246+
221247 with mock .patch .object (BaseTestHandler , "read_transport" , return_value = b"foo\x1b " ) as patch :
222248 fixture = ClientServerFixture (BaseTestHandler ())
223249 assert fixture .client .read_transport (128 , timeout_sec = 5.0 ) == {"data" : b"foo\x1b " }
@@ -239,7 +265,10 @@ def test_read_transport():
239265 assert fixture .handler .read_transport .call_count == 3
240266
241267
242- def test_write_transport ():
268+ @tvm .testing .requires_micro
269+ def test_write_transport (BaseTestHandler ):
270+ from tvm .micro import project_api
271+
243272 with mock .patch .object (BaseTestHandler , "write_transport" , return_value = None ) as patch :
244273 fixture = ClientServerFixture (BaseTestHandler ())
245274 assert fixture .client .write_transport (b"foo" , timeout_sec = 5.0 ) is None
@@ -264,7 +293,10 @@ class ProjectAPITestError(Exception):
264293 """An error raised in test."""
265294
266295
267- def test_method_raises_error ():
296+ @tvm .testing .requires_micro
297+ def test_method_raises_error (BaseTestHandler ):
298+ from tvm .micro import project_api
299+
268300 with mock .patch .object (
269301 BaseTestHandler , "close_transport" , side_effect = ProjectAPITestError
270302 ) as patch :
@@ -276,7 +308,10 @@ def test_method_raises_error():
276308 assert "ProjectAPITestError" in str (exc_info .value )
277309
278310
279- def test_method_not_found ():
311+ @tvm .testing .requires_micro
312+ def test_method_not_found (BaseTestHandler ):
313+ from tvm .micro import project_api
314+
280315 fixture = ClientServerFixture (BaseTestHandler ())
281316
282317 with pytest .raises (project_api .server .JSONRPCError ) as exc_info :
@@ -285,7 +320,10 @@ def test_method_not_found():
285320 assert exc_info .value .code == project_api .server .ErrorCode .METHOD_NOT_FOUND
286321
287322
288- def test_extra_param ():
323+ @tvm .testing .requires_micro
324+ def test_extra_param (BaseTestHandler ):
325+ from tvm .micro import project_api
326+
289327 fixture = ClientServerFixture (BaseTestHandler ())
290328
291329 # test one with has_preprocssing and one without
@@ -304,7 +342,10 @@ def test_extra_param():
304342 assert "open_transport: extra parameters: invalid_param_name" in str (exc_info .value )
305343
306344
307- def test_missing_param ():
345+ @tvm .testing .requires_micro
346+ def test_missing_param (BaseTestHandler ):
347+ from tvm .micro import project_api
348+
308349 fixture = ClientServerFixture (BaseTestHandler ())
309350
310351 # test one with has_preprocssing and one without
@@ -323,7 +364,10 @@ def test_missing_param():
323364 assert "open_transport: parameter options not given" in str (exc_info .value )
324365
325366
326- def test_incorrect_param_type ():
367+ @tvm .testing .requires_micro
368+ def test_incorrect_param_type (BaseTestHandler ):
369+ from tvm .micro import project_api
370+
327371 fixture = ClientServerFixture (BaseTestHandler ())
328372
329373 # The error message given at the JSON-RPC server level doesn't make sense when preprocessing is
@@ -338,7 +382,10 @@ def test_incorrect_param_type():
338382 )
339383
340384
341- def test_invalid_request ():
385+ @tvm .testing .requires_micro
386+ def test_invalid_request (BaseTestHandler ):
387+ from tvm .micro import project_api
388+
342389 fixture = ClientServerFixture (BaseTestHandler ())
343390
344391 # Invalid JSON does not get a reply.
0 commit comments