From 4cb0fb572569e7d95fefa0f5c9e13a0c562f86f8 Mon Sep 17 00:00:00 2001 From: MohammadMahdi Khalaj Date: Sat, 11 Nov 2023 17:01:22 +0330 Subject: [PATCH] Validate fields in lookup with local and foreign fields --- aggify/aggify.py | 20 +++++++++++--------- tests/test_query.py | 4 ++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/aggify/aggify.py b/aggify/aggify.py index 53e0f5e..d45fc48 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -517,7 +517,9 @@ def __combine_sequential_matches(self) -> List[Dict[str, Union[dict, Any]]]: return merged_pipeline # check_fields_exist(self.base_model, let) # noqa - def get_field_name_recursively(self, field: str) -> str: + def get_field_name_recursively( + self, field: str, base: Union[CollectionType, None] = None + ) -> str: """ Recursively fetch the field name by following the hierarchy indicated by the field parameter. @@ -533,19 +535,19 @@ def get_field_name_recursively(self, field: str) -> str: """ field_name = [] - prev_base = self.base_model + base = self.base_model if not base else base # Split the field based on double underscores and process each item for index, item in enumerate(field.split("__")): # Ensure the field exists at the current level of hierarchy - validate_field_existence(prev_base, [item]) # noqa + validate_field_existence(base, [item]) # noqa # Append the database field name to the field_name list - field_name.append(get_db_field(prev_base, item)) + field_name.append(get_db_field(base, item)) # Move to the next level in the model hierarchy - prev_base = self.get_model_field(prev_base, item) - prev_base = prev_base.__dict__.get("document_type_obj", prev_base) + base = self.get_model_field(base, item) + base = base.__dict__.get("document_type_obj", base) # Join the entire hierarchy using dots and return return ".".join(field_name) @@ -594,9 +596,9 @@ def lookup( lookup_stage = { "$lookup": { "from": from_collection_name, - "localField": get_db_field(self.base_model, local_field), # noqa - "foreignField": get_db_field( - from_collection, foreign_field # noqa + "localField": self.get_field_name_recursively(local_field), # noqa + "foreignField": self.get_field_name_recursively( + base=from_collection, field=foreign_field # noqa ), "as": as_name, } diff --git a/tests/test_query.py b/tests/test_query.py index 6a9d2f4..014b6fb 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -510,7 +510,7 @@ class ParameterTestCase: Aggify(PostDocument) .lookup( PostDocument, - local_field="end", + local_field="stat", foreign_field="id", as_name="saved_post", ) @@ -522,7 +522,7 @@ class ParameterTestCase: "as": "saved_post", "foreignField": "_id", "from": "post_document", - "localField": "end", + "localField": "stat", } }, {"$replaceRoot": {"newRoot": "$saved_post"}},