Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 76 additions & 54 deletions aggify/aggify.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
from typing import Any, Dict, Type, Union, List, Callable, TypeVar
from typing import Any, Dict, Type, Union, List, TypeVar, Callable

from mongoengine import Document, EmbeddedDocument, fields
from mongoengine import Document, EmbeddedDocument, fields as mongoengine_fields
from mongoengine.base import TopLevelDocumentMetaclass

from aggify.compiler import F, Match, Q, Operators, Cond # noqa keep
Expand All @@ -13,7 +13,7 @@
OutStageError,
InvalidArgument,
)
from aggify.types import QueryParams, CollectionType
from aggify.types import QueryParams
from aggify.utilty import (
to_mongo_positive_index,
check_fields_exist,
Expand All @@ -24,6 +24,7 @@
)

AggifyType = TypeVar("AggifyType", bound=Callable[..., "Aggify"])
CollectionType = TypeVar("CollectionType", bound=Callable[..., "Document"])


def last_out_stage_check(method: AggifyType) -> AggifyType:
Expand Down Expand Up @@ -85,25 +86,23 @@ def project(self, **kwargs: QueryParams) -> "Aggify":
"""

# Extract fields to keep and check if _id should be deleted
to_keep_values = ["id"]
delete_id = kwargs.get("id") == 0
to_keep_values = {"id"}
delete_id = kwargs.get("id") is not None
projection = {}

# Add missing fields to the base model
for key, value in kwargs.items():
if value == 1:
to_keep_values.append(key)
to_keep_values.add(key)
elif key not in self.base_model._fields and isinstance( # noqa
kwargs[key], (str, dict)
): # noqa
to_keep_values.append(key)
self.base_model._fields[key] = fields.IntField() # noqa
):
to_keep_values.add(key)
self.base_model._fields[key] = mongoengine_fields.IntField() # noqa
projection[get_db_field(self.base_model, key)] = value # noqa

# Remove fields from the base model, except the ones in to_keep_values and possibly _id
keys_for_deletion = set(self.base_model._fields.keys()) - set( # noqa
to_keep_values
) # noqa
keys_for_deletion = self.base_model._fields.keys() - to_keep_values # noqa
if delete_id:
keys_for_deletion.add("id")
for key in keys_for_deletion:
Expand Down Expand Up @@ -144,19 +143,18 @@ def raw(self, raw_query: dict) -> "Aggify":
return self

@last_out_stage_check
def add_fields(self, **_fields) -> "Aggify": # noqa
"""
Generates a MongoDB addFields pipeline stage.
def add_fields(self, **fields) -> "Aggify": # noqa
"""Generates a MongoDB addFields pipeline stage.

Args:
_fields: A dictionary of field expressions and values.
fields: A dictionary of field expressions and values.

Returns:
A MongoDB add_fields pipeline stage.
"""
add_fields_stage = {"$addFields": {}}

for field, expression in _fields.items():
for field, expression in fields.items():
field = field.replace("__", ".")
if isinstance(expression, str):
add_fields_stage["$addFields"][field] = {"$literal": expression}
Expand All @@ -169,7 +167,9 @@ def add_fields(self, **_fields) -> "Aggify": # noqa
else:
raise AggifyValueError([str, F, list], type(expression))
# TODO: Should be checked if new field is embedded, create embedded field.
self.base_model._fields[field.replace("$", "")] = fields.IntField() # noqa
self.base_model._fields[
field.replace("$", "")
] = mongoengine_fields.IntField() # noqa

