diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index 75b45cf424c3..cfb45dafdeb2 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -87,6 +87,21 @@ class SearchStrategy(Object): ], ] + def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument + """Prevent direct instantiation of abstract SearchStrategy class. + + SearchStrategy is an abstract class and cannot be directly instantiated. + Use SearchStrategy.create() or a concrete subclass instead. + """ + if cls is SearchStrategy: + raise TypeError( + "Cannot instantiate abstract class SearchStrategy. " + "Use SearchStrategy.create() with a valid strategy type " + "(e.g., 'evolutionary', 'replay-trace', 'replay-func') " + "or use a concrete subclass instead." + ) + return super().__new__(cls) # pylint: disable=no-value-for-parameter + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the search strategy with tuning context. diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index c3f496265a97..35a8d468a75c 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -123,6 +123,15 @@ def __init__( if search_strategy is not None: if not isinstance(search_strategy, SearchStrategy): search_strategy = SearchStrategy.create(search_strategy) + # Additional check: ensure it's not the abstract SearchStrategy class itself + # Use type() for exact type check (not isinstance which would match subclasses) + elif type(search_strategy) is SearchStrategy: # pylint: disable=unidiomatic-typecheck + raise TypeError( + "Cannot use abstract SearchStrategy class directly. " + "Use SearchStrategy.create() with a valid strategy type " + "(e.g., 'evolutionary', 'replay-trace', 'replay-func') " + "or use a concrete subclass instead." + ) if logger is None: logger = get_logger(__name__) if not isinstance(num_threads, int): diff --git a/tests/python/meta_schedule/test_meta_schedule_search_strategy.py b/tests/python/meta_schedule/test_meta_schedule_search_strategy.py index 29c20ced0488..04a6e187a6a7 100644 --- a/tests/python/meta_schedule/test_meta_schedule_search_strategy.py +++ b/tests/python/meta_schedule/test_meta_schedule_search_strategy.py @@ -306,9 +306,40 @@ def __str__(self) -> str: assert candidates is None +def test_search_strategy_abstract_class_instantiation(): + """Test that directly instantiating abstract SearchStrategy raises TypeError instead of segfault.""" + from tvm.meta_schedule import SearchStrategy + from tvm.target import Target + from tvm.meta_schedule import TuneContext + + # Test that direct instantiation raises TypeError + # This prevents segfault when SearchStrategy() is called directly + with pytest.raises(TypeError, match="Cannot instantiate abstract class SearchStrategy"): + SearchStrategy() + + # Test that TuneContext with SearchStrategy() raises TypeError + # The error should occur when trying to create SearchStrategy() instance in the function call + # Since SearchStrategy() fails in __new__, it will fail before TuneContext.__init__ is called + with pytest.raises(TypeError, match="Cannot instantiate abstract class SearchStrategy"): + # This will fail when evaluating SearchStrategy() as an argument + TuneContext( + mod=Matmul, # Use the existing Matmul module from the test file + target=Target("llvm"), + search_strategy=SearchStrategy(), # This should fail in __new__ before reaching TuneContext + ) + + # Test that SearchStrategy.create() works correctly + strategy = SearchStrategy.create("evolutionary") + assert strategy is not None + assert isinstance(strategy, SearchStrategy) + # Verify it's not the abstract class itself + assert type(strategy) is not SearchStrategy + + if __name__ == "__main__": test_meta_schedule_replay_func(ms.search_strategy.ReplayFunc) test_meta_schedule_replay_func(ms.search_strategy.ReplayTrace) test_meta_schedule_evolutionary_search() test_meta_schedule_evolutionary_search_early_stop() test_meta_schedule_evolutionary_search_fail_init_population() + test_search_strategy_abstract_class_instantiation()