From 161737acbff37c9004446b3c094a75feab7b9d3c Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Wed, 19 Feb 2025 17:00:07 +0200 Subject: [PATCH 1/2] Update the muxing rules to v3 Closes: #1060 Right now the muxing rules are designed to catch globally FIM or Chat requests. This PR extends its functionality to be able to match per file and request, i.e. this PR enables - Chat request of main.py -> model 1 - FIM request of main.py -> model 2 - Any type of v1.py -> model 3 --- ..._1452-5e5cd2288147_update_matcher_types.py | 79 +++++++++++ src/codegate/muxing/models.py | 35 ++++- src/codegate/muxing/router.py | 5 +- src/codegate/muxing/rulematcher.py | 94 ++++++------- tests/muxing/test_rulematcher.py | 126 +++++++----------- 5 files changed, 204 insertions(+), 135 deletions(-) create mode 100644 migrations/versions/2025_02_19_1452-5e5cd2288147_update_matcher_types.py diff --git a/migrations/versions/2025_02_19_1452-5e5cd2288147_update_matcher_types.py b/migrations/versions/2025_02_19_1452-5e5cd2288147_update_matcher_types.py new file mode 100644 index 00000000..15f96977 --- /dev/null +++ b/migrations/versions/2025_02_19_1452-5e5cd2288147_update_matcher_types.py @@ -0,0 +1,79 @@ +"""update matcher types + +Revision ID: 5e5cd2288147 +Revises: 0c3539f66339 +Create Date: 2025-02-19 14:52:39.126196+00:00 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "5e5cd2288147" +down_revision: Union[str, None] = "0c3539f66339" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Begin transaction + op.execute("BEGIN TRANSACTION;") + + # Update the matcher types. We need to do this every time we change the matcher types. + # in /muxing/models.py + op.execute( + """ + UPDATE muxes + SET matcher_type = 'fim', matcher_blob = '' + WHERE matcher_type = 'request_type_match' AND matcher_blob = 'fim'; + """ + ) + op.execute( + """ + UPDATE muxes + SET matcher_type = 'chat', matcher_blob = '' + WHERE matcher_type = 'request_type_match' AND matcher_blob = 'chat'; + """ + ) + op.execute( + """ + UPDATE muxes + SET matcher_type = 'catch_all' + WHERE matcher_type = 'filename_match' AND matcher_blob != ''; + """ + ) + + # Finish transaction + op.execute("COMMIT;") + + +def downgrade() -> None: + # Begin transaction + op.execute("BEGIN TRANSACTION;") + + op.execute( + """ + UPDATE muxes + SET matcher_blob = 'fim', matcher_type = 'request_type_match' + WHERE matcher_type = 'fim'; + """ + ) + op.execute( + """ + UPDATE muxes + SET matcher_blob = 'chat', matcher_type = 'request_type_match' + WHERE matcher_type = 'chat'; + """ + ) + op.execute( + """ + UPDATE muxes + SET matcher_type = 'filename_match', matcher_blob = 'catch_all' + WHERE matcher_type = 'catch_all'; + """ + ) + + # Finish transaction + op.execute("COMMIT;") diff --git a/src/codegate/muxing/models.py b/src/codegate/muxing/models.py index 4c822485..4a0b1d9c 100644 --- a/src/codegate/muxing/models.py +++ b/src/codegate/muxing/models.py @@ -1,24 +1,33 @@ from enum import Enum -from typing import Optional +from typing import Optional, Self import pydantic from codegate.clients.clients import ClientType +from codegate.db.models import MuxRule as DBMuxRule class MuxMatcherType(str, Enum): """ Represents the different types of matchers we support. + + The 3 rules present match filenames and request types. They're used in conjunction with the + matcher field in the MuxRule model. + E.g. + - catch_all and match: None -> Always match + - fim and match: requests.py -> Match the request if the filename is requests.py and FIM + - chat and match: None -> Match the request if it's a chat request + - chat and match: .js -> Match the request if the filename has a .js extension and is chat + + NOTE: Removing or updating fields from this enum will require a migration. """ # Always match this prompt catch_all = "catch_all" - # Match based on the filename. It will match if there is a filename - # in the request that matches the matcher either extension or full name (*.py or main.py) - filename_match = "filename_match" - # Match based on the request type. It will match if the request type - # matches the matcher (e.g. FIM or chat) - request_type_match = "request_type_match" + # Match based on fim request type. It will match if the request type is fim + fim = "fim" + # Match based on chat request type. It will match if the request type is chat + chat = "chat" class MuxRule(pydantic.BaseModel): @@ -36,6 +45,18 @@ class MuxRule(pydantic.BaseModel): # this depends on the matcher type. matcher: Optional[str] = None + @classmethod + def from_db_mux_rule(cls, db_mux_rule: DBMuxRule) -> Self: + """ + Convert a DBMuxRule to a MuxRule. + """ + return MuxRule( + provider_id=db_mux_rule.id, + model=db_mux_rule.provider_model_name, + matcher_type=db_mux_rule.matcher_type, + matcher=db_mux_rule.matcher_blob, + ) + class ThingToMatchMux(pydantic.BaseModel): """ diff --git a/src/codegate/muxing/router.py b/src/codegate/muxing/router.py index 4231e8e7..bfa9c663 100644 --- a/src/codegate/muxing/router.py +++ b/src/codegate/muxing/router.py @@ -50,8 +50,11 @@ async def _get_model_route( # Try to get a model route for the active workspace model_route = await mux_registry.get_match_for_active_workspace(thing_to_match) return model_route + except rulematcher.MuxMatchingError as e: + logger.exception(f"Error matching rule and getting model route: {e}") + raise HTTPException(detail=str(e), status_code=404) except Exception as e: - logger.error(f"Error getting active workspace muxes: {e}") + logger.exception(f"Error getting active workspace muxes: {e}") raise HTTPException(detail=str(e), status_code=404) def _setup_routes(self): diff --git a/src/codegate/muxing/rulematcher.py b/src/codegate/muxing/rulematcher.py index fb3c1da2..2fbec39b 100644 --- a/src/codegate/muxing/rulematcher.py +++ b/src/codegate/muxing/rulematcher.py @@ -18,6 +18,12 @@ _singleton_lock = Lock() +class MuxMatchingError(Exception): + """An exception for muxing matching errors.""" + + pass + + async def get_muxing_rules_registry(): """Returns a singleton instance of the muxing rules registry.""" @@ -48,9 +54,9 @@ def __init__( class MuxingRuleMatcher(ABC): """Base class for matching muxing rules.""" - def __init__(self, route: ModelRoute, matcher_blob: str): + def __init__(self, route: ModelRoute, mux_rule: mux_models.MuxRule): self._route = route - self._matcher_blob = matcher_blob + self._mux_rule = mux_rule @abstractmethod def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: @@ -67,32 +73,24 @@ class MuxingMatcherFactory: """Factory for creating muxing matchers.""" @staticmethod - def create(mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher: + def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher: """Create a muxing matcher for the given endpoint and model.""" factory: Dict[mux_models.MuxMatcherType, MuxingRuleMatcher] = { - mux_models.MuxMatcherType.catch_all: CatchAllMuxingRuleMatcher, - mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher, - mux_models.MuxMatcherType.request_type_match: RequestTypeMuxingRuleMatcher, + mux_models.MuxMatcherType.catch_all: RequestTypeAndFileMuxingRuleMatcher, + mux_models.MuxMatcherType.fim: RequestTypeAndFileMuxingRuleMatcher, + mux_models.MuxMatcherType.chat: RequestTypeAndFileMuxingRuleMatcher, } try: # Initialize the MuxingRuleMatcher - return factory[mux_rule.matcher_type](route, mux_rule.matcher_blob) + mux_rule = mux_models.MuxRule.from_db_mux_rule(db_mux_rule) + return factory[mux_rule.matcher_type](route, mux_rule) except KeyError: raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}") -class CatchAllMuxingRuleMatcher(MuxingRuleMatcher): - """A catch all muxing rule matcher.""" - - def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: - logger.info("Catch all rule matched") - return True - - -class FileMuxingRuleMatcher(MuxingRuleMatcher): - """A file muxing rule matcher.""" +class RequestTypeAndFileMuxingRuleMatcher(MuxingRuleMatcher): def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]: """ @@ -103,47 +101,51 @@ def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> return body_extractor.extract_unique_filenames(data) except BodyCodeSnippetExtractorError as e: logger.error(f"Error extracting filenames from request: {e}") - return set() + raise MuxMatchingError("Error extracting filenames from request") - def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: + def _is_matcher_in_filenames(self, detected_client: ClientType, data: dict) -> bool: """ - Retun True if there is a filename in the request that matches the matcher_blob. - The matcher_blob is either an extension (e.g. .py) or a filename (e.g. main.py). + Check if the matcher is in the request filenames. """ - # If there is no matcher_blob, we don't match - if not self._matcher_blob: - return False - filenames_to_match = self._extract_request_filenames( - thing_to_match.client_type, thing_to_match.body + # Empty matcher_blob means we match everything + if not self._mux_rule.matcher: + return True + filenames_to_match = self._extract_request_filenames(detected_client, data) + # _mux_rule.matcher can be a filename or a file extension. We match if any of the filenames + # match the rule. + is_filename_match = any( + self._mux_rule.matcher == filename or filename.endswith(self._mux_rule.matcher) + for filename in filenames_to_match ) - is_filename_match = any(self._matcher_blob in filename for filename in filenames_to_match) - if is_filename_match: - logger.info( - "Filename rule matched", filenames=filenames_to_match, matcher=self._matcher_blob - ) return is_filename_match - -class RequestTypeMuxingRuleMatcher(MuxingRuleMatcher): - """A catch all muxing rule matcher.""" + def _is_request_type_match(self, is_fim_request: bool) -> bool: + """ + Check if the request type matches the MuxMatcherType. + """ + # Catch all rule matches both chat and FIM requests + if self._mux_rule.matcher_type == mux_models.MuxMatcherType.catch_all: + return True + incoming_request_type = "fim" if is_fim_request else "chat" + if incoming_request_type == self._mux_rule.matcher_type: + return True + return False def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: """ - Return True if the request type matches the matcher_blob. - The matcher_blob is either "fim" or "chat". + Return True if the matcher is in one of the request filenames and + if the request type matches the MuxMatcherType. """ - # If there is no matcher_blob, we don't match - if not self._matcher_blob: - return False - incoming_request_type = "fim" if thing_to_match.is_fim_request else "chat" - is_request_type_match = self._matcher_blob == incoming_request_type - if is_request_type_match: + is_rule_matched = self._is_matcher_in_filenames( + thing_to_match.client_type, thing_to_match.body + ) and self._is_request_type_match(thing_to_match.is_fim_request) + if is_rule_matched: logger.info( - "Request type rule matched", - matcher=self._matcher_blob, - request_type=incoming_request_type, + "Request type and rule matched", + matcher=self._mux_rule.matcher, + is_fim_request=thing_to_match.is_fim_request, ) - return is_request_type_match + return is_rule_matched class MuxingRulesinWorkspaces: diff --git a/tests/muxing/test_rulematcher.py b/tests/muxing/test_rulematcher.py index 4e489799..9cde0179 100644 --- a/tests/muxing/test_rulematcher.py +++ b/tests/muxing/test_rulematcher.py @@ -25,102 +25,66 @@ @pytest.mark.parametrize( - "matcher_blob, thing_to_match", + "matcher, filenames_to_match, expected_bool_filenames", [ - (None, None), - ("fake-matcher-blob", None), - ( - "fake-matcher-blob", - mux_models.ThingToMatchMux( - body={}, - url_request_path="/chat/completions", - is_fim_request=False, - client_type="generic", - ), - ), - ], -) -def test_catch_all(matcher_blob, thing_to_match): - muxing_rule_matcher = rulematcher.CatchAllMuxingRuleMatcher(mocked_route_openai, matcher_blob) - # It should always match - assert muxing_rule_matcher.match(thing_to_match) is True - - -@pytest.mark.parametrize( - "matcher_blob, filenames_to_match, expected_bool", - [ - (None, [], False), # Empty filenames and no blob - (None, ["main.py"], False), # Empty blob + (None, [], True), # Empty filenames and no blob + (None, ["main.py"], True), # Empty blob should match (".py", ["main.py"], True), # Extension match ("main.py", ["main.py"], True), # Full name match (".py", ["main.py", "test.py"], True), # Extension match ("main.py", ["main.py", "test.py"], True), # Full name match ("main.py", ["test.py"], False), # Full name no match (".js", ["main.py", "test.py"], False), # Extension no match + (".ts", ["main.tsx", "test.tsx"], False), # Extension no match + ], +) +@pytest.mark.parametrize( + "is_fim_request, matcher_type, expected_bool_request", + [ + (False, "fim", False), # No match + (True, "fim", True), # Match + (False, "chat", True), # Match + (True, "chat", False), # No match + (True, "catch_all", True), # Match + (False, "catch_all", True), # Match ], ) -def test_file_matcher(matcher_blob, filenames_to_match, expected_bool): - muxing_rule_matcher = rulematcher.FileMuxingRuleMatcher(mocked_route_openai, matcher_blob) +def test_file_matcher( + matcher, + filenames_to_match, + expected_bool_filenames, + is_fim_request, + matcher_type, + expected_bool_request, +): + mux_rule = mux_models.MuxRule( + provider_id="1", + model="fake-gpt", + matcher_type=matcher_type, + matcher=matcher, + ) + muxing_rule_matcher = rulematcher.RequestTypeAndFileMuxingRuleMatcher( + mocked_route_openai, mux_rule + ) # We mock the _extract_request_filenames method to return a list of filenames # The logic to get the correct filenames from snippets is tested in /tests/extract_snippets muxing_rule_matcher._extract_request_filenames = MagicMock(return_value=filenames_to_match) mocked_thing_to_match = mux_models.ThingToMatchMux( body={}, url_request_path="/chat/completions", - is_fim_request=False, + is_fim_request=is_fim_request, client_type="generic", ) - assert muxing_rule_matcher.match(mocked_thing_to_match) == expected_bool - - -@pytest.mark.parametrize( - "matcher_blob, thing_to_match, expected_bool", - [ - (None, None, False), # Empty blob - ( - "fim", - mux_models.ThingToMatchMux( - body={}, - url_request_path="/chat/completions", - is_fim_request=False, - client_type="generic", - ), - False, - ), # No match - ( - "fim", - mux_models.ThingToMatchMux( - body={}, - url_request_path="/chat/completions", - is_fim_request=True, - client_type="generic", - ), - True, - ), # Match - ( - "chat", - mux_models.ThingToMatchMux( - body={}, - url_request_path="/chat/completions", - is_fim_request=True, - client_type="generic", - ), - False, - ), # No match - ( - "chat", - mux_models.ThingToMatchMux( - body={}, - url_request_path="/chat/completions", - is_fim_request=False, - client_type="generic", - ), - True, - ), # Match - ], -) -def test_request_type(matcher_blob, thing_to_match, expected_bool): - muxing_rule_matcher = rulematcher.RequestTypeMuxingRuleMatcher( - mocked_route_openai, matcher_blob + assert ( + muxing_rule_matcher._is_request_type_match(mocked_thing_to_match.is_fim_request) + is expected_bool_request + ) + assert ( + muxing_rule_matcher._is_matcher_in_filenames( + mocked_thing_to_match.client_type, mocked_thing_to_match.body + ) + is expected_bool_filenames + ) + assert muxing_rule_matcher.match(mocked_thing_to_match) is ( + expected_bool_request and expected_bool_filenames ) - assert muxing_rule_matcher.match(thing_to_match) == expected_bool From 34f4f027ba3b9ac53d54a1fe888700f6b87c1ca4 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Fri, 21 Feb 2025 11:30:24 +0200 Subject: [PATCH 2/2] updated matcher types to preserve old types --- ..._1452-5e5cd2288147_update_matcher_types.py | 18 +--- src/codegate/muxing/models.py | 16 +-- src/codegate/muxing/rulematcher.py | 38 +++++-- tests/muxing/test_rulematcher.py | 101 ++++++++++++++++-- 4 files changed, 136 insertions(+), 37 deletions(-) diff --git a/migrations/versions/2025_02_19_1452-5e5cd2288147_update_matcher_types.py b/migrations/versions/2025_02_19_1452-5e5cd2288147_update_matcher_types.py index 15f96977..138ae7ee 100644 --- a/migrations/versions/2025_02_19_1452-5e5cd2288147_update_matcher_types.py +++ b/migrations/versions/2025_02_19_1452-5e5cd2288147_update_matcher_types.py @@ -26,24 +26,17 @@ def upgrade() -> None: op.execute( """ UPDATE muxes - SET matcher_type = 'fim', matcher_blob = '' + SET matcher_type = 'fim_filename', matcher_blob = '' WHERE matcher_type = 'request_type_match' AND matcher_blob = 'fim'; """ ) op.execute( """ UPDATE muxes - SET matcher_type = 'chat', matcher_blob = '' + SET matcher_type = 'chat_filename', matcher_blob = '' WHERE matcher_type = 'request_type_match' AND matcher_blob = 'chat'; """ ) - op.execute( - """ - UPDATE muxes - SET matcher_type = 'catch_all' - WHERE matcher_type = 'filename_match' AND matcher_blob != ''; - """ - ) # Finish transaction op.execute("COMMIT;") @@ -67,13 +60,6 @@ def downgrade() -> None: WHERE matcher_type = 'chat'; """ ) - op.execute( - """ - UPDATE muxes - SET matcher_type = 'filename_match', matcher_blob = 'catch_all' - WHERE matcher_type = 'catch_all'; - """ - ) # Finish transaction op.execute("COMMIT;") diff --git a/src/codegate/muxing/models.py b/src/codegate/muxing/models.py index 4a0b1d9c..5637c5b8 100644 --- a/src/codegate/muxing/models.py +++ b/src/codegate/muxing/models.py @@ -14,20 +14,24 @@ class MuxMatcherType(str, Enum): The 3 rules present match filenames and request types. They're used in conjunction with the matcher field in the MuxRule model. E.g. - - catch_all and match: None -> Always match - - fim and match: requests.py -> Match the request if the filename is requests.py and FIM - - chat and match: None -> Match the request if it's a chat request - - chat and match: .js -> Match the request if the filename has a .js extension and is chat + - catch_all-> Always match + - filename_match and match: requests.py -> Match the request if the filename is requests.py + - fim_filename and match: main.py -> Match the request if the request type is fim + and the filename is main.py NOTE: Removing or updating fields from this enum will require a migration. + Adding new fields is safe. """ # Always match this prompt catch_all = "catch_all" + # Match based on the filename. It will match if there is a filename + # in the request that matches the matcher either extension or full name (*.py or main.py) + filename_match = "filename_match" # Match based on fim request type. It will match if the request type is fim - fim = "fim" + fim_filename = "fim_filename" # Match based on chat request type. It will match if the request type is chat - chat = "chat" + chat_filename = "chat_filename" class MuxRule(pydantic.BaseModel): diff --git a/src/codegate/muxing/rulematcher.py b/src/codegate/muxing/rulematcher.py index 2fbec39b..247e6c12 100644 --- a/src/codegate/muxing/rulematcher.py +++ b/src/codegate/muxing/rulematcher.py @@ -77,9 +77,10 @@ def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatch """Create a muxing matcher for the given endpoint and model.""" factory: Dict[mux_models.MuxMatcherType, MuxingRuleMatcher] = { - mux_models.MuxMatcherType.catch_all: RequestTypeAndFileMuxingRuleMatcher, - mux_models.MuxMatcherType.fim: RequestTypeAndFileMuxingRuleMatcher, - mux_models.MuxMatcherType.chat: RequestTypeAndFileMuxingRuleMatcher, + mux_models.MuxMatcherType.catch_all: CatchAllMuxingRuleMatcher, + mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher, + mux_models.MuxMatcherType.fim_filename: RequestTypeAndFileMuxingRuleMatcher, + mux_models.MuxMatcherType.chat_filename: RequestTypeAndFileMuxingRuleMatcher, } try: @@ -90,7 +91,16 @@ def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatch raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}") -class RequestTypeAndFileMuxingRuleMatcher(MuxingRuleMatcher): +class CatchAllMuxingRuleMatcher(MuxingRuleMatcher): + """A catch all muxing rule matcher.""" + + def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: + logger.info("Catch all rule matched") + return True + + +class FileMuxingRuleMatcher(MuxingRuleMatcher): + """A file muxing rule matcher.""" def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]: """ @@ -119,14 +129,26 @@ def _is_matcher_in_filenames(self, detected_client: ClientType, data: dict) -> b ) return is_filename_match + def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: + """ + Return True if the matcher is in one of the request filenames. + """ + is_rule_matched = self._is_matcher_in_filenames( + thing_to_match.client_type, thing_to_match.body + ) + if is_rule_matched: + logger.info("Filename rule matched", matcher=self._mux_rule.matcher) + return is_rule_matched + + +class RequestTypeAndFileMuxingRuleMatcher(FileMuxingRuleMatcher): + """A request type and file muxing rule matcher.""" + def _is_request_type_match(self, is_fim_request: bool) -> bool: """ Check if the request type matches the MuxMatcherType. """ - # Catch all rule matches both chat and FIM requests - if self._mux_rule.matcher_type == mux_models.MuxMatcherType.catch_all: - return True - incoming_request_type = "fim" if is_fim_request else "chat" + incoming_request_type = "fim_filename" if is_fim_request else "chat_filename" if incoming_request_type == self._mux_rule.matcher_type: return True return False diff --git a/tests/muxing/test_rulematcher.py b/tests/muxing/test_rulematcher.py index 9cde0179..7340d983 100644 --- a/tests/muxing/test_rulematcher.py +++ b/tests/muxing/test_rulematcher.py @@ -24,6 +24,66 @@ ) +@pytest.mark.parametrize( + "matcher_blob, thing_to_match", + [ + (None, None), + ("fake-matcher-blob", None), + ( + "fake-matcher-blob", + mux_models.ThingToMatchMux( + body={}, + url_request_path="/chat/completions", + is_fim_request=False, + client_type="generic", + ), + ), + ], +) +def test_catch_all(matcher_blob, thing_to_match): + muxing_rule_matcher = rulematcher.CatchAllMuxingRuleMatcher(mocked_route_openai, matcher_blob) + # It should always match + assert muxing_rule_matcher.match(thing_to_match) is True + + +@pytest.mark.parametrize( + "matcher, filenames_to_match, expected_bool", + [ + (None, [], True), # Empty filenames and no blob + (None, ["main.py"], True), # Empty blob should match + (".py", ["main.py"], True), # Extension match + ("main.py", ["main.py"], True), # Full name match + (".py", ["main.py", "test.py"], True), # Extension match + ("main.py", ["main.py", "test.py"], True), # Full name match + ("main.py", ["test.py"], False), # Full name no match + (".js", ["main.py", "test.py"], False), # Extension no match + (".ts", ["main.tsx", "test.tsx"], False), # Extension no match + ], +) +def test_file_matcher( + matcher, + filenames_to_match, + expected_bool, +): + mux_rule = mux_models.MuxRule( + provider_id="1", + model="fake-gpt", + matcher_type="filename_match", + matcher=matcher, + ) + muxing_rule_matcher = rulematcher.FileMuxingRuleMatcher(mocked_route_openai, mux_rule) + # We mock the _extract_request_filenames method to return a list of filenames + # The logic to get the correct filenames from snippets is tested in /tests/extract_snippets + muxing_rule_matcher._extract_request_filenames = MagicMock(return_value=filenames_to_match) + mocked_thing_to_match = mux_models.ThingToMatchMux( + body={}, + url_request_path="/chat/completions", + is_fim_request=False, + client_type="generic", + ) + assert muxing_rule_matcher.match(mocked_thing_to_match) is expected_bool + + @pytest.mark.parametrize( "matcher, filenames_to_match, expected_bool_filenames", [ @@ -41,15 +101,13 @@ @pytest.mark.parametrize( "is_fim_request, matcher_type, expected_bool_request", [ - (False, "fim", False), # No match - (True, "fim", True), # Match - (False, "chat", True), # Match - (True, "chat", False), # No match - (True, "catch_all", True), # Match - (False, "catch_all", True), # Match + (False, "fim_filename", False), # No match + (True, "fim_filename", True), # Match + (False, "chat_filename", True), # Match + (True, "chat_filename", False), # No match ], ) -def test_file_matcher( +def test_request_file_matcher( matcher, filenames_to_match, expected_bool_filenames, @@ -88,3 +146,32 @@ def test_file_matcher( assert muxing_rule_matcher.match(mocked_thing_to_match) is ( expected_bool_request and expected_bool_filenames ) + + +@pytest.mark.parametrize( + "matcher_type, expected_class", + [ + (mux_models.MuxMatcherType.catch_all, rulematcher.CatchAllMuxingRuleMatcher), + (mux_models.MuxMatcherType.filename_match, rulematcher.FileMuxingRuleMatcher), + (mux_models.MuxMatcherType.fim_filename, rulematcher.RequestTypeAndFileMuxingRuleMatcher), + (mux_models.MuxMatcherType.chat_filename, rulematcher.RequestTypeAndFileMuxingRuleMatcher), + ("invalid_matcher", None), + ], +) +def test_muxing_matcher_factory(matcher_type, expected_class): + mux_rule = db_models.MuxRule( + id="1", + provider_endpoint_id="1", + provider_model_name="fake-gpt", + workspace_id="1", + matcher_type=matcher_type, + matcher_blob="fake-matcher", + priority=1, + ) + if expected_class: + assert isinstance( + rulematcher.MuxingMatcherFactory.create(mux_rule, mocked_route_openai), expected_class + ) + else: + with pytest.raises(ValueError): + rulematcher.MuxingMatcherFactory.create(mux_rule, mocked_route_openai)