Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions python/tvm/meta_schedule/search_strategy/search_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,21 @@ class SearchStrategy(Object):
],
]

def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
"""Prevent direct instantiation of abstract SearchStrategy class.

SearchStrategy is an abstract class and cannot be directly instantiated.
Use SearchStrategy.create() or a concrete subclass instead.
"""
if cls is SearchStrategy:
raise TypeError(
"Cannot instantiate abstract class SearchStrategy. "
"Use SearchStrategy.create() with a valid strategy type "
"(e.g., 'evolutionary', 'replay-trace', 'replay-func') "
"or use a concrete subclass instead."
)
return super().__new__(cls) # pylint: disable=no-value-for-parameter

def _initialize_with_tune_context(self, context: "TuneContext") -> None:
"""Initialize the search strategy with tuning context.

Expand Down
9 changes: 9 additions & 0 deletions python/tvm/meta_schedule/tune_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ def __init__(
if search_strategy is not None:
if not isinstance(search_strategy, SearchStrategy):
search_strategy = SearchStrategy.create(search_strategy)
# Additional check: ensure it's not the abstract SearchStrategy class itself
# Use type() for exact type check (not isinstance which would match subclasses)
elif type(search_strategy) is SearchStrategy: # pylint: disable=unidiomatic-typecheck
raise TypeError(
"Cannot use abstract SearchStrategy class directly. "
"Use SearchStrategy.create() with a valid strategy type "
"(e.g., 'evolutionary', 'replay-trace', 'replay-func') "
"or use a concrete subclass instead."
)
if logger is None:
logger = get_logger(__name__)
if not isinstance(num_threads, int):
Expand Down
31 changes: 31 additions & 0 deletions tests/python/meta_schedule/test_meta_schedule_search_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,40 @@ def __str__(self) -> str:
assert candidates is None


def test_search_strategy_abstract_class_instantiation():
"""Test that directly instantiating abstract SearchStrategy raises TypeError instead of segfault."""
from tvm.meta_schedule import SearchStrategy
from tvm.target import Target
from tvm.meta_schedule import TuneContext

# Test that direct instantiation raises TypeError
# This prevents segfault when SearchStrategy() is called directly
with pytest.raises(TypeError, match="Cannot instantiate abstract class SearchStrategy"):
SearchStrategy()

# Test that TuneContext with SearchStrategy() raises TypeError
# The error should occur when trying to create SearchStrategy() instance in the function call
# Since SearchStrategy() fails in __new__, it will fail before TuneContext.__init__ is called
with pytest.raises(TypeError, match="Cannot instantiate abstract class SearchStrategy"):
# This will fail when evaluating SearchStrategy() as an argument
TuneContext(
mod=Matmul, # Use the existing Matmul module from the test file
target=Target("llvm"),
search_strategy=SearchStrategy(), # This should fail in __new__ before reaching TuneContext
)

# Test that SearchStrategy.create() works correctly
strategy = SearchStrategy.create("evolutionary")
assert strategy is not None
assert isinstance(strategy, SearchStrategy)
# Verify it's not the abstract class itself
assert type(strategy) is not SearchStrategy


if __name__ == "__main__":
test_meta_schedule_replay_func(ms.search_strategy.ReplayFunc)
test_meta_schedule_replay_func(ms.search_strategy.ReplayTrace)
test_meta_schedule_evolutionary_search()
test_meta_schedule_evolutionary_search_early_stop()
test_meta_schedule_evolutionary_search_fail_init_population()
test_search_strategy_abstract_class_instantiation()