From 7c7dbbd2f2ba82cebde60cfcb05b5812979e4ed4 Mon Sep 17 00:00:00 2001 From: Mahdi Haghverdi Date: Sun, 5 Nov 2023 11:36:42 +0330 Subject: [PATCH 01/13] Move types to types.py Signed-off-by: mahdihaghverdi --- aggify/aggify.py | 13 ++++--------- aggify/types.py | 4 +++- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/aggify/aggify.py b/aggify/aggify.py index 8ed4e16..837a5a4 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Dict, Type, Union, List, Callable, TypeVar +from typing import Any, Dict, Type, Union, List from mongoengine import Document, EmbeddedDocument, fields from mongoengine.base import TopLevelDocumentMetaclass @@ -13,7 +13,7 @@ OutStageError, InvalidArgument, ) -from aggify.types import QueryParams +from aggify.types import QueryParams, AggifyType, CollectionType from aggify.utilty import ( to_mongo_positive_index, check_fields_exist, @@ -24,11 +24,6 @@ ) -AggifyType = TypeVar('AggifyType', bound=Callable[..., "Aggify"]) -CollectionType = TypeVar('CollectionType', bound=Callable[..., "Document"]) - - - def last_out_stage_check(method: AggifyType) -> AggifyType: """Check if the last stage is $out or not @@ -555,8 +550,8 @@ def lookup( as_name (str): The name of the new field to create. query (list[Q] | Union[Q, None], optional): List of desired queries with Q function or a single query. let (Union[List[str], None], optional): The local field(s) to join on. If provided, localField and foreignField are not used. - local_field (Union[str, None], optional): The local field to join on when let is not provided. - foreign_field (Union[str, None], optional): The foreign field to join on when let is not provided. + local_field (Union[str, None], optional): The local field to join on when `let` is not provided. + foreign_field (Union[str, None], optional): The foreign field to join on when `let` is not provided. Returns: Aggify: An instance of the Aggify class representing a MongoDB lookup pipeline stage. diff --git a/aggify/types.py b/aggify/types.py index a9b3a11..227b9d6 100644 --- a/aggify/types.py +++ b/aggify/types.py @@ -1,5 +1,7 @@ -from typing import Union, Dict +from typing import Union, Dict, TypeVar, Callable from bson import ObjectId QueryParams = Union[int, None, str, bool, float, Dict, ObjectId] +AggifyType = TypeVar('AggifyType', bound=Callable[..., "Aggify"]) +CollectionType = TypeVar('CollectionType', bound=Callable[..., "Document"]) From 159332e43da637014d992d768f38e940ef0e6b49 Mon Sep 17 00:00:00 2001 From: mahdihaghverdi Date: Sun, 5 Nov 2023 16:18:22 +0330 Subject: [PATCH 02/13] Add a test for id deletion in `project` Signed-off-by: mahdihaghverdi --- aggify/aggify.py | 14 ++++++-------- aggify/types.py | 4 ++-- tests/test_aggify.py | 8 +++++++- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/aggify/aggify.py b/aggify/aggify.py index 837a5a4..40576ae 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -83,25 +83,23 @@ def project(self, **kwargs: QueryParams) -> "Aggify": """ # Extract fields to keep and check if _id should be deleted - to_keep_values = ["id"] - delete_id = kwargs.get("id") == 0 + to_keep_values = {"id"} + delete_id = kwargs.get("id") is not None projection = {} # Add missing fields to the base model for key, value in kwargs.items(): if value == 1: - to_keep_values.append(key) + to_keep_values.add(key) elif key not in self.base_model._fields and isinstance( # noqa kwargs[key], (str, dict) - ): # noqa - to_keep_values.append(key) + ): + to_keep_values.add(key) self.base_model._fields[key] = fields.IntField() # noqa projection[get_db_field(self.base_model, key)] = value # noqa # Remove fields from the base model, except the ones in to_keep_values and possibly _id - keys_for_deletion = set(self.base_model._fields.keys()) - set( # noqa - to_keep_values - ) # noqa + keys_for_deletion = self.base_model._fields.keys() - to_keep_values # noqa if delete_id: keys_for_deletion.add("id") for key in keys_for_deletion: diff --git a/aggify/types.py b/aggify/types.py index 227b9d6..9cd099e 100644 --- a/aggify/types.py +++ b/aggify/types.py @@ -3,5 +3,5 @@ from bson import ObjectId QueryParams = Union[int, None, str, bool, float, Dict, ObjectId] -AggifyType = TypeVar('AggifyType', bound=Callable[..., "Aggify"]) -CollectionType = TypeVar('CollectionType', bound=Callable[..., "Document"]) +AggifyType = TypeVar("AggifyType", bound=Callable[..., "Aggify"]) +CollectionType = TypeVar("CollectionType", bound=Callable[..., "Document"]) diff --git a/tests/test_aggify.py b/tests/test_aggify.py index cc77afb..63d800f 100644 --- a/tests/test_aggify.py +++ b/tests/test_aggify.py @@ -1,5 +1,5 @@ import pytest -from mongoengine import Document, IntField, StringField +from mongoengine import Document, IntField, StringField, UUIDField from aggify import Aggify, Cond, F, Q from aggify.exceptions import ( @@ -46,6 +46,12 @@ def test_filtering_and_projection(self): assert len(aggify.pipelines) == 2 assert aggify.pipelines[1]["$project"] == {"name": 1, "age": 1} + def test_filtering_and_projection_with_deleting_id(self): + aggify = Aggify(BaseModel) + aggify.filter(age__gte=30).project(name=1, age=1, id=0) + assert len(aggify.pipelines) == 2 + assert aggify.pipelines[1]["$project"] == {"_id": 0, "name": 1, "age": 1} + def test_filtering_and_ordering(self): aggify = Aggify(BaseModel) aggify.filter(age__gte=30).order_by("-age") From be301b6c463a390ce8b9fd763b4d62dd9c110210 Mon Sep 17 00:00:00 2001 From: mahdihaghverdi Date: Sun, 5 Nov 2023 16:40:50 +0330 Subject: [PATCH 03/13] refactor the `fields` to `mongoengine_fields` Signed-off-by: mahdihaghverdi --- aggify/aggify.py | 53 ++++++++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/aggify/aggify.py b/aggify/aggify.py index 40576ae..c76f684 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -1,7 +1,7 @@ import functools from typing import Any, Dict, Type, Union, List -from mongoengine import Document, EmbeddedDocument, fields +from mongoengine import Document, EmbeddedDocument, fields as mongoengine_fields from mongoengine.base import TopLevelDocumentMetaclass from aggify.compiler import F, Match, Q, Operators, Cond # noqa keep @@ -95,7 +95,7 @@ def project(self, **kwargs: QueryParams) -> "Aggify": kwargs[key], (str, dict) ): to_keep_values.add(key) - self.base_model._fields[key] = fields.IntField() # noqa + self.base_model._fields[key] = mongoengine_fields.IntField() # noqa projection[get_db_field(self.base_model, key)] = value # noqa # Remove fields from the base model, except the ones in to_keep_values and possibly _id @@ -135,8 +135,7 @@ def raw(self, raw_query: dict) -> "Aggify": @last_out_stage_check def add_fields(self, **_fields) -> "Aggify": # noqa - """ - Generates a MongoDB addFields pipeline stage. + """Generates a MongoDB addFields pipeline stage. Args: _fields: A dictionary of field expressions and values. @@ -159,7 +158,7 @@ def add_fields(self, **_fields) -> "Aggify": # noqa else: raise AggifyValueError([str, F, list], type(expression)) # TODO: Should be checked if new field is embedded, create embedded field. - self.base_model._fields[field.replace("$", "")] = fields.IntField() # noqa + self.base_model._fields[field.replace("$", "")] = mongoengine_fields.IntField() # noqa self.pipelines.append(add_fields_stage) return self @@ -379,30 +378,30 @@ def annotate( # Some of the accumulator fields might be false and should be checked. aggregation_mapping: Dict[str, Type] = { - "sum": (fields.FloatField(), "$sum"), - "avg": (fields.FloatField(), "$avg"), - "stdDevPop": (fields.FloatField(), "$stdDevPop"), - "stdDevSamp": (fields.FloatField(), "$stdDevSamp"), - "push": (fields.ListField(), "$push"), - "addToSet": (fields.ListField(), "$addToSet"), - "count": (fields.IntField(), "$count"), - "first": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$first"), - "last": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$last"), - "max": (fields.DynamicField(), "$max"), - "accumulator": (fields.DynamicField(), "$accumulator"), - "min": (fields.DynamicField(), "$min"), - "median": (fields.DynamicField(), "$median"), - "mergeObjects": (fields.DictField(), "$mergeObjects"), - "top": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$top"), + "sum": (mongoengine_fields.FloatField(), "$sum"), + "avg": (mongoengine_fields.FloatField(), "$avg"), + "stdDevPop": (mongoengine_fields.FloatField(), "$stdDevPop"), + "stdDevSamp": (mongoengine_fields.FloatField(), "$stdDevSamp"), + "push": (mongoengine_fields.ListField(), "$push"), + "addToSet": (mongoengine_fields.ListField(), "$addToSet"), + "count": (mongoengine_fields.IntField(), "$count"), + "first": (mongoengine_fields.EmbeddedDocumentField(mongoengine_fields.EmbeddedDocument), "$first"), + "last": (mongoengine_fields.EmbeddedDocumentField(mongoengine_fields.EmbeddedDocument), "$last"), + "max": (mongoengine_fields.DynamicField(), "$max"), + "accumulator": (mongoengine_fields.DynamicField(), "$accumulator"), + "min": (mongoengine_fields.DynamicField(), "$min"), + "median": (mongoengine_fields.DynamicField(), "$median"), + "mergeObjects": (mongoengine_fields.DictField(), "$mergeObjects"), + "top": (mongoengine_fields.EmbeddedDocumentField(mongoengine_fields.EmbeddedDocument), "$top"), "bottom": ( - fields.EmbeddedDocumentField(fields.EmbeddedDocument), + mongoengine_fields.EmbeddedDocumentField(mongoengine_fields.EmbeddedDocument), "$bottom", ), - "topN": (fields.ListField(), "$topN"), - "bottomN": (fields.ListField(), "$bottomN"), - "firstN": (fields.ListField(), "$firstN"), - "lastN": (fields.ListField(), "$lastN"), - "maxN": (fields.ListField(), "$maxN"), + "topN": (mongoengine_fields.ListField(), "$topN"), + "bottomN": (mongoengine_fields.ListField(), "$bottomN"), + "firstN": (mongoengine_fields.ListField(), "$firstN"), + "lastN": (mongoengine_fields.ListField(), "$lastN"), + "maxN": (mongoengine_fields.ListField(), "$maxN"), } try: @@ -620,7 +619,7 @@ def lookup( return self @staticmethod - def get_model_field(model: Type[Document], field: str) -> fields: + def get_model_field(model: Type[Document], field: str) -> mongoengine_fields: """ Get the field definition of a specified field in a MongoDB model. From 47f97d1ca41b87b3c8c90eb57b905c93677c6962 Mon Sep 17 00:00:00 2001 From: mahdihaghverdi Date: Sun, 5 Nov 2023 16:46:31 +0330 Subject: [PATCH 04/13] refactor the `_fields` parameter to `fields` Signed-off-by: mahdihaghverdi --- aggify/aggify.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aggify/aggify.py b/aggify/aggify.py index c76f684..8863a45 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -134,18 +134,18 @@ def raw(self, raw_query: dict) -> "Aggify": return self @last_out_stage_check - def add_fields(self, **_fields) -> "Aggify": # noqa + def add_fields(self, **fields) -> "Aggify": # noqa """Generates a MongoDB addFields pipeline stage. Args: - _fields: A dictionary of field expressions and values. + fields: A dictionary of field expressions and values. Returns: A MongoDB add_fields pipeline stage. """ add_fields_stage = {"$addFields": {}} - for field, expression in _fields.items(): + for field, expression in fields.items(): field = field.replace("__", ".") if isinstance(expression, str): add_fields_stage["$addFields"][field] = {"$literal": expression} From 12a0b713f56ebe87726d39b6e09cfa27c1f17621 Mon Sep 17 00:00:00 2001 From: mahdihaghverdi Date: Sun, 5 Nov 2023 16:48:31 +0330 Subject: [PATCH 05/13] parametrize the `add_fields` tests and cover two more lines Signed-off-by: mahdihaghverdi --- tests/test_aggify.py | 50 ++++++++++++++++++++------------------------ 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/tests/test_aggify.py b/tests/test_aggify.py index 63d800f..d282dd0 100644 --- a/tests/test_aggify.py +++ b/tests/test_aggify.py @@ -137,35 +137,31 @@ def test_add_field_value_error(self): } aggify.add_fields(**fields) - def test_add_fields_string_literal(self): - aggify = Aggify(BaseModel) - fields = {"new_field_1": "some_string", "new_field_2": "another_string"} - add_fields_stage = aggify.add_fields(**fields) - - expected_stage = { - "$addFields": { - "new_field_1": {"$literal": "some_string"}, - "new_field_2": {"$literal": "another_string"}, - } - } - - assert add_fields_stage.pipelines[0] == expected_stage - - def test_add_fields_with_f_expression(self): + @pytest.mark.parametrize( + ('fields', 'expected'), + ( + ( + {"scores__age": "Mahdi"}, + {"scores.age": {"$literal": "Mahdi"}} + ), + ( + {"new_field": F("existing_field") + 10}, + {"new_field": {"$add": ["$existing_field", 10]}} + ), + ( + {'array': [1, 2, 3, 4]}, + {'array': [1, 2, 3, 4]} + ), + ( + {"cond": Cond(30, "==", 30, "Equal", "Not Equal")}, + {"cond": dict(Cond(30, "==", 30, "Equal", "Not Equal"))} + ) + ) + ) + def test_add_fields(self, fields, expected): aggify = Aggify(BaseModel) - fields = { - "new_field_1": F("existing_field") + 10, - "new_field_2": F("field_a") * F("field_b"), - } add_fields_stage = aggify.add_fields(**fields) - - expected_stage = { - "$addFields": { - "new_field_1": {"$add": ["$existing_field", 10]}, - "new_field_2": {"$multiply": ["$field_a", "$field_b"]}, - } - } - assert add_fields_stage.pipelines[0] == expected_stage + assert add_fields_stage.pipelines[0]['$addFields'] == expected def test_filter_value_error(self): with pytest.raises(AggifyValueError): From d5ed7026e350001b50b75ac81085378a6f5c3967 Mon Sep 17 00:00:00 2001 From: mahdihaghverdi Date: Sun, 5 Nov 2023 19:47:44 +0330 Subject: [PATCH 06/13] Add test for multiple ands Signed-off-by: mahdihaghverdi --- aggify/aggify.py | 1 + aggify/compiler.py | 4 ++-- tests/test_q.py | 23 +++++++++++++++++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/aggify/aggify.py b/aggify/aggify.py index 0bf4d22..9584239 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -24,6 +24,7 @@ get_db_field, ) + def last_out_stage_check(method: AggifyType) -> AggifyType: """Check if the last stage is $out or not diff --git a/aggify/compiler.py b/aggify/compiler.py index 71866f1..bb8ae18 100644 --- a/aggify/compiler.py +++ b/aggify/compiler.py @@ -107,7 +107,7 @@ def __iter__(self): yield "$match", self.conditions def __or__(self, other): - if self.conditions.get("$or", None): + if self.conditions.get("$or"): self.conditions["$or"].append(dict(other)["$match"]) combined_conditions = self.conditions @@ -116,7 +116,7 @@ def __or__(self, other): return Q(**combined_conditions) def __and__(self, other): - if self.conditions.get("$and", None): + if self.conditions.get("$and"): self.conditions["$and"].append(dict(other)["$match"]) combined_conditions = self.conditions else: diff --git a/tests/test_q.py b/tests/test_q.py index e3bf06d..6ac105e 100644 --- a/tests/test_q.py +++ b/tests/test_q.py @@ -22,6 +22,29 @@ def test_or_operator_with_multiple_conditions_more_than_rwo(self): } } + def test_and(self): + q1 = Q(name="Mahdi") + q2 = Q(age__gt=20) + q = q1 & q2 + + assert dict(q) == { + "$match": { + "$and": [dict(q1)["$match"], dict(q2)["$match"]] + } + } + + def test_multiple_and(self): + q1 = Q(name="Mahdi") + q2 = Q(age__gt=20) + q3 = Q(age__lt=30) + q = q1 & q2 & q3 + + assert dict(q) == { + "$match": { + "$and": [dict(q1)["$match"], dict(q2)["$match"], dict(q3)['$match']] + } + } + # Test combining NOT operators with AND def test_combine_not_operators_with_and(self): q1 = Q(name="John") From bf74b32e06149d0e7160c3b1d51a9875fbb592ae Mon Sep 17 00:00:00 2001 From: mahdihaghverdi Date: Sun, 5 Nov 2023 20:19:11 +0330 Subject: [PATCH 07/13] Add tests to cover the `if` in dunders of `F` Signed-off-by: mahdihaghverdi --- aggify/compiler.py | 10 +++++----- tests/test_f.py | 32 ++++++++++++++++---------------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/aggify/compiler.py b/aggify/compiler.py index bb8ae18..2e9b40e 100644 --- a/aggify/compiler.py +++ b/aggify/compiler.py @@ -138,7 +138,7 @@ def __init__(self, field: Union[str, Dict[str, list]]): def to_dict(self): return self.field - def __add__(self, other): # TODO: add type for 'other' + def __add__(self, other): if isinstance(other, F): other = other.field @@ -275,15 +275,15 @@ def is_base_model_field(self, field) -> bool: def compile(self, pipelines: list) -> Dict[str, Dict[str, list]]: match_query = {} for key, value in self.matches.items(): + if isinstance(value, F): + if F.is_suitable_for_match(key) is False: + raise InvalidOperator(key) + if "__" not in key: key = get_db_field(self.base_model, key) match_query[key] = value continue - if isinstance(value, F): - if F.is_suitable_for_match(key) is False: - raise InvalidOperator(key) - field, operator, *_ = key.split("__") if ( self.is_base_model_field(field) diff --git a/tests/test_f.py b/tests/test_f.py index fec4371..05cfcbb 100644 --- a/tests/test_f.py +++ b/tests/test_f.py @@ -2,70 +2,70 @@ class TestF: - # Test subtraction using F class def test_subtraction(self): f1 = F("age") f2 = F("income") - f_combined = f1 - f2 - assert f_combined.to_dict() == {"$subtract": ["$age", "$income"]} + f3 = F('x') + f_combined = f1 - f2 - f3 + assert f_combined.to_dict() == {"$subtract": ["$age", "$income", "$x"]} - # Test division using F class def test_division(self): f1 = F("income") f2 = F("expenses") - f_combined = f1 / f2 - assert f_combined.to_dict() == {"$divide": ["$income", "$expenses"]} + f3 = F('x') + f_combined = f1 / f2 / f3 + assert f_combined.to_dict() == {"$divide": ["$income", "$expenses", "$x"]} - # Test multiplication using F class def test_multiplication(self): f1 = F("quantity") f2 = F("price") - f_combined = f1 * f2 - assert f_combined.to_dict() == {"$multiply": ["$quantity", "$price"]} + f3 = F('x') + f_combined = f1 * f2 * f3 + assert f_combined.to_dict() == {"$multiply": ["$quantity", "$price", "$x"]} + + def test_addition(self): + f1 = F('income') + f2 = F('interest') + f3 = F('x') + f = f1 + f2 + f3 + assert f.to_dict() == {"$add": ['$income', "$interest", "$x"]} - # Test addition using F class with a constant def test_addition_with_constant(self): f1 = F("age") constant = 10 f_combined = f1 + constant assert f_combined.to_dict() == {"$add": ["$age", 10]} - # Test subtraction using F class with a constant def test_subtraction_with_constant(self): f1 = F("income") constant = 5000 f_combined = f1 - constant assert f_combined.to_dict() == {"$subtract": ["$income", 5000]} - # Test division using F class with a constant def test_division_with_constant(self): f1 = F("price") constant = 2 f_combined = f1 / constant assert f_combined.to_dict() == {"$divide": ["$price", 2]} - # Test multiplication using F class with a constant def test_multiplication_with_constant(self): f1 = F("quantity") constant = 5 f_combined = f1 * constant assert f_combined.to_dict() == {"$multiply": ["$quantity", 5]} - # Test addition using F class with multiple fields and constants def test_addition_with_multiple_fields(self): f1 = F("age") f2 = F("income") f_combined = f1 + f2 assert f_combined.to_dict() == {"$add": ["$age", "$income"]} - # Test subtraction using F class with multiple fields and constants def test_subtraction_with_multiple_fields(self): f1 = F("income") f2 = F("expenses") f_combined = f1 - f2 assert f_combined.to_dict() == {"$subtract": ["$income", "$expenses"]} - # Test multiplication using F class with multiple fields and constants def test_multiplication_with_multiple_fields(self): f1 = F("quantity") f2 = F("price") From db38b18b5b0a2bcaccf3cca779d867de3ba0c9cc Mon Sep 17 00:00:00 2001 From: mahdihaghverdi Date: Sun, 5 Nov 2023 20:19:55 +0330 Subject: [PATCH 08/13] test returning `False` of `F.is_suitable_for_match` Signed-off-by: mahdihaghverdi --- tests/test_q.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_q.py b/tests/test_q.py index 6ac105e..dc7d8dc 100644 --- a/tests/test_q.py +++ b/tests/test_q.py @@ -1,4 +1,8 @@ -from aggify import Q +import pytest + +from aggify import Q, F, Aggify +from aggify.exceptions import InvalidOperator +from tests.test_aggify import BaseModel class TestQ: @@ -66,3 +70,7 @@ def test_combine_not_operators_with_or(self): "$or": [{"$not": [dict(q1)["$match"]]}, {"$not": [dict(q2)["$match"]]}] } } + + def test_unsuitable_key_for_f(self): + with pytest.raises(InvalidOperator): + Q(Aggify(BaseModel).filter(age__gt=20).pipelines, age_gt=F("income") * 2) From 2ab73674a70f4a312327ab1ee898d4bcf9c8ad57 Mon Sep 17 00:00:00 2001 From: mahdihaghverdi Date: Sun, 5 Nov 2023 20:25:50 +0330 Subject: [PATCH 09/13] add tests fot `Match.validate_operator` Signed-off-by: mahdihaghverdi --- aggify/compiler.py | 7 ++++++- tests/test_match.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 tests/test_match.py diff --git a/aggify/compiler.py b/aggify/compiler.py index 2e9b40e..e9618c6 100644 --- a/aggify/compiler.py +++ b/aggify/compiler.py @@ -248,7 +248,12 @@ def __init__( @staticmethod def validate_operator(key: str): - operator = key.rsplit("__", 1)[1] + _op = key.rsplit("__", 1) + try: + operator = _op[1] + except IndexError: + raise InvalidOperator(_op) from None + if operator not in Operators.COMPARISON_OPERATORS: raise InvalidOperator(operator) diff --git a/tests/test_match.py b/tests/test_match.py new file mode 100644 index 0000000..2e6439e --- /dev/null +++ b/tests/test_match.py @@ -0,0 +1,14 @@ +import pytest + +from aggify.compiler import Match +from aggify.exceptions import InvalidOperator + + +def test_validate_operator_fail(): + with pytest.raises(InvalidOperator): + Match.validate_operator('key_raise') + + +def test_validate_operator_fail_not_in_operators(): + with pytest.raises(InvalidOperator): + Match.validate_operator('key__ge') From e08ca86a82f0d252fed189168af44cbe07fbf0eb Mon Sep 17 00:00:00 2001 From: mahdihaghverdi Date: Sun, 5 Nov 2023 20:42:17 +0330 Subject: [PATCH 10/13] test cov is now 100 percent Signed-off-by: mahdihaghverdi --- aggify/compiler.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/aggify/compiler.py b/aggify/compiler.py index e9618c6..e0de1b4 100644 --- a/aggify/compiler.py +++ b/aggify/compiler.py @@ -82,11 +82,6 @@ def compile_match(self, operator: str, value, field: str): } else: self.match_query[field] = {Operators.ALL_OPERATORS[operator]: value} - else: - # Default behavior - self.match_query[field] = { - Operators.ALL_OPERATORS.get(operator, operator): value - } return self.match_query From 8b5e5e7c0dd5b8deffbbfbc14f8c8ed45d3cc480 Mon Sep 17 00:00:00 2001 From: mahdihaghverdi Date: Sun, 5 Nov 2023 21:03:29 +0330 Subject: [PATCH 11/13] reformat the source code Signed-off-by: mahdihaghverdi --- aggify/aggify.py | 68 ++++++++++++++++++++++++++++---------------- tests/test_aggify.py | 22 ++++++-------- tests/test_f.py | 14 ++++----- tests/test_match.py | 4 +-- tests/test_q.py | 8 ++---- 5 files changed, 62 insertions(+), 54 deletions(-) diff --git a/aggify/aggify.py b/aggify/aggify.py index 9584239..f252926 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -165,7 +165,9 @@ def add_fields(self, **fields) -> "Aggify": # noqa else: raise AggifyValueError([str, F, list], type(expression)) # TODO: Should be checked if new field is embedded, create embedded field. - self.base_model._fields[field.replace("$", "")] = mongoengine_fields.IntField() # noqa + self.base_model._fields[ + field.replace("$", "") + ] = mongoengine_fields.IntField() # noqa self.pipelines.append(add_fields_stage) return self @@ -236,7 +238,6 @@ def __to_aggregate(self, query: Dict[str, Any]) -> None: """ for key, value in query.items(): - # Split the key to access the field information. split_query = key.split("__") @@ -380,16 +381,33 @@ def annotate( "push": (mongoengine_fields.ListField(), "$push"), "addToSet": (mongoengine_fields.ListField(), "$addToSet"), "count": (mongoengine_fields.IntField(), "$count"), - "first": (mongoengine_fields.EmbeddedDocumentField(mongoengine_fields.EmbeddedDocument), "$first"), - "last": (mongoengine_fields.EmbeddedDocumentField(mongoengine_fields.EmbeddedDocument), "$last"), + "first": ( + mongoengine_fields.EmbeddedDocumentField( + mongoengine_fields.EmbeddedDocument + ), + "$first", + ), + "last": ( + mongoengine_fields.EmbeddedDocumentField( + mongoengine_fields.EmbeddedDocument + ), + "$last", + ), "max": (mongoengine_fields.DynamicField(), "$max"), "accumulator": (mongoengine_fields.DynamicField(), "$accumulator"), "min": (mongoengine_fields.DynamicField(), "$min"), "median": (mongoengine_fields.DynamicField(), "$median"), "mergeObjects": (mongoengine_fields.DictField(), "$mergeObjects"), - "top": (mongoengine_fields.EmbeddedDocumentField(mongoengine_fields.EmbeddedDocument), "$top"), + "top": ( + mongoengine_fields.EmbeddedDocumentField( + mongoengine_fields.EmbeddedDocument + ), + "$top", + ), "bottom": ( - mongoengine_fields.EmbeddedDocumentField(mongoengine_fields.EmbeddedDocument), + mongoengine_fields.EmbeddedDocumentField( + mongoengine_fields.EmbeddedDocument + ), "$bottom", ), "topN": (mongoengine_fields.ListField(), "$topN"), @@ -535,25 +553,25 @@ def lookup( foreign_field: Union[str, None] = None, ) -> "Aggify": """ - Generates a MongoDB lookup pipeline stage. - - Args: - from_collection (Document): The document representing the collection to perform the lookup on. - as_name (str): The name of the new field to create. - query (list[Q] | Union[Q, None], optional): List of desired queries with Q function or a single query. -<<<<<<< HEAD - let (Union[List[str], None], optional): The local field(s) to join on. If provided, localField and foreignField are not used. - local_field (Union[str, None], optional): The local field to join on when `let` is not provided. - foreign_field (Union[str, None], optional): The foreign field to join on when `let` is not provided. -======= - let (Union[List[str], None], optional): The local field(s) to join on. If provided, - localField and foreignField are not used. - local_field (Union[str, None], optional): The local field to join on when let not provided. - foreign_field (Union[str, None], optional): The foreign field to join on when let not provided. ->>>>>>> main - - Returns: - Aggify: An instance of the Aggify class representing a MongoDB lookup pipeline stage. + Generates a MongoDB lookup pipeline stage. + + Args: + from_collection (Document): The document representing the collection to perform the lookup on. + as_name (str): The name of the new field to create. + query (list[Q] | Union[Q, None], optional): List of desired queries with Q function or a single query. + <<<<<<< HEAD + let (Union[List[str], None], optional): The local field(s) to join on. If provided, localField and foreignField are not used. + local_field (Union[str, None], optional): The local field to join on when `let` is not provided. + foreign_field (Union[str, None], optional): The foreign field to join on when `let` is not provided. + ======= + let (Union[List[str], None], optional): The local field(s) to join on. If provided, + localField and foreignField are not used. + local_field (Union[str, None], optional): The local field to join on when let not provided. + foreign_field (Union[str, None], optional): The foreign field to join on when let not provided. + >>>>>>> main + + Returns: + Aggify: An instance of the Aggify class representing a MongoDB lookup pipeline stage. """ lookup_stages = [] diff --git a/tests/test_aggify.py b/tests/test_aggify.py index 4e2acb4..f9af0de 100644 --- a/tests/test_aggify.py +++ b/tests/test_aggify.py @@ -143,30 +143,24 @@ def test_add_field_value_error(self): aggify.add_fields(**fields) @pytest.mark.parametrize( - ('fields', 'expected'), + ("fields", "expected"), ( - ( - {"scores__age": "Mahdi"}, - {"scores.age": {"$literal": "Mahdi"}} - ), + ({"scores__age": "Mahdi"}, {"scores.age": {"$literal": "Mahdi"}}), ( {"new_field": F("existing_field") + 10}, - {"new_field": {"$add": ["$existing_field", 10]}} - ), - ( - {'array': [1, 2, 3, 4]}, - {'array': [1, 2, 3, 4]} + {"new_field": {"$add": ["$existing_field", 10]}}, ), + ({"array": [1, 2, 3, 4]}, {"array": [1, 2, 3, 4]}), ( {"cond": Cond(30, "==", 30, "Equal", "Not Equal")}, - {"cond": dict(Cond(30, "==", 30, "Equal", "Not Equal"))} - ) - ) + {"cond": dict(Cond(30, "==", 30, "Equal", "Not Equal"))}, + ), + ), ) def test_add_fields(self, fields, expected): aggify = Aggify(BaseModel) add_fields_stage = aggify.add_fields(**fields) - assert add_fields_stage.pipelines[0]['$addFields'] == expected + assert add_fields_stage.pipelines[0]["$addFields"] == expected def test_filter_value_error(self): with pytest.raises(AggifyValueError): diff --git a/tests/test_f.py b/tests/test_f.py index 05cfcbb..f69d230 100644 --- a/tests/test_f.py +++ b/tests/test_f.py @@ -5,30 +5,30 @@ class TestF: def test_subtraction(self): f1 = F("age") f2 = F("income") - f3 = F('x') + f3 = F("x") f_combined = f1 - f2 - f3 assert f_combined.to_dict() == {"$subtract": ["$age", "$income", "$x"]} def test_division(self): f1 = F("income") f2 = F("expenses") - f3 = F('x') + f3 = F("x") f_combined = f1 / f2 / f3 assert f_combined.to_dict() == {"$divide": ["$income", "$expenses", "$x"]} def test_multiplication(self): f1 = F("quantity") f2 = F("price") - f3 = F('x') + f3 = F("x") f_combined = f1 * f2 * f3 assert f_combined.to_dict() == {"$multiply": ["$quantity", "$price", "$x"]} def test_addition(self): - f1 = F('income') - f2 = F('interest') - f3 = F('x') + f1 = F("income") + f2 = F("interest") + f3 = F("x") f = f1 + f2 + f3 - assert f.to_dict() == {"$add": ['$income', "$interest", "$x"]} + assert f.to_dict() == {"$add": ["$income", "$interest", "$x"]} def test_addition_with_constant(self): f1 = F("age") diff --git a/tests/test_match.py b/tests/test_match.py index 2e6439e..6e5f4ca 100644 --- a/tests/test_match.py +++ b/tests/test_match.py @@ -6,9 +6,9 @@ def test_validate_operator_fail(): with pytest.raises(InvalidOperator): - Match.validate_operator('key_raise') + Match.validate_operator("key_raise") def test_validate_operator_fail_not_in_operators(): with pytest.raises(InvalidOperator): - Match.validate_operator('key__ge') + Match.validate_operator("key__ge") diff --git a/tests/test_q.py b/tests/test_q.py index dc7d8dc..4d69969 100644 --- a/tests/test_q.py +++ b/tests/test_q.py @@ -31,11 +31,7 @@ def test_and(self): q2 = Q(age__gt=20) q = q1 & q2 - assert dict(q) == { - "$match": { - "$and": [dict(q1)["$match"], dict(q2)["$match"]] - } - } + assert dict(q) == {"$match": {"$and": [dict(q1)["$match"], dict(q2)["$match"]]}} def test_multiple_and(self): q1 = Q(name="Mahdi") @@ -45,7 +41,7 @@ def test_multiple_and(self): assert dict(q) == { "$match": { - "$and": [dict(q1)["$match"], dict(q2)["$match"], dict(q3)['$match']] + "$and": [dict(q1)["$match"], dict(q2)["$match"], dict(q3)["$match"]] } } From 7f0f59dc731422a67555ac5f52119bd4f96f87c1 Mon Sep 17 00:00:00 2001 From: mahdihaghverdi Date: Sun, 5 Nov 2023 21:05:03 +0330 Subject: [PATCH 12/13] fix the linting error Signed-off-by: mahdihaghverdi --- aggify/aggify.py | 12 +++++++----- aggify/types.py | 4 +--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/aggify/aggify.py b/aggify/aggify.py index f252926..a48b5b7 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Dict, Type, Union, List +from typing import Any, Dict, Type, Union, List, TypeVar, Callable from mongoengine import Document, EmbeddedDocument, fields as mongoengine_fields from mongoengine.base import TopLevelDocumentMetaclass @@ -13,8 +13,7 @@ OutStageError, InvalidArgument, ) - -from aggify.types import QueryParams, AggifyType, CollectionType +from aggify.types import QueryParams from aggify.utilty import ( to_mongo_positive_index, check_fields_exist, @@ -24,6 +23,9 @@ get_db_field, ) +AggifyType = TypeVar("AggifyType", bound=Callable[..., "Aggify"]) +CollectionType = TypeVar("CollectionType", bound=Callable[..., "Document"]) + def last_out_stage_check(method: AggifyType) -> AggifyType: """Check if the last stage is $out or not @@ -251,8 +253,8 @@ def __to_aggregate(self, query: Dict[str, Any]) -> None: or "document_type_obj" not in join_field.__dict__ # Check whether this field is a join field or not. or issubclass( - join_field.document_type, EmbeddedDocument # noqa - ) # Check whether this field is embedded field or not + join_field.document_type, EmbeddedDocument # noqa + ) # Check whether this field is embedded field or not or len(split_query) == 1 or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS) ): diff --git a/aggify/types.py b/aggify/types.py index 9cd099e..a9b3a11 100644 --- a/aggify/types.py +++ b/aggify/types.py @@ -1,7 +1,5 @@ -from typing import Union, Dict, TypeVar, Callable +from typing import Union, Dict from bson import ObjectId QueryParams = Union[int, None, str, bool, float, Dict, ObjectId] -AggifyType = TypeVar("AggifyType", bound=Callable[..., "Aggify"]) -CollectionType = TypeVar("CollectionType", bound=Callable[..., "Document"]) From 1642f61a23076f0e4b1d5cf252a5836f7e222b19 Mon Sep 17 00:00:00 2001 From: mahdihaghverdi Date: Sun, 5 Nov 2023 21:07:19 +0330 Subject: [PATCH 13/13] fix importing errors Signed-off-by: mahdihaghverdi --- aggify/utilty.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/aggify/utilty.py b/aggify/utilty.py index ea056ce..1fceb69 100644 --- a/aggify/utilty.py +++ b/aggify/utilty.py @@ -3,7 +3,6 @@ from mongoengine import Document from aggify.exceptions import MongoIndexError, InvalidField, AlreadyExistsField -from aggify.types import CollectionType def to_mongo_positive_index(index: Union[int, slice]) -> slice: @@ -23,7 +22,7 @@ def to_mongo_positive_index(index: Union[int, slice]) -> slice: return index -def check_fields_exist(model: CollectionType, fields_to_check: List[str]) -> None: +def check_fields_exist(model, fields_to_check: List[str]) -> None: """ Check if the specified fields exist in a model's fields. @@ -106,7 +105,7 @@ def convert_match_query( return d -def check_field_exists(model: CollectionType, field: str) -> None: +def check_field_exists(model, field: str) -> None: """ Check if a field exists in the given model. @@ -121,7 +120,7 @@ def check_field_exists(model: CollectionType, field: str) -> None: raise AlreadyExistsField(field=field) -def get_db_field(model: CollectionType, field: str, add_dollar_sign=False) -> str: +def get_db_field(model, field: str, add_dollar_sign=False) -> str: """ Get the database field name for a given field in the model.