diff --git a/aggify/aggify.py b/aggify/aggify.py index d0ef553..a48b5b7 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -1,7 +1,7 @@ import functools -from typing import Any, Dict, Type, Union, List, Callable, TypeVar +from typing import Any, Dict, Type, Union, List, TypeVar, Callable -from mongoengine import Document, EmbeddedDocument, fields +from mongoengine import Document, EmbeddedDocument, fields as mongoengine_fields from mongoengine.base import TopLevelDocumentMetaclass from aggify.compiler import F, Match, Q, Operators, Cond # noqa keep @@ -13,7 +13,7 @@ OutStageError, InvalidArgument, ) -from aggify.types import QueryParams, CollectionType +from aggify.types import QueryParams from aggify.utilty import ( to_mongo_positive_index, check_fields_exist, @@ -24,6 +24,7 @@ ) AggifyType = TypeVar("AggifyType", bound=Callable[..., "Aggify"]) +CollectionType = TypeVar("CollectionType", bound=Callable[..., "Document"]) def last_out_stage_check(method: AggifyType) -> AggifyType: @@ -85,25 +86,23 @@ def project(self, **kwargs: QueryParams) -> "Aggify": """ # Extract fields to keep and check if _id should be deleted - to_keep_values = ["id"] - delete_id = kwargs.get("id") == 0 + to_keep_values = {"id"} + delete_id = kwargs.get("id") is not None projection = {} # Add missing fields to the base model for key, value in kwargs.items(): if value == 1: - to_keep_values.append(key) + to_keep_values.add(key) elif key not in self.base_model._fields and isinstance( # noqa kwargs[key], (str, dict) - ): # noqa - to_keep_values.append(key) - self.base_model._fields[key] = fields.IntField() # noqa + ): + to_keep_values.add(key) + self.base_model._fields[key] = mongoengine_fields.IntField() # noqa projection[get_db_field(self.base_model, key)] = value # noqa # Remove fields from the base model, except the ones in to_keep_values and possibly _id - keys_for_deletion = set(self.base_model._fields.keys()) - set( # noqa - to_keep_values - ) # noqa + keys_for_deletion = self.base_model._fields.keys() - to_keep_values # noqa if delete_id: keys_for_deletion.add("id") for key in keys_for_deletion: @@ -144,19 +143,18 @@ def raw(self, raw_query: dict) -> "Aggify": return self @last_out_stage_check - def add_fields(self, **_fields) -> "Aggify": # noqa - """ - Generates a MongoDB addFields pipeline stage. + def add_fields(self, **fields) -> "Aggify": # noqa + """Generates a MongoDB addFields pipeline stage. Args: - _fields: A dictionary of field expressions and values. + fields: A dictionary of field expressions and values. Returns: A MongoDB add_fields pipeline stage. """ add_fields_stage = {"$addFields": {}} - for field, expression in _fields.items(): + for field, expression in fields.items(): field = field.replace("__", ".") if isinstance(expression, str): add_fields_stage["$addFields"][field] = {"$literal": expression} @@ -169,7 +167,9 @@ def add_fields(self, **_fields) -> "Aggify": # noqa else: raise AggifyValueError([str, F, list], type(expression)) # TODO: Should be checked if new field is embedded, create embedded field. - self.base_model._fields[field.replace("$", "")] = fields.IntField() # noqa + self.base_model._fields[ + field.replace("$", "") + ] = mongoengine_fields.IntField() # noqa self.pipelines.append(add_fields_stage) return self @@ -240,7 +240,6 @@ def __to_aggregate(self, query: Dict[str, Any]) -> None: """ for key, value in query.items(): - # Split the key to access the field information. split_query = key.split("__") @@ -254,8 +253,8 @@ def __to_aggregate(self, query: Dict[str, Any]) -> None: or "document_type_obj" not in join_field.__dict__ # Check whether this field is a join field or not. or issubclass( - join_field.document_type, EmbeddedDocument # noqa - ) # Check whether this field is embedded field or not + join_field.document_type, EmbeddedDocument # noqa + ) # Check whether this field is embedded field or not or len(split_query) == 1 or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS) ): @@ -377,30 +376,47 @@ def annotate( # Some of the accumulator fields might be false and should be checked. aggregation_mapping: Dict[str, Type] = { - "sum": (fields.FloatField(), "$sum"), - "avg": (fields.FloatField(), "$avg"), - "stdDevPop": (fields.FloatField(), "$stdDevPop"), - "stdDevSamp": (fields.FloatField(), "$stdDevSamp"), # noqa - "push": (fields.ListField(), "$push"), - "addToSet": (fields.ListField(), "$addToSet"), - "count": (fields.IntField(), "$count"), - "first": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$first"), - "last": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$last"), - "max": (fields.DynamicField(), "$max"), - "accumulator": (fields.DynamicField(), "$accumulator"), - "min": (fields.DynamicField(), "$min"), - "median": (fields.DynamicField(), "$median"), - "mergeObjects": (fields.DictField(), "$mergeObjects"), - "top": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$top"), + "sum": (mongoengine_fields.FloatField(), "$sum"), + "avg": (mongoengine_fields.FloatField(), "$avg"), + "stdDevPop": (mongoengine_fields.FloatField(), "$stdDevPop"), + "stdDevSamp": (mongoengine_fields.FloatField(), "$stdDevSamp"), + "push": (mongoengine_fields.ListField(), "$push"), + "addToSet": (mongoengine_fields.ListField(), "$addToSet"), + "count": (mongoengine_fields.IntField(), "$count"), + "first": ( + mongoengine_fields.EmbeddedDocumentField( + mongoengine_fields.EmbeddedDocument + ), + "$first", + ), + "last": ( + mongoengine_fields.EmbeddedDocumentField( + mongoengine_fields.EmbeddedDocument + ), + "$last", + ), + "max": (mongoengine_fields.DynamicField(), "$max"), + "accumulator": (mongoengine_fields.DynamicField(), "$accumulator"), + "min": (mongoengine_fields.DynamicField(), "$min"), + "median": (mongoengine_fields.DynamicField(), "$median"), + "mergeObjects": (mongoengine_fields.DictField(), "$mergeObjects"), + "top": ( + mongoengine_fields.EmbeddedDocumentField( + mongoengine_fields.EmbeddedDocument + ), + "$top", + ), "bottom": ( - fields.EmbeddedDocumentField(fields.EmbeddedDocument), + mongoengine_fields.EmbeddedDocumentField( + mongoengine_fields.EmbeddedDocument + ), "$bottom", ), - "topN": (fields.ListField(), "$topN"), - "bottomN": (fields.ListField(), "$bottomN"), - "firstN": (fields.ListField(), "$firstN"), - "lastN": (fields.ListField(), "$lastN"), - "maxN": (fields.ListField(), "$maxN"), + "topN": (mongoengine_fields.ListField(), "$topN"), + "bottomN": (mongoengine_fields.ListField(), "$bottomN"), + "firstN": (mongoengine_fields.ListField(), "$firstN"), + "lastN": (mongoengine_fields.ListField(), "$lastN"), + "maxN": (mongoengine_fields.ListField(), "$maxN"), } try: @@ -539,19 +555,25 @@ def lookup( foreign_field: Union[str, None] = None, ) -> "Aggify": """ - Generates a MongoDB lookup pipeline stage. + Generates a MongoDB lookup pipeline stage. - Args: - from_collection (Document): The document representing the collection to perform the lookup on. - as_name (str): The name of the new field to create. - query (list[Q] | Union[Q, None], optional): List of desired queries with Q function or a single query. - let (Union[List[str], None], optional): The local field(s) to join on. If provided, - localField and foreignField are not used. - local_field (Union[str, None], optional): The local field to join on when let not provided. - foreign_field (Union[str, None], optional): The foreign field to join on when let not provided. + Args: + from_collection (Document): The document representing the collection to perform the lookup on. + as_name (str): The name of the new field to create. + query (list[Q] | Union[Q, None], optional): List of desired queries with Q function or a single query. + <<<<<<< HEAD + let (Union[List[str], None], optional): The local field(s) to join on. If provided, localField and foreignField are not used. + local_field (Union[str, None], optional): The local field to join on when `let` is not provided. + foreign_field (Union[str, None], optional): The foreign field to join on when `let` is not provided. + ======= + let (Union[List[str], None], optional): The local field(s) to join on. If provided, + localField and foreignField are not used. + local_field (Union[str, None], optional): The local field to join on when let not provided. + foreign_field (Union[str, None], optional): The foreign field to join on when let not provided. + >>>>>>> main - Returns: - Aggify: An instance of the Aggify class representing a MongoDB lookup pipeline stage. + Returns: + Aggify: An instance of the Aggify class representing a MongoDB lookup pipeline stage. """ lookup_stages = [] @@ -617,7 +639,7 @@ def lookup( return self @staticmethod - def get_model_field(model: CollectionType, field: str) -> fields: + def get_model_field(model: Type[Document], field: str) -> mongoengine_fields: """ Get the field definition of a specified field in a MongoDB model. diff --git a/aggify/compiler.py b/aggify/compiler.py index 71866f1..e0de1b4 100644 --- a/aggify/compiler.py +++ b/aggify/compiler.py @@ -82,11 +82,6 @@ def compile_match(self, operator: str, value, field: str): } else: self.match_query[field] = {Operators.ALL_OPERATORS[operator]: value} - else: - # Default behavior - self.match_query[field] = { - Operators.ALL_OPERATORS.get(operator, operator): value - } return self.match_query @@ -107,7 +102,7 @@ def __iter__(self): yield "$match", self.conditions def __or__(self, other): - if self.conditions.get("$or", None): + if self.conditions.get("$or"): self.conditions["$or"].append(dict(other)["$match"]) combined_conditions = self.conditions @@ -116,7 +111,7 @@ def __or__(self, other): return Q(**combined_conditions) def __and__(self, other): - if self.conditions.get("$and", None): + if self.conditions.get("$and"): self.conditions["$and"].append(dict(other)["$match"]) combined_conditions = self.conditions else: @@ -138,7 +133,7 @@ def __init__(self, field: Union[str, Dict[str, list]]): def to_dict(self): return self.field - def __add__(self, other): # TODO: add type for 'other' + def __add__(self, other): if isinstance(other, F): other = other.field @@ -248,7 +243,12 @@ def __init__( @staticmethod def validate_operator(key: str): - operator = key.rsplit("__", 1)[1] + _op = key.rsplit("__", 1) + try: + operator = _op[1] + except IndexError: + raise InvalidOperator(_op) from None + if operator not in Operators.COMPARISON_OPERATORS: raise InvalidOperator(operator) @@ -275,15 +275,15 @@ def is_base_model_field(self, field) -> bool: def compile(self, pipelines: list) -> Dict[str, Dict[str, list]]: match_query = {} for key, value in self.matches.items(): + if isinstance(value, F): + if F.is_suitable_for_match(key) is False: + raise InvalidOperator(key) + if "__" not in key: key = get_db_field(self.base_model, key) match_query[key] = value continue - if isinstance(value, F): - if F.is_suitable_for_match(key) is False: - raise InvalidOperator(key) - field, operator, *_ = key.split("__") if ( self.is_base_model_field(field) diff --git a/aggify/types.py b/aggify/types.py index db405d9..a9b3a11 100644 --- a/aggify/types.py +++ b/aggify/types.py @@ -1,7 +1,5 @@ -from typing import Union, Dict, TypeVar, Callable -from mongoengine import Document +from typing import Union, Dict + from bson import ObjectId QueryParams = Union[int, None, str, bool, float, Dict, ObjectId] - -CollectionType = TypeVar("CollectionType", bound=Callable[..., Document]) diff --git a/aggify/utilty.py b/aggify/utilty.py index ea056ce..1fceb69 100644 --- a/aggify/utilty.py +++ b/aggify/utilty.py @@ -3,7 +3,6 @@ from mongoengine import Document from aggify.exceptions import MongoIndexError, InvalidField, AlreadyExistsField -from aggify.types import CollectionType def to_mongo_positive_index(index: Union[int, slice]) -> slice: @@ -23,7 +22,7 @@ def to_mongo_positive_index(index: Union[int, slice]) -> slice: return index -def check_fields_exist(model: CollectionType, fields_to_check: List[str]) -> None: +def check_fields_exist(model, fields_to_check: List[str]) -> None: """ Check if the specified fields exist in a model's fields. @@ -106,7 +105,7 @@ def convert_match_query( return d -def check_field_exists(model: CollectionType, field: str) -> None: +def check_field_exists(model, field: str) -> None: """ Check if a field exists in the given model. @@ -121,7 +120,7 @@ def check_field_exists(model: CollectionType, field: str) -> None: raise AlreadyExistsField(field=field) -def get_db_field(model: CollectionType, field: str, add_dollar_sign=False) -> str: +def get_db_field(model, field: str, add_dollar_sign=False) -> str: """ Get the database field name for a given field in the model. diff --git a/tests/test_aggify.py b/tests/test_aggify.py index 8b1735e..f9af0de 100644 --- a/tests/test_aggify.py +++ b/tests/test_aggify.py @@ -1,5 +1,5 @@ import pytest -from mongoengine import Document, IntField, StringField +from mongoengine import Document, IntField, StringField, UUIDField from aggify import Aggify, Cond, F, Q from aggify.exceptions import ( @@ -51,6 +51,12 @@ def test_filtering_and_projection(self): assert len(aggify.pipelines) == 2 assert aggify.pipelines[1]["$project"] == {"name": 1, "age": 1} + def test_filtering_and_projection_with_deleting_id(self): + aggify = Aggify(BaseModel) + aggify.filter(age__gte=30).project(name=1, age=1, id=0) + assert len(aggify.pipelines) == 2 + assert aggify.pipelines[1]["$project"] == {"_id": 0, "name": 1, "age": 1} + def test_filtering_and_ordering(self): aggify = Aggify(BaseModel) aggify.filter(age__gte=30).order_by("-age") @@ -136,35 +142,25 @@ def test_add_field_value_error(self): } aggify.add_fields(**fields) - def test_add_fields_string_literal(self): - aggify = Aggify(BaseModel) - fields = {"new_field_1": "some_string", "new_field_2": "another_string"} - add_fields_stage = aggify.add_fields(**fields) - - expected_stage = { - "$addFields": { - "new_field_1": {"$literal": "some_string"}, - "new_field_2": {"$literal": "another_string"}, - } - } - - assert add_fields_stage.pipelines[0] == expected_stage - - def test_add_fields_with_f_expression(self): + @pytest.mark.parametrize( + ("fields", "expected"), + ( + ({"scores__age": "Mahdi"}, {"scores.age": {"$literal": "Mahdi"}}), + ( + {"new_field": F("existing_field") + 10}, + {"new_field": {"$add": ["$existing_field", 10]}}, + ), + ({"array": [1, 2, 3, 4]}, {"array": [1, 2, 3, 4]}), + ( + {"cond": Cond(30, "==", 30, "Equal", "Not Equal")}, + {"cond": dict(Cond(30, "==", 30, "Equal", "Not Equal"))}, + ), + ), + ) + def test_add_fields(self, fields, expected): aggify = Aggify(BaseModel) - fields = { - "new_field_1": F("existing_field") + 10, - "new_field_2": F("field_a") * F("field_b"), - } add_fields_stage = aggify.add_fields(**fields) - - expected_stage = { - "$addFields": { - "new_field_1": {"$add": ["$existing_field", 10]}, - "new_field_2": {"$multiply": ["$field_a", "$field_b"]}, - } - } - assert add_fields_stage.pipelines[0] == expected_stage + assert add_fields_stage.pipelines[0]["$addFields"] == expected def test_filter_value_error(self): with pytest.raises(AggifyValueError): diff --git a/tests/test_f.py b/tests/test_f.py index fec4371..f69d230 100644 --- a/tests/test_f.py +++ b/tests/test_f.py @@ -2,70 +2,70 @@ class TestF: - # Test subtraction using F class def test_subtraction(self): f1 = F("age") f2 = F("income") - f_combined = f1 - f2 - assert f_combined.to_dict() == {"$subtract": ["$age", "$income"]} + f3 = F("x") + f_combined = f1 - f2 - f3 + assert f_combined.to_dict() == {"$subtract": ["$age", "$income", "$x"]} - # Test division using F class def test_division(self): f1 = F("income") f2 = F("expenses") - f_combined = f1 / f2 - assert f_combined.to_dict() == {"$divide": ["$income", "$expenses"]} + f3 = F("x") + f_combined = f1 / f2 / f3 + assert f_combined.to_dict() == {"$divide": ["$income", "$expenses", "$x"]} - # Test multiplication using F class def test_multiplication(self): f1 = F("quantity") f2 = F("price") - f_combined = f1 * f2 - assert f_combined.to_dict() == {"$multiply": ["$quantity", "$price"]} + f3 = F("x") + f_combined = f1 * f2 * f3 + assert f_combined.to_dict() == {"$multiply": ["$quantity", "$price", "$x"]} + + def test_addition(self): + f1 = F("income") + f2 = F("interest") + f3 = F("x") + f = f1 + f2 + f3 + assert f.to_dict() == {"$add": ["$income", "$interest", "$x"]} - # Test addition using F class with a constant def test_addition_with_constant(self): f1 = F("age") constant = 10 f_combined = f1 + constant assert f_combined.to_dict() == {"$add": ["$age", 10]} - # Test subtraction using F class with a constant def test_subtraction_with_constant(self): f1 = F("income") constant = 5000 f_combined = f1 - constant assert f_combined.to_dict() == {"$subtract": ["$income", 5000]} - # Test division using F class with a constant def test_division_with_constant(self): f1 = F("price") constant = 2 f_combined = f1 / constant assert f_combined.to_dict() == {"$divide": ["$price", 2]} - # Test multiplication using F class with a constant def test_multiplication_with_constant(self): f1 = F("quantity") constant = 5 f_combined = f1 * constant assert f_combined.to_dict() == {"$multiply": ["$quantity", 5]} - # Test addition using F class with multiple fields and constants def test_addition_with_multiple_fields(self): f1 = F("age") f2 = F("income") f_combined = f1 + f2 assert f_combined.to_dict() == {"$add": ["$age", "$income"]} - # Test subtraction using F class with multiple fields and constants def test_subtraction_with_multiple_fields(self): f1 = F("income") f2 = F("expenses") f_combined = f1 - f2 assert f_combined.to_dict() == {"$subtract": ["$income", "$expenses"]} - # Test multiplication using F class with multiple fields and constants def test_multiplication_with_multiple_fields(self): f1 = F("quantity") f2 = F("price") diff --git a/tests/test_match.py b/tests/test_match.py new file mode 100644 index 0000000..6e5f4ca --- /dev/null +++ b/tests/test_match.py @@ -0,0 +1,14 @@ +import pytest + +from aggify.compiler import Match +from aggify.exceptions import InvalidOperator + + +def test_validate_operator_fail(): + with pytest.raises(InvalidOperator): + Match.validate_operator("key_raise") + + +def test_validate_operator_fail_not_in_operators(): + with pytest.raises(InvalidOperator): + Match.validate_operator("key__ge") diff --git a/tests/test_q.py b/tests/test_q.py index e3bf06d..4d69969 100644 --- a/tests/test_q.py +++ b/tests/test_q.py @@ -1,4 +1,8 @@ -from aggify import Q +import pytest + +from aggify import Q, F, Aggify +from aggify.exceptions import InvalidOperator +from tests.test_aggify import BaseModel class TestQ: @@ -22,6 +26,25 @@ def test_or_operator_with_multiple_conditions_more_than_rwo(self): } } + def test_and(self): + q1 = Q(name="Mahdi") + q2 = Q(age__gt=20) + q = q1 & q2 + + assert dict(q) == {"$match": {"$and": [dict(q1)["$match"], dict(q2)["$match"]]}} + + def test_multiple_and(self): + q1 = Q(name="Mahdi") + q2 = Q(age__gt=20) + q3 = Q(age__lt=30) + q = q1 & q2 & q3 + + assert dict(q) == { + "$match": { + "$and": [dict(q1)["$match"], dict(q2)["$match"], dict(q3)["$match"]] + } + } + # Test combining NOT operators with AND def test_combine_not_operators_with_and(self): q1 = Q(name="John") @@ -43,3 +66,7 @@ def test_combine_not_operators_with_or(self): "$or": [{"$not": [dict(q1)["$match"]]}, {"$not": [dict(q2)["$match"]]}] } } + + def test_unsuitable_key_for_f(self): + with pytest.raises(InvalidOperator): + Q(Aggify(BaseModel).filter(age__gt=20).pipelines, age_gt=F("income") * 2)