Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions aggify/aggify.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from aggify.types import QueryParams, CollectionType
from aggify.utilty import (
to_mongo_positive_index,
check_fields_exist,
validate_field_existence,
replace_values_recursive,
convert_match_query,
check_field_exists,
check_field_already_exists,
get_db_field,
)

Expand Down Expand Up @@ -538,7 +538,7 @@ def get_field_name_recursively(self, field: str) -> str:
# Split the field based on double underscores and process each item
for index, item in enumerate(field.split("__")):
# Ensure the field exists at the current level of hierarchy
check_fields_exist(prev_base, [item]) # noqa
validate_field_existence(prev_base, [item]) # noqa

# Append the database field name to the field_name list
field_name.append(get_db_field(prev_base, item))
Expand Down Expand Up @@ -583,7 +583,7 @@ def lookup(
"""

lookup_stages = []
check_field_exists(self.base_model, as_name) # noqa
check_field_already_exists(self.base_model, as_name) # noqa
from_collection_name = from_collection._meta.get("collection") # noqa

if not (let or raw_let) and not (local_field and foreign_field):
Expand Down Expand Up @@ -723,7 +723,7 @@ def replace_root(
{key: mongoengine_fields.IntField() for key, value in merge.items()}
)
else:
new_root = {"$replaceRoot": {"$newRoot": name}}
new_root = {"$replaceRoot": {"newRoot": name}}
self.pipelines.append(new_root)

return self
Expand Down
16 changes: 10 additions & 6 deletions aggify/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mongoengine.base import TopLevelDocumentMetaclass

from aggify.exceptions import InvalidOperator
from aggify.utilty import get_db_field
from aggify.utilty import get_db_field, get_nested_field_model


class Operators:
Expand Down Expand Up @@ -284,16 +284,20 @@ def compile(self, pipelines: list) -> Dict[str, Dict[str, list]]:
match_query[key] = value
continue

field, operator, *_ = key.split("__")
field, operator, *others = key.split("__")
if (
self.is_base_model_field(field)
and operator not in Operators.ALL_OPERATORS
):
pipelines.append(
Match({key.replace("__", ".", 1): value}, self.base_model).compile(
[]
)
field_db_name = get_db_field(self.base_model, field)

nested_field_name = get_db_field(
get_nested_field_model(self.base_model, field), operator
)
key = (
f"{field_db_name}.{nested_field_name}__" + "__".join(others)
).rstrip("__")
pipelines.append(Match({key: value}, self.base_model).compile([]))
continue

if operator not in Operators.ALL_OPERATORS:
Expand Down
40 changes: 35 additions & 5 deletions aggify/utilty.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any, Union, List, Dict

from mongoengine import Document
from aggify.types import CollectionType

from aggify.exceptions import MongoIndexError, InvalidField, AlreadyExistsField
from aggify.types import CollectionType


def to_mongo_positive_index(index: Union[int, slice]) -> slice:
Expand All @@ -22,9 +23,9 @@ def to_mongo_positive_index(index: Union[int, slice]) -> slice:
return index


def check_fields_exist(model: CollectionType, fields_to_check: List[str]) -> None:
def validate_field_existence(model: CollectionType, fields_to_check: List[str]) -> None:
"""
Check if the specified fields exist in a model's fields.
The function checks a list of fields and raises an InvalidField exception if any are missing.

Args:
model: The model containing the fields to check.
Expand Down Expand Up @@ -105,7 +106,7 @@ def convert_match_query(
return d


def check_field_exists(model: CollectionType, field: str) -> None:
def check_field_already_exists(model: CollectionType, field: str) -> None:
"""
Check if a field exists in the given model.

Expand All @@ -116,7 +117,10 @@ def check_field_exists(model: CollectionType, field: str) -> None:
Raises:
AlreadyExistsField: If the field already exists in the model.
"""
if model._fields.get(field): # noqa
if field in [
f.db_field if hasattr(f, "db_field") else k
for k, f in model._fields.items() # noqa
]:
raise AlreadyExistsField(field=field)


Expand All @@ -138,3 +142,29 @@ def get_db_field(model: CollectionType, field: str, add_dollar_sign=False) -> st
return f"${db_field}" if add_dollar_sign else db_field
except AttributeError:
return field


def get_nested_field_model(model: CollectionType, field: str) -> CollectionType:
"""
Retrieves the nested field model for a specified field within a given model.

This function examines the provided model to determine if the specified field is
a nested field. If it is, the function returns the nested field's model.
Otherwise, it returns the original model.

Args:
model (CollectionType): The model to be inspected. This should be a class that
represents a collection or document in a database, typically
in an ORM or ODM framework.
field (str): The name of the field within the model to inspect for nestedness.

Returns:
CollectionType: The model of the nested field if the specified field is nested;
otherwise, returns the original model.

Raises:
KeyError: If the specified field is not found in the model.
"""
if model._fields[field].__dict__.get("__module__"): # noqa
return model
return model._fields[field].__dict__["document_type_obj"] # noqa
8 changes: 6 additions & 2 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class ParameterTestCase:
),
ParameterTestCase(
compiled_query=(Aggify(PostDocument).replace_root(embedded_field="stat")),
expected_query=[{"$replaceRoot": {"$newRoot": "$stat"}}],
expected_query=[{"$replaceRoot": {"newRoot": "$stat"}}],
),
ParameterTestCase(
compiled_query=(Aggify(PostDocument).replace_with(embedded_field="stat")),
Expand Down Expand Up @@ -525,9 +525,13 @@ class ParameterTestCase:
"localField": "end",
}
},
{"$replaceRoot": {"$newRoot": "$saved_post"}},
{"$replaceRoot": {"newRoot": "$saved_post"}},
],
),
ParameterTestCase(
compiled_query=(Aggify(PostDocument).filter(stat__like_count=2)),
expected_query=[{"$match": {"stat.like_count": 2}}],
),
]


Expand Down