self.pipelines.append(add_fields_stage)
return self
Expand Down Expand Up @@ -240,7 +240,6 @@ def __to_aggregate(self, query: Dict[str, Any]) -> None:
"""

for key, value in query.items():

# Split the key to access the field information.
split_query = key.split("__")

Expand All @@ -254,8 +253,8 @@ def __to_aggregate(self, query: Dict[str, Any]) -> None:
or "document_type_obj"
not in join_field.__dict__ # Check whether this field is a join field or not.
or issubclass(
join_field.document_type, EmbeddedDocument # noqa
) # Check whether this field is embedded field or not
join_field.document_type, EmbeddedDocument # noqa
) # Check whether this field is embedded field or not
or len(split_query) == 1
or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS)
):
Expand Down Expand Up @@ -377,30 +376,47 @@ def annotate(

# Some of the accumulator fields might be false and should be checked.
aggregation_mapping: Dict[str, Type] = {
"sum": (fields.FloatField(), "$sum"),
"avg": (fields.FloatField(), "$avg"),
"stdDevPop": (fields.FloatField(), "$stdDevPop"),
"stdDevSamp": (fields.FloatField(), "$stdDevSamp"), # noqa
"push": (fields.ListField(), "$push"),
"addToSet": (fields.ListField(), "$addToSet"),
"count": (fields.IntField(), "$count"),
"first": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$first"),
"last": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$last"),
"max": (fields.DynamicField(), "$max"),
"accumulator": (fields.DynamicField(), "$accumulator"),
"min": (fields.DynamicField(), "$min"),
"median": (fields.DynamicField(), "$median"),
"mergeObjects": (fields.DictField(), "$mergeObjects"),
"top": (fields.EmbeddedDocumentField(fields.EmbeddedDocument), "$top"),
"sum": (mongoengine_fields.FloatField(), "$sum"),
"avg": (mongoengine_fields.FloatField(), "$avg"),
"stdDevPop": (mongoengine_fields.FloatField(), "$stdDevPop"),
"stdDevSamp": (mongoengine_fields.FloatField(), "$stdDevSamp"),
"push": (mongoengine_fields.ListField(), "$push"),
"addToSet": (mongoengine_fields.ListField(), "$addToSet"),
"count": (mongoengine_fields.IntField(), "$count"),
"first": (
mongoengine_fields.EmbeddedDocumentField(
mongoengine_fields.EmbeddedDocument
),
"$first",
),
"last": (
mongoengine_fields.EmbeddedDocumentField(
mongoengine_fields.EmbeddedDocument
),
"$last",
),
"max": (mongoengine_fields.DynamicField(), "$max"),
"accumulator": (mongoengine_fields.DynamicField(), "$accumulator"),
"min": (mongoengine_fields.DynamicField(), "$min"),
"median": (mongoengine_fields.DynamicField(), "$median"),
"mergeObjects": (mongoengine_fields.DictField(), "$mergeObjects"),
"top": (
mongoengine_fields.EmbeddedDocumentField(
mongoengine_fields.EmbeddedDocument
),
"$top",
),
"bottom": (
fields.EmbeddedDocumentField(fields.EmbeddedDocument),
mongoengine_fields.EmbeddedDocumentField(
mongoengine_fields.EmbeddedDocument
),
"$bottom",
),
"topN": (fields.ListField(), "$topN"),
"bottomN": (fields.ListField(), "$bottomN"),
"firstN": (fields.ListField(), "$firstN"),
"lastN": (fields.ListField(), "$lastN"),
"maxN": (fields.ListField(), "$maxN"),
"topN": (mongoengine_fields.ListField(), "$topN"),
"bottomN": (mongoengine_fields.ListField(), "$bottomN"),
"firstN": (mongoengine_fields.ListField(), "$firstN"),
"lastN": (mongoengine_fields.ListField(), "$lastN"),
"maxN": (mongoengine_fields.ListField(), "$maxN"),
}

try:
Expand Down Expand Up @@ -539,19 +555,25 @@ def lookup(
foreign_field: Union[str, None] = None,
) -> "Aggify":
"""
Generates a MongoDB lookup pipeline stage.
Generates a MongoDB lookup pipeline stage.

Args:
from_collection (Document): The document representing the collection to perform the lookup on.
as_name (str): The name of the new field to create.
query (list[Q] | Union[Q, None], optional): List of desired queries with Q function or a single query.
let (Union[List[str], None], optional): The local field(s) to join on. If provided,
localField and foreignField are not used.
local_field (Union[str, None], optional): The local field to join on when let not provided.
foreign_field (Union[str, None], optional): The foreign field to join on when let not provided.
Args:
from_collection (Document): The document representing the collection to perform the lookup on.
as_name (str): The name of the new field to create.
query (list[Q] | Union[Q, None], optional): List of desired queries with Q function or a single query.
<<<<<<< HEAD
let (Union[List[str], None], optional): The local field(s) to join on. If provided, localField and foreignField are not used.
local_field (Union[str, None], optional): The local field to join on when `let` is not provided.
foreign_field (Union[str, None], optional): The foreign field to join on when `let` is not provided.
=======
let (Union[List[str], None], optional): The local field(s) to join on. If provided,
localField and foreignField are not used.
local_field (Union[str, None], optional): The local field to join on when let not provided.
foreign_field (Union[str, None], optional): The foreign field to join on when let not provided.
>>>>>>> main

