From fd79e034dfbef14c272a9420446794b64743dc20 Mon Sep 17 00:00:00 2001 From: MohammadMahdi Khalaj Date: Sat, 11 Nov 2023 16:06:22 +0330 Subject: [PATCH 1/5] Remove `$` from `newRoot` --- aggify/aggify.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aggify/aggify.py b/aggify/aggify.py index 5712d23..8cc416c 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -723,7 +723,7 @@ def replace_root( {key: mongoengine_fields.IntField() for key, value in merge.items()} ) else: - new_root = {"$replaceRoot": {"$newRoot": name}} + new_root = {"$replaceRoot": {"newRoot": name}} self.pipelines.append(new_root) return self From 6ee00c9fd1c957f640ab81eeefed4063fdc54c83 Mon Sep 17 00:00:00 2001 From: MohammadMahdi Khalaj Date: Sat, 11 Nov 2023 16:13:11 +0330 Subject: [PATCH 2/5] Add `get_nested_field_model` --- aggify/utilty.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/aggify/utilty.py b/aggify/utilty.py index 2cc8017..5b59343 100644 --- a/aggify/utilty.py +++ b/aggify/utilty.py @@ -1,8 +1,9 @@ from typing import Any, Union, List, Dict from mongoengine import Document -from aggify.types import CollectionType + from aggify.exceptions import MongoIndexError, InvalidField, AlreadyExistsField +from aggify.types import CollectionType def to_mongo_positive_index(index: Union[int, slice]) -> slice: @@ -116,7 +117,10 @@ def check_field_exists(model: CollectionType, field: str) -> None: Raises: AlreadyExistsField: If the field already exists in the model. """ - if model._fields.get(field): # noqa + if field in [ + f.db_field if hasattr(f, "db_field") else k + for k, f in model._fields.items() # noqa + ]: raise AlreadyExistsField(field=field) @@ -138,3 +142,29 @@ def get_db_field(model: CollectionType, field: str, add_dollar_sign=False) -> st return f"${db_field}" if add_dollar_sign else db_field except AttributeError: return field + + +def get_nested_field_model(model: CollectionType, field: str) -> CollectionType: + """ + Retrieves the nested field model for a specified field within a given model. + + This function examines the provided model to determine if the specified field is + a nested field. If it is, the function returns the nested field's model. + Otherwise, it returns the original model. + + Args: + model (CollectionType): The model to be inspected. This should be a class that + represents a collection or document in a database, typically + in an ORM or ODM framework. + field (str): The name of the field within the model to inspect for nestedness. + + Returns: + CollectionType: The model of the nested field if the specified field is nested; + otherwise, returns the original model. + + Raises: + KeyError: If the specified field is not found in the model. + """ + if model._fields[field].__dict__.get("__module__"): # noqa + return model + return model._fields[field].__dict__["document_type_obj"] # noqa From d5b2ed6324a704f2c129075a326b84e7b0414a81 Mon Sep 17 00:00:00 2001 From: MohammadMahdi Khalaj Date: Sat, 11 Nov 2023 16:14:23 +0330 Subject: [PATCH 3/5] Get nested fields db_name in `filter()` --- aggify/compiler.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/aggify/compiler.py b/aggify/compiler.py index 17a40a9..0f6752a 100644 --- a/aggify/compiler.py +++ b/aggify/compiler.py @@ -4,7 +4,7 @@ from mongoengine.base import TopLevelDocumentMetaclass from aggify.exceptions import InvalidOperator -from aggify.utilty import get_db_field +from aggify.utilty import get_db_field, get_nested_field_model class Operators: @@ -284,16 +284,20 @@ def compile(self, pipelines: list) -> Dict[str, Dict[str, list]]: match_query[key] = value continue - field, operator, *_ = key.split("__") + field, operator, *others = key.split("__") if ( self.is_base_model_field(field) and operator not in Operators.ALL_OPERATORS ): - pipelines.append( - Match({key.replace("__", ".", 1): value}, self.base_model).compile( - [] - ) + field_db_name = get_db_field(self.base_model, field) + + nested_field_name = get_db_field( + get_nested_field_model(self.base_model, field), operator ) + key = ( + f"{field_db_name}.{nested_field_name}__" + "__".join(others) + ).rstrip("__") + pipelines.append(Match({key: value}, self.base_model).compile([])) continue if operator not in Operators.ALL_OPERATORS: From 7c7681e2df155e1525c1bb6efe4bf1716a3343ed Mon Sep 17 00:00:00 2001 From: MohammadMahdi Khalaj Date: Sat, 11 Nov 2023 16:14:49 +0330 Subject: [PATCH 4/5] Update tests base on new changes --- tests/test_query.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_query.py b/tests/test_query.py index 86185b1..6a9d2f4 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -274,7 +274,7 @@ class ParameterTestCase: ), ParameterTestCase( compiled_query=(Aggify(PostDocument).replace_root(embedded_field="stat")), - expected_query=[{"$replaceRoot": {"$newRoot": "$stat"}}], + expected_query=[{"$replaceRoot": {"newRoot": "$stat"}}], ), ParameterTestCase( compiled_query=(Aggify(PostDocument).replace_with(embedded_field="stat")), @@ -525,9 +525,13 @@ class ParameterTestCase: "localField": "end", } }, - {"$replaceRoot": {"$newRoot": "$saved_post"}}, + {"$replaceRoot": {"newRoot": "$saved_post"}}, ], ), + ParameterTestCase( + compiled_query=(Aggify(PostDocument).filter(stat__like_count=2)), + expected_query=[{"$match": {"stat.like_count": 2}}], + ), ] From b9f5d33dbb3c341da5ee6f8ddc0eaa92079231ca Mon Sep 17 00:00:00 2001 From: MohammadMahdi Khalaj Date: Sat, 11 Nov 2023 16:35:50 +0330 Subject: [PATCH 5/5] Change some function names --- aggify/aggify.py | 8 ++++---- aggify/utilty.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/aggify/aggify.py b/aggify/aggify.py index 8cc416c..53e0f5e 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -17,10 +17,10 @@ from aggify.types import QueryParams, CollectionType from aggify.utilty import ( to_mongo_positive_index, - check_fields_exist, + validate_field_existence, replace_values_recursive, convert_match_query, - check_field_exists, + check_field_already_exists, get_db_field, ) @@ -538,7 +538,7 @@ def get_field_name_recursively(self, field: str) -> str: # Split the field based on double underscores and process each item for index, item in enumerate(field.split("__")): # Ensure the field exists at the current level of hierarchy - check_fields_exist(prev_base, [item]) # noqa + validate_field_existence(prev_base, [item]) # noqa # Append the database field name to the field_name list field_name.append(get_db_field(prev_base, item)) @@ -583,7 +583,7 @@ def lookup( """ lookup_stages = [] - check_field_exists(self.base_model, as_name) # noqa + check_field_already_exists(self.base_model, as_name) # noqa from_collection_name = from_collection._meta.get("collection") # noqa if not (let or raw_let) and not (local_field and foreign_field): diff --git a/aggify/utilty.py b/aggify/utilty.py index 5b59343..70d9a35 100644 --- a/aggify/utilty.py +++ b/aggify/utilty.py @@ -23,9 +23,9 @@ def to_mongo_positive_index(index: Union[int, slice]) -> slice: return index -def check_fields_exist(model: CollectionType, fields_to_check: List[str]) -> None: +def validate_field_existence(model: CollectionType, fields_to_check: List[str]) -> None: """ - Check if the specified fields exist in a model's fields. + The function checks a list of fields and raises an InvalidField exception if any are missing. Args: model: The model containing the fields to check. @@ -106,7 +106,7 @@ def convert_match_query( return d -def check_field_exists(model: CollectionType, field: str) -> None: +def check_field_already_exists(model: CollectionType, field: str) -> None: """ Check if a field exists in the given model.