diff --git a/migrations/versions/2025_02_19_1452-5e5cd2288147_update_matcher_types.py b/migrations/versions/2025_02_19_1452-5e5cd2288147_update_matcher_types.py new file mode 100644 index 00000000..138ae7ee --- /dev/null +++ b/migrations/versions/2025_02_19_1452-5e5cd2288147_update_matcher_types.py @@ -0,0 +1,65 @@ +"""update matcher types + +Revision ID: 5e5cd2288147 +Revises: 0c3539f66339 +Create Date: 2025-02-19 14:52:39.126196+00:00 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "5e5cd2288147" +down_revision: Union[str, None] = "0c3539f66339" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Begin transaction + op.execute("BEGIN TRANSACTION;") + + # Update the matcher types. We need to do this every time we change the matcher types. + # in /muxing/models.py + op.execute( + """ + UPDATE muxes + SET matcher_type = 'fim_filename', matcher_blob = '' + WHERE matcher_type = 'request_type_match' AND matcher_blob = 'fim'; + """ + ) + op.execute( + """ + UPDATE muxes + SET matcher_type = 'chat_filename', matcher_blob = '' + WHERE matcher_type = 'request_type_match' AND matcher_blob = 'chat'; + """ + ) + + # Finish transaction + op.execute("COMMIT;") + + +def downgrade() -> None: + # Begin transaction + op.execute("BEGIN TRANSACTION;") + + op.execute( + """ + UPDATE muxes + SET matcher_blob = 'fim', matcher_type = 'request_type_match' + WHERE matcher_type = 'fim'; + """ + ) + op.execute( + """ + UPDATE muxes + SET matcher_blob = 'chat', matcher_type = 'request_type_match' + WHERE matcher_type = 'chat'; + """ + ) + + # Finish transaction + op.execute("COMMIT;") diff --git a/src/codegate/muxing/models.py b/src/codegate/muxing/models.py index 4c822485..5637c5b8 100644 --- a/src/codegate/muxing/models.py +++ b/src/codegate/muxing/models.py @@ -1,14 +1,26 @@ from enum import Enum -from typing import Optional +from typing import Optional, Self import pydantic from codegate.clients.clients import ClientType +from codegate.db.models import MuxRule as DBMuxRule class MuxMatcherType(str, Enum): """ Represents the different types of matchers we support. + + The 3 rules present match filenames and request types. They're used in conjunction with the + matcher field in the MuxRule model. + E.g. + - catch_all-> Always match + - filename_match and match: requests.py -> Match the request if the filename is requests.py + - fim_filename and match: main.py -> Match the request if the request type is fim + and the filename is main.py + + NOTE: Removing or updating fields from this enum will require a migration. + Adding new fields is safe. """ # Always match this prompt @@ -16,9 +28,10 @@ class MuxMatcherType(str, Enum): # Match based on the filename. It will match if there is a filename # in the request that matches the matcher either extension or full name (*.py or main.py) filename_match = "filename_match" - # Match based on the request type. It will match if the request type - # matches the matcher (e.g. FIM or chat) - request_type_match = "request_type_match" + # Match based on fim request type. It will match if the request type is fim + fim_filename = "fim_filename" + # Match based on chat request type. It will match if the request type is chat + chat_filename = "chat_filename" class MuxRule(pydantic.BaseModel): @@ -36,6 +49,18 @@ class MuxRule(pydantic.BaseModel): # this depends on the matcher type. matcher: Optional[str] = None + @classmethod + def from_db_mux_rule(cls, db_mux_rule: DBMuxRule) -> Self: + """ + Convert a DBMuxRule to a MuxRule. + """ + return MuxRule( + provider_id=db_mux_rule.id, + model=db_mux_rule.provider_model_name, + matcher_type=db_mux_rule.matcher_type, + matcher=db_mux_rule.matcher_blob, + ) + class ThingToMatchMux(pydantic.BaseModel): """ diff --git a/src/codegate/muxing/router.py b/src/codegate/muxing/router.py index 4231e8e7..bfa9c663 100644 --- a/src/codegate/muxing/router.py +++ b/src/codegate/muxing/router.py @@ -50,8 +50,11 @@ async def _get_model_route( # Try to get a model route for the active workspace model_route = await mux_registry.get_match_for_active_workspace(thing_to_match) return model_route + except rulematcher.MuxMatchingError as e: + logger.exception(f"Error matching rule and getting model route: {e}") + raise HTTPException(detail=str(e), status_code=404) except Exception as e: - logger.error(f"Error getting active workspace muxes: {e}") + logger.exception(f"Error getting active workspace muxes: {e}") raise HTTPException(detail=str(e), status_code=404) def _setup_routes(self): diff --git a/src/codegate/muxing/rulematcher.py b/src/codegate/muxing/rulematcher.py index fb3c1da2..247e6c12 100644 --- a/src/codegate/muxing/rulematcher.py +++ b/src/codegate/muxing/rulematcher.py @@ -18,6 +18,12 @@ _singleton_lock = Lock() +class MuxMatchingError(Exception): + """An exception for muxing matching errors.""" + + pass + + async def get_muxing_rules_registry(): """Returns a singleton instance of the muxing rules registry.""" @@ -48,9 +54,9 @@ def __init__( class MuxingRuleMatcher(ABC): """Base class for matching muxing rules.""" - def __init__(self, route: ModelRoute, matcher_blob: str): + def __init__(self, route: ModelRoute, mux_rule: mux_models.MuxRule): self._route = route - self._matcher_blob = matcher_blob + self._mux_rule = mux_rule @abstractmethod def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: @@ -67,18 +73,20 @@ class MuxingMatcherFactory: """Factory for creating muxing matchers.""" @staticmethod - def create(mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher: + def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher: """Create a muxing matcher for the given endpoint and model.""" factory: Dict[mux_models.MuxMatcherType, MuxingRuleMatcher] = { mux_models.MuxMatcherType.catch_all: CatchAllMuxingRuleMatcher, mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher, - mux_models.MuxMatcherType.request_type_match: RequestTypeMuxingRuleMatcher, + mux_models.MuxMatcherType.fim_filename: RequestTypeAndFileMuxingRuleMatcher, + mux_models.MuxMatcherType.chat_filename: RequestTypeAndFileMuxingRuleMatcher, } try: # Initialize the MuxingRuleMatcher - return factory[mux_rule.matcher_type](route, mux_rule.matcher_blob) + mux_rule = mux_models.MuxRule.from_db_mux_rule(db_mux_rule) + return factory[mux_rule.matcher_type](route, mux_rule) except KeyError: raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}") @@ -103,47 +111,63 @@ def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> return body_extractor.extract_unique_filenames(data) except BodyCodeSnippetExtractorError as e: logger.error(f"Error extracting filenames from request: {e}") - return set() + raise MuxMatchingError("Error extracting filenames from request") + + def _is_matcher_in_filenames(self, detected_client: ClientType, data: dict) -> bool: + """ + Check if the matcher is in the request filenames. + """ + # Empty matcher_blob means we match everything + if not self._mux_rule.matcher: + return True + filenames_to_match = self._extract_request_filenames(detected_client, data) + # _mux_rule.matcher can be a filename or a file extension. We match if any of the filenames + # match the rule. + is_filename_match = any( + self._mux_rule.matcher == filename or filename.endswith(self._mux_rule.matcher) + for filename in filenames_to_match + ) + return is_filename_match def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: """ - Retun True if there is a filename in the request that matches the matcher_blob. - The matcher_blob is either an extension (e.g. .py) or a filename (e.g. main.py). + Return True if the matcher is in one of the request filenames. """ - # If there is no matcher_blob, we don't match - if not self._matcher_blob: - return False - filenames_to_match = self._extract_request_filenames( + is_rule_matched = self._is_matcher_in_filenames( thing_to_match.client_type, thing_to_match.body ) - is_filename_match = any(self._matcher_blob in filename for filename in filenames_to_match) - if is_filename_match: - logger.info( - "Filename rule matched", filenames=filenames_to_match, matcher=self._matcher_blob - ) - return is_filename_match + if is_rule_matched: + logger.info("Filename rule matched", matcher=self._mux_rule.matcher) + return is_rule_matched -class RequestTypeMuxingRuleMatcher(MuxingRuleMatcher): - """A catch all muxing rule matcher.""" +class RequestTypeAndFileMuxingRuleMatcher(FileMuxingRuleMatcher): + """A request type and file muxing rule matcher.""" + + def _is_request_type_match(self, is_fim_request: bool) -> bool: + """ + Check if the request type matches the MuxMatcherType. + """ + incoming_request_type = "fim_filename" if is_fim_request else "chat_filename" + if incoming_request_type == self._mux_rule.matcher_type: + return True + return False def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: """ - Return True if the request type matches the matcher_blob. - The matcher_blob is either "fim" or "chat". + Return True if the matcher is in one of the request filenames and + if the request type matches the MuxMatcherType. """ - # If there is no matcher_blob, we don't match - if not self._matcher_blob: - return False - incoming_request_type = "fim" if thing_to_match.is_fim_request else "chat" - is_request_type_match = self._matcher_blob == incoming_request_type - if is_request_type_match: + is_rule_matched = self._is_matcher_in_filenames( + thing_to_match.client_type, thing_to_match.body + ) and self._is_request_type_match(thing_to_match.is_fim_request) + if is_rule_matched: logger.info( - "Request type rule matched", - matcher=self._matcher_blob, - request_type=incoming_request_type, + "Request type and rule matched", + matcher=self._mux_rule.matcher, + is_fim_request=thing_to_match.is_fim_request, ) - return is_request_type_match + return is_rule_matched class MuxingRulesinWorkspaces: diff --git a/tests/muxing/test_rulematcher.py b/tests/muxing/test_rulematcher.py index 4e489799..7340d983 100644 --- a/tests/muxing/test_rulematcher.py +++ b/tests/muxing/test_rulematcher.py @@ -47,20 +47,31 @@ def test_catch_all(matcher_blob, thing_to_match): @pytest.mark.parametrize( - "matcher_blob, filenames_to_match, expected_bool", + "matcher, filenames_to_match, expected_bool", [ - (None, [], False), # Empty filenames and no blob - (None, ["main.py"], False), # Empty blob + (None, [], True), # Empty filenames and no blob + (None, ["main.py"], True), # Empty blob should match (".py", ["main.py"], True), # Extension match ("main.py", ["main.py"], True), # Full name match (".py", ["main.py", "test.py"], True), # Extension match ("main.py", ["main.py", "test.py"], True), # Full name match ("main.py", ["test.py"], False), # Full name no match (".js", ["main.py", "test.py"], False), # Extension no match + (".ts", ["main.tsx", "test.tsx"], False), # Extension no match ], ) -def test_file_matcher(matcher_blob, filenames_to_match, expected_bool): - muxing_rule_matcher = rulematcher.FileMuxingRuleMatcher(mocked_route_openai, matcher_blob) +def test_file_matcher( + matcher, + filenames_to_match, + expected_bool, +): + mux_rule = mux_models.MuxRule( + provider_id="1", + model="fake-gpt", + matcher_type="filename_match", + matcher=matcher, + ) + muxing_rule_matcher = rulematcher.FileMuxingRuleMatcher(mocked_route_openai, mux_rule) # We mock the _extract_request_filenames method to return a list of filenames # The logic to get the correct filenames from snippets is tested in /tests/extract_snippets muxing_rule_matcher._extract_request_filenames = MagicMock(return_value=filenames_to_match) @@ -70,57 +81,97 @@ def test_file_matcher(matcher_blob, filenames_to_match, expected_bool): is_fim_request=False, client_type="generic", ) - assert muxing_rule_matcher.match(mocked_thing_to_match) == expected_bool + assert muxing_rule_matcher.match(mocked_thing_to_match) is expected_bool @pytest.mark.parametrize( - "matcher_blob, thing_to_match, expected_bool", + "matcher, filenames_to_match, expected_bool_filenames", [ - (None, None, False), # Empty blob - ( - "fim", - mux_models.ThingToMatchMux( - body={}, - url_request_path="/chat/completions", - is_fim_request=False, - client_type="generic", - ), - False, - ), # No match - ( - "fim", - mux_models.ThingToMatchMux( - body={}, - url_request_path="/chat/completions", - is_fim_request=True, - client_type="generic", - ), - True, - ), # Match - ( - "chat", - mux_models.ThingToMatchMux( - body={}, - url_request_path="/chat/completions", - is_fim_request=True, - client_type="generic", - ), - False, - ), # No match - ( - "chat", - mux_models.ThingToMatchMux( - body={}, - url_request_path="/chat/completions", - is_fim_request=False, - client_type="generic", - ), - True, - ), # Match + (None, [], True), # Empty filenames and no blob + (None, ["main.py"], True), # Empty blob should match + (".py", ["main.py"], True), # Extension match + ("main.py", ["main.py"], True), # Full name match + (".py", ["main.py", "test.py"], True), # Extension match + ("main.py", ["main.py", "test.py"], True), # Full name match + ("main.py", ["test.py"], False), # Full name no match + (".js", ["main.py", "test.py"], False), # Extension no match + (".ts", ["main.tsx", "test.tsx"], False), # Extension no match + ], +) +@pytest.mark.parametrize( + "is_fim_request, matcher_type, expected_bool_request", + [ + (False, "fim_filename", False), # No match + (True, "fim_filename", True), # Match + (False, "chat_filename", True), # Match + (True, "chat_filename", False), # No match + ], +) +def test_request_file_matcher( + matcher, + filenames_to_match, + expected_bool_filenames, + is_fim_request, + matcher_type, + expected_bool_request, +): + mux_rule = mux_models.MuxRule( + provider_id="1", + model="fake-gpt", + matcher_type=matcher_type, + matcher=matcher, + ) + muxing_rule_matcher = rulematcher.RequestTypeAndFileMuxingRuleMatcher( + mocked_route_openai, mux_rule + ) + # We mock the _extract_request_filenames method to return a list of filenames + # The logic to get the correct filenames from snippets is tested in /tests/extract_snippets + muxing_rule_matcher._extract_request_filenames = MagicMock(return_value=filenames_to_match) + mocked_thing_to_match = mux_models.ThingToMatchMux( + body={}, + url_request_path="/chat/completions", + is_fim_request=is_fim_request, + client_type="generic", + ) + assert ( + muxing_rule_matcher._is_request_type_match(mocked_thing_to_match.is_fim_request) + is expected_bool_request + ) + assert ( + muxing_rule_matcher._is_matcher_in_filenames( + mocked_thing_to_match.client_type, mocked_thing_to_match.body + ) + is expected_bool_filenames + ) + assert muxing_rule_matcher.match(mocked_thing_to_match) is ( + expected_bool_request and expected_bool_filenames + ) + + +@pytest.mark.parametrize( + "matcher_type, expected_class", + [ + (mux_models.MuxMatcherType.catch_all, rulematcher.CatchAllMuxingRuleMatcher), + (mux_models.MuxMatcherType.filename_match, rulematcher.FileMuxingRuleMatcher), + (mux_models.MuxMatcherType.fim_filename, rulematcher.RequestTypeAndFileMuxingRuleMatcher), + (mux_models.MuxMatcherType.chat_filename, rulematcher.RequestTypeAndFileMuxingRuleMatcher), + ("invalid_matcher", None), ], ) -def test_request_type(matcher_blob, thing_to_match, expected_bool): - muxing_rule_matcher = rulematcher.RequestTypeMuxingRuleMatcher( - mocked_route_openai, matcher_blob +def test_muxing_matcher_factory(matcher_type, expected_class): + mux_rule = db_models.MuxRule( + id="1", + provider_endpoint_id="1", + provider_model_name="fake-gpt", + workspace_id="1", + matcher_type=matcher_type, + matcher_blob="fake-matcher", + priority=1, ) - assert muxing_rule_matcher.match(thing_to_match) == expected_bool + if expected_class: + assert isinstance( + rulematcher.MuxingMatcherFactory.create(mux_rule, mocked_route_openai), expected_class + ) + else: + with pytest.raises(ValueError): + rulematcher.MuxingMatcherFactory.create(mux_rule, mocked_route_openai)