Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions aggify/aggify.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def order_by(self, *order_fields: Union[str, List[str]]) -> "Aggify":
@last_out_stage_check
def raw(self, raw_query: dict) -> "Aggify":
self.pipelines.append(raw_query)
self.pipelines = self.__combine_sequential_matches()
return self

@last_out_stage_check
Expand Down Expand Up @@ -558,6 +559,7 @@ def lookup(
let: Union[List[str], None] = None,
local_field: Union[str, None] = None,
foreign_field: Union[str, None] = None,
raw_let: Union[Dict, None] = None,
) -> "Aggify":
"""
Generates a MongoDB lookup pipeline stage.
Expand All @@ -574,6 +576,7 @@ def lookup(
localField and foreignField are not used.
local_field (Union[str, None], optional): The local field to join on when let not provided.
foreign_field (Union[str, None], optional): The foreign field to join on when let not provided.
raw_let (Union[Dict, None]): raw let

Returns:
Aggify: An instance of the Aggify class representing a MongoDB lookup pipeline stage.
Expand All @@ -583,11 +586,11 @@ def lookup(
check_field_exists(self.base_model, as_name) # noqa
from_collection_name = from_collection._meta.get("collection") # noqa

if not let and not (local_field and foreign_field):
if not (let or raw_let) and not (local_field and foreign_field):
raise InvalidArgument(
expected_list=[["local_field", "foreign_field"], "let"]
expected_list=[["local_field", "foreign_field"], ["let", "raw_let"]]
)
elif not let:
elif not (let or raw_let):
lookup_stage = {
"$lookup": {
"from": from_collection_name,
Expand All @@ -602,11 +605,16 @@ def lookup(
if not query:
raise InvalidArgument(expected_list=["query"])

if let is None:
let = []

let_dict = {
field: f"${get_db_field(self.base_model, self.get_field_name_recursively(field))}" # noqa
for field in let
}

let = list(raw_let.keys()) if let is [] else let

for q in query:
# Construct the match stage for each query
if isinstance(q, Q):
Expand All @@ -625,6 +633,8 @@ def lookup(
)

# Append the lookup stage with multiple match stages to the pipeline
if raw_let:
let_dict.update(raw_let)
lookup_stage = {
"$lookup": {
"from": from_collection_name,
Expand Down Expand Up @@ -675,13 +685,17 @@ def _replace_base(self, embedded_field) -> str:
InvalidEmbeddedField: If the specified embedded field is not found or is not of the correct type.
"""
model_field = self.get_model_field(self.base_model, embedded_field) # noqa

field_name = get_db_field(self.base_model, embedded_field)
if "__module__" in model_field.__dict__:
self.base_model._fields = (
model_field._fields
) # load new fields into old model
return f"${field_name}"
if not hasattr(model_field, "document_type") or not issubclass(
model_field.document_type, EmbeddedDocument
):
raise InvalidEmbeddedField(field=embedded_field)

return f"${model_field.db_field}"
return f"${field_name}"

@last_out_stage_check
def replace_root(
Expand All @@ -703,10 +717,13 @@ def replace_root(
"""
name = self._replace_base(embedded_field)

if not merge:
new_root = {"$replaceRoot": {"$newRoot": name}}
else:
if merge:
new_root = {"$replaceRoot": {"newRoot": {"$mergeObjects": [merge, name]}}}
self.base_model._fields.update( # noqa
{key: mongoengine_fields.IntField() for key, value in merge.items()}
)
else:
new_root = {"$replaceRoot": {"$newRoot": name}}
self.pipelines.append(new_root)

return self
Expand Down Expand Up @@ -735,6 +752,9 @@ def replace_with(
new_root = {"$replaceWith": name}
else:
new_root = {"$replaceWith": {"$mergeObjects": [merge, name]}}
self.base_model._fields.update( # noqa
{key: mongoengine_fields.IntField() for key, value in merge.items()}
)
self.pipelines.append(new_root)

return self
Expand Down
18 changes: 18 additions & 0 deletions tests/test_aggify.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,3 +644,21 @@ def test_project_add_new_field(self):
thing = list(aggify.project(test="test", id=0))
assert thing[0]["$project"] == {"test": "test", "_id": 0}
assert list(aggify.base_model._fields.keys()) == ["test"]

def test_lookup_raw_let(self):
aggify = Aggify(BaseModel)
thing = list(
aggify.lookup(
BaseModel,
raw_let={"test": "$name"},
query=[Q(name__exact="$$test")],
as_name="test_name",
)
)
assert thing[0]["$lookup"] == {
"from": None,
"let": {"test": "$name"},
"pipeline": [{"$match": {"$expr": {"$eq": ["$name", "$$test"]}}}],
"as": "test_name",
}
assert "test_name" in list(aggify.base_model._fields.keys())
62 changes: 62 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,68 @@ class ParameterTestCase:
{"$group": {"_id": "$stat.like_count", "sss": {"$first": "sss"}}}
],
),
ParameterTestCase(
compiled_query=(
Aggify(PostDocument).lookup(
AccountDocument,
let=["caption"],
raw_let={
"latest_story_id": {"$last": {"$slice": ["$owner.story", -1]}}
},
query=[
Q(end__exact="caption") & Q(start__exact="$$latest_story_id._id")
],
as_name="is_seen",
)
),
expected_query=[
{
"$lookup": {
"from": "account",
"let": {
"caption": "$caption",
"latest_story_id": {"$last": {"$slice": ["$owner.story", -1]}},
},
"pipeline": [
{
"$match": {
"$expr": {
"$and": [
{"$eq": ["$end", "$$caption"]},
{"$eq": ["$start", "$$latest_story_id._id"]},
]
}
}
}
],
"as": "is_seen",
}
}
],
),
ParameterTestCase(
compiled_query=(
Aggify(PostDocument)
.lookup(
PostDocument,
local_field="end",
foreign_field="id",
as_name="saved_post",
)
.replace_root(embedded_field="saved_post")
),
expected_query=[
{
"$lookup": {
"as": "saved_post",
"foreignField": "_id",
"from": "post_document",
"localField": "end",
}
},
{"$replaceRoot": {"$newRoot": "$saved_post"}},
],
),
]


Expand Down