Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions aggify/aggify.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ def project(self, **kwargs: QueryParams) -> "Aggify":
Returns:
Aggify: Returns an instance of the Aggify class for potential method chaining.
"""

if all([i in kwargs.values() for i in [0, 1]]):
filtered_kwargs = dict(kwargs)
filtered_kwargs.pop("id", None)
if all([i in filtered_kwargs.values() for i in [0, 1]]):
raise InvalidProjection()

# Extract fields to keep and check if _id should be deleted
to_keep_values = {"id"}
delete_id = kwargs.get("id") is not None
projection = {}

# Add missing fields to the base model
Expand All @@ -109,13 +109,10 @@ def project(self, **kwargs: QueryParams) -> "Aggify":
# Remove fields from the base model, except the ones in to_keep_values and possibly _id
if to_keep_values != {"id"}:
keys_for_deletion = self.base_model._fields.keys() - to_keep_values # noqa
if delete_id:
keys_for_deletion.add("id")
for key in keys_for_deletion:
del self.base_model._fields[key] # noqa
# Append the projection stage to the pipelines
self.pipelines.append({"$project": projection})

# Return the instance for method chaining
return self

Expand Down
12 changes: 10 additions & 2 deletions tests/test_aggify.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ def test_filtering_and_projection(self):
aggify.filter(age__gte=30).project(name=1, age=1)
assert len(aggify.pipelines) == 2
assert aggify.pipelines[1]["$project"] == {"name": 1, "age": 1}
assert list(aggify.base_model._fields.keys()) == ["name", "age", "id"]

def test_filtering_and_projection_with_deleting_id(self):
aggify = Aggify(BaseModel)
aggify.filter(age__gte=30).project(name=1, age=1, id=1)
aggify.filter(age__gte=30).project(name=1, age=1, id=0)
assert len(aggify.pipelines) == 2
assert aggify.pipelines[1]["$project"] == {"_id": 1, "name": 1, "age": 1}
assert aggify.pipelines[1]["$project"] == {"_id": 0, "name": 1, "age": 1}
assert list(aggify.base_model._fields.keys()) == ["name", "age"]

def test_filtering_and_ordering(self):
aggify = Aggify(BaseModel)
Expand Down Expand Up @@ -636,3 +638,9 @@ def test_project_use_inclusion_and_exclusion_together(self):
with pytest.raises(InvalidProjection):
# noinspection PyUnusedLocal
var = aggify.project(name=0, age=1)

def test_project_add_new_field(self):
aggify = Aggify(BaseModel)
thing = list(aggify.project(test="test", id=0))
assert thing[0]["$project"] == {"test": "test", "_id": 0}
assert list(aggify.base_model._fields.keys()) == ["test"]