diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index d0ceee4aa2a0..ac22af282345 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1447,7 +1447,7 @@ def parameter(*values, ids=None, by_dict=None): # Optional cls parameter in case a parameter is defined inside a # class scope. - @pytest.fixture(params=values, ids=ids) + @pytest.fixture(params=values, ids=ids, scope="session") def as_fixture(*_cls, request): return request.param diff --git a/tests/python/testing/test_tvm_testing_features.py b/tests/python/testing/test_tvm_testing_features.py index 5c0e526f0d4d..6d394ebeb649 100644 --- a/tests/python/testing/test_tvm_testing_features.py +++ b/tests/python/testing/test_tvm_testing_features.py @@ -290,5 +290,16 @@ def test_uses_deepcopy(self, fixture_with_deepcopy): pass +class TestPytestCache: + param = tvm.testing.parameter(1, 2, 3) + + @pytest.fixture(scope="class") + def cached_fixture(self, param): + return param * param + + def test_uses_cached_fixture(self, param, cached_fixture): + assert cached_fixture == param * param + + if __name__ == "__main__": tvm.testing.main()