Returns:
Aggify: An instance of the Aggify class representing a MongoDB lookup pipeline stage.
Returns:
Aggify: An instance of the Aggify class representing a MongoDB lookup pipeline stage.
"""

lookup_stages = []
Expand Down Expand Up @@ -617,7 +639,7 @@ def lookup(
return self

@staticmethod
def get_model_field(model: CollectionType, field: str) -> fields:
def get_model_field(model: Type[Document], field: str) -> mongoengine_fields:
"""
Get the field definition of a specified field in a MongoDB model.

Expand Down
26 changes: 13 additions & 13 deletions aggify/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,6 @@ def compile_match(self, operator: str, value, field: str):
}
else:
self.match_query[field] = {Operators.ALL_OPERATORS[operator]: value}
else:
# Default behavior
self.match_query[field] = {
Operators.ALL_OPERATORS.get(operator, operator): value
}

return self.match_query

Expand All @@ -107,7 +102,7 @@ def __iter__(self):
yield "$match", self.conditions

def __or__(self, other):
if self.conditions.get("$or", None):
if self.conditions.get("$or"):
self.conditions["$or"].append(dict(other)["$match"])
combined_conditions = self.conditions

Expand All @@ -116,7 +111,7 @@ def __or__(self, other):
return Q(**combined_conditions)

def __and__(self, other):
if self.conditions.get("$and", None):
if self.conditions.get("$and"):
self.conditions["$and"].append(dict(other)["$match"])
combined_conditions = self.conditions
else:
Expand All @@ -138,7 +133,7 @@ def __init__(self, field: Union[str, Dict[str, list]]):
def to_dict(self):
return self.field

def __add__(self, other): # TODO: add type for 'other'
def __add__(self, other):
if isinstance(other, F):
other = other.field

Expand Down Expand Up @@ -248,7 +243,12 @@ def __init__(

@staticmethod
def validate_operator(key: str):
operator = key.rsplit("__", 1)[1]
_op = key.rsplit("__", 1)
try:
operator = _op[1]
except IndexError:
raise InvalidOperator(_op) from None

if operator not in Operators.COMPARISON_OPERATORS:
raise InvalidOperator(operator)

Expand All @@ -275,15 +275,15 @@ def is_base_model_field(self, field) -> bool:
def compile(self, pipelines: list) -> Dict[str, Dict[str, list]]:
match_query = {}
for key, value in self.matches.items():
if isinstance(value, F):
if F.is_suitable_for_match(key) is False:
raise InvalidOperator(key)

if "__" not in key:
key = get_db_field(self.base_model, key)
match_query[key] = value
continue

if isinstance(value, F):
if F.is_suitable_for_match(key) is False:
raise InvalidOperator(key)

field, operator, *_ = key.split("__")
if (
self.is_base_model_field(field)
Expand Down
6 changes: 2 additions & 4 deletions aggify/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Union, Dict, TypeVar, Callable
from mongoengine import Document
from typing import Union, Dict

from bson import ObjectId

QueryParams = Union[int, None, str, bool, float, Dict, ObjectId]

CollectionType = TypeVar("CollectionType", bound=Callable[..., Document])
7 changes: 3 additions & 4 deletions aggify/utilty.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from mongoengine import Document

from aggify.exceptions import MongoIndexError, InvalidField, AlreadyExistsField
from aggify.types import CollectionType


def to_mongo_positive_index(index: Union[int, slice]) -> slice:
Expand All @@ -23,7 +22,7 @@ def to_mongo_positive_index(index: Union[int, slice]) -> slice:
return index


def check_fields_exist(model: CollectionType, fields_to_check: List[str]) -> None:
def check_fields_exist(model, fields_to_check: List[str]) -> None:
"""
Check if the specified fields exist in a model's fields.

Expand Down Expand Up @@ -106,7 +105,7 @@ def convert_match_query(
return d


def check_field_exists(model: CollectionType, field: str) -> None:
def check_field_exists(model, field: str) -> None:
"""
Check if a field exists in the given model.

Expand All @@ -121,7 +120,7 @@ def check_field_exists(model: CollectionType, field: str) -> None:
raise AlreadyExistsField(field=field)


def get_db_field(model: CollectionType, field: str, add_dollar_sign=False) -> str:
def get_db_field(model, field: str, add_dollar_sign=False) -> str:
"""
Get the database field name for a given field in the model.

Expand Down
Loading