diff --git a/src/hyperactive/tests/test_all_objects.py b/src/hyperactive/tests/test_all_objects.py index 84257db0..986c0086 100644 --- a/src/hyperactive/tests/test_all_objects.py +++ b/src/hyperactive/tests/test_all_objects.py @@ -4,6 +4,7 @@ import shutil from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator +from skbase.testing import QuickTester as _QuickTester from skbase.testing import TestAllObjects as _TestAllObjects from hyperactive._registry import all_objects @@ -154,7 +155,7 @@ class ExperimentFixtureGenerator(BaseFixtureGenerator): object_type_filter = "experiment" -class TestAllExperiments(ExperimentFixtureGenerator): +class TestAllExperiments(ExperimentFixtureGenerator, _QuickTester): """Module level tests for all experiment classes.""" def test_paramnames(self, object_class): @@ -204,7 +205,7 @@ class OptimizerFixtureGenerator(BaseFixtureGenerator): object_type_filter = "optimizer" -class TestAllOptimizers(OptimizerFixtureGenerator): +class TestAllOptimizers(OptimizerFixtureGenerator, _QuickTester): """Module level tests for all optimizer classes.""" def test_opt_run(self, object_instance): diff --git a/src/hyperactive/tests/test_class_register.py b/src/hyperactive/tests/test_class_register.py new file mode 100644 index 00000000..dfbd1f1f --- /dev/null +++ b/src/hyperactive/tests/test_class_register.py @@ -0,0 +1,94 @@ +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) +"""Registry and dispatcher for test classes. + +Module does not contain tests, only test utilities. +""" + +__author__ = ["fkiraly"] + +from inspect import isclass + + +def get_test_class_registry(): + """Return test class registry. + + Wrapped in a function to avoid circular imports. + + Returns + ------- + testclass_dict : dict + test class registry + keys are scitypes, values are test classes TestAll[Scitype] + """ + from hyperactive.tests.test_all_objects import ( + TestAllExperiments, + TestAllObjects, + TestAllOptimizers, + ) + + testclass_dict = dict() + # every object in sktime inherits from BaseObject + # "object" tests are run for all objects + testclass_dict["object"] = TestAllObjects + # more specific base classes + # these inherit either from BaseEstimator or BaseObject, + # so also imply estimator and object tests, or only object tests + testclass_dict["experiment"] = TestAllExperiments + testclass_dict["optimizer"] = TestAllOptimizers + + return testclass_dict + + +def get_test_classes_for_obj(obj): + """Get all test classes relevant for an object or estimator. + + Parameters + ---------- + obj : object or estimator, descendant of sktime BaseObject or BaseEstimator + object or estimator for which to get test classes + + Returns + ------- + test_classes : list of test classes + list of test classes relevant for obj + these are references to the actual classes, not strings + if obj was not a descendant of BaseObject or BaseEstimator, returns empty list + """ + from skbase.base import BaseObject + + def is_object(obj): + """Return whether obj is an estimator class or estimator object.""" + if isclass(obj): + return issubclass(obj, BaseObject) + else: + return isinstance(obj, BaseObject) + + # warning: BaseEstimator does not inherit from BaseObject, + # therefore we need to check both + if not is_object(obj): + return [] + + testclass_dict = get_test_class_registry() + + # we always need to run "object" tests + test_clss = [testclass_dict["object"]] + + try: + if isclass(obj): + obj_scitypes = obj.get_class_tag("object_type") + elif hasattr(obj, "get_tag"): + obj_scitypes = obj.get_tag("object_type") + else: + obj_scitypes = [] + except Exception: + obj_scitypes = [] + + if isinstance(obj_scitypes, str): + # if obj_scitypes is a string, convert to list + obj_scitypes = [obj_scitypes] + + for obj_scitype in obj_scitypes: + if obj_scitype in testclass_dict: + test_clss += [testclass_dict[obj_scitype]] + + return test_clss diff --git a/src/hyperactive/utils/__init__.py b/src/hyperactive/utils/__init__.py new file mode 100644 index 00000000..c9c88720 --- /dev/null +++ b/src/hyperactive/utils/__init__.py @@ -0,0 +1,7 @@ +"""Utility functionality.""" + +from hyperactive.utils.estimator_checks import check_estimator + +__all__ = [ + "check_estimator", +] diff --git a/src/hyperactive/utils/estimator_checks.py b/src/hyperactive/utils/estimator_checks.py new file mode 100644 index 00000000..1bc9f793 --- /dev/null +++ b/src/hyperactive/utils/estimator_checks.py @@ -0,0 +1,139 @@ +"""Estimator checker for extension.""" + +__author__ = ["fkiraly"] +__all__ = ["check_estimator"] + +from skbase.utils.dependencies import _check_soft_dependencies + + +def check_estimator( + estimator, + raise_exceptions=False, + tests_to_run=None, + fixtures_to_run=None, + verbose=True, + tests_to_exclude=None, + fixtures_to_exclude=None, +): + """Run all tests on one single estimator. + + Tests that are run on estimator: + + * all tests in test_all_estimators + * all interface compatibility tests from the module of estimator's scitype + + Parameters + ---------- + estimator : estimator class or estimator instance + raise_exceptions : bool, optional, default=False + whether to return exceptions/failures in the results dict, or raise them + + * if False: returns exceptions in returned `results` dict + * if True: raises exceptions as they occur + + tests_to_run : str or list of str, optional. Default = run all tests. + Names (test/function name string) of tests to run. + sub-sets tests that are run to the tests given here. + fixtures_to_run : str or list of str, optional. Default = run all tests. + pytest test-fixture combination codes, which test-fixture combinations to run. + sub-sets tests and fixtures to run to the list given here. + If both tests_to_run and fixtures_to_run are provided, runs the *union*, + i.e., all test-fixture combinations for tests in tests_to_run, + plus all test-fixture combinations in fixtures_to_run. + verbose : str, optional, default=True. + whether to print out informative summary of tests run. + tests_to_exclude : str or list of str, names of tests to exclude. default = None + removes tests that should not be run, after subsetting via tests_to_run. + fixtures_to_exclude : str or list of str, fixtures to exclude. default = None + removes test-fixture combinations that should not be run. + This is done after subsetting via fixtures_to_run. + + Returns + ------- + results : dict of results of the tests in self + keys are test/fixture strings, identical as in pytest, e.g., test[fixture] + entries are the string "PASSED" if the test passed, + or the exception raised if the test did not pass + returned only if all tests pass, or raise_exceptions=False + + Raises + ------ + if raise_exceptions=True, + raises any exception produced by the tests directly + + Examples + -------- + >>> from hyperactive.opt import HillClimbing + >>> from hyperactive.utils import check_estimator + + Running all tests for HillClimbing class, + this uses all instances from get_test_params and compatible scenarios + + >>> results = check_estimator(HillClimbing) + All tests PASSED! + + Running all tests for a specific HillClimbing + this uses the instance that is passed and compatible scenarios + + >>> specific_hill_climbing = HillClimbing.create_test_instance() + >>> results = check_estimator(specific_hill_climbing) + All tests PASSED! + + Running specific test (all fixtures) HillClimbing + + >>> results = check_estimator(HillClimbing, tests_to_run="test_clone") + All tests PASSED! + + {'test_clone[HillClimbing-0]': 'PASSED', + 'test_clone[HillClimbing-1]': 'PASSED'} + + Running one specific test-fixture-combination for ResidualDouble + + >>> check_estimator( + ... HillClimbing, fixtures_to_run="test_clone[HillClimbing-1]" + ... ) + All tests PASSED! + {'test_clone[HillClimbing-1]': 'PASSED'} + """ + msg = ( + "check_estimator is a testing utility for developers, and " + "requires pytest to be present " + "in the python environment, but pytest was not found. " + "pytest is a developer dependency and not included in the base " + "sktime installation. Please run: `pip install pytest` to " + "install the pytest package. " + "To install sktime with all developer dependencies, run:" + " `pip install hyperactive[dev]`" + ) + _check_soft_dependencies("pytest", msg=msg) + + from hyperactive.tests.test_class_register import get_test_classes_for_obj + + test_clss_for_est = get_test_classes_for_obj(estimator) + + results = {} + + for test_cls in test_clss_for_est: + test_cls_results = test_cls().run_tests( + obj=estimator, + raise_exceptions=raise_exceptions, + tests_to_run=tests_to_run, + fixtures_to_run=fixtures_to_run, + tests_to_exclude=tests_to_exclude, + fixtures_to_exclude=fixtures_to_exclude, + ) + results.update(test_cls_results) + + failed_tests = [key for key in results.keys() if results[key] != "PASSED"] + if len(failed_tests) > 0: + msg = failed_tests + msg = ["FAILED: " + x for x in msg] + msg = "\n".join(msg) + else: + msg = "All tests PASSED!" + + if verbose: + # printing is an intended feature, for console usage and interactive debugging + print(msg) # noqa T001 + + return results