diff --git a/aggify/aggify.py b/aggify/aggify.py index d403623..5b16734 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -84,13 +84,13 @@ def project(self, **kwargs: QueryParams) -> "Aggify": Returns: Aggify: Returns an instance of the Aggify class for potential method chaining. """ - - if all([i in kwargs.values() for i in [0, 1]]): + filtered_kwargs = dict(kwargs) + filtered_kwargs.pop("id", None) + if all([i in filtered_kwargs.values() for i in [0, 1]]): raise InvalidProjection() # Extract fields to keep and check if _id should be deleted to_keep_values = {"id"} - delete_id = kwargs.get("id") is not None projection = {} # Add missing fields to the base model @@ -109,13 +109,10 @@ def project(self, **kwargs: QueryParams) -> "Aggify": # Remove fields from the base model, except the ones in to_keep_values and possibly _id if to_keep_values != {"id"}: keys_for_deletion = self.base_model._fields.keys() - to_keep_values # noqa - if delete_id: - keys_for_deletion.add("id") for key in keys_for_deletion: del self.base_model._fields[key] # noqa # Append the projection stage to the pipelines self.pipelines.append({"$project": projection}) - # Return the instance for method chaining return self diff --git a/tests/test_aggify.py b/tests/test_aggify.py index f931deb..cc71fcf 100644 --- a/tests/test_aggify.py +++ b/tests/test_aggify.py @@ -51,12 +51,14 @@ def test_filtering_and_projection(self): aggify.filter(age__gte=30).project(name=1, age=1) assert len(aggify.pipelines) == 2 assert aggify.pipelines[1]["$project"] == {"name": 1, "age": 1} + assert list(aggify.base_model._fields.keys()) == ["name", "age", "id"] def test_filtering_and_projection_with_deleting_id(self): aggify = Aggify(BaseModel) - aggify.filter(age__gte=30).project(name=1, age=1, id=1) + aggify.filter(age__gte=30).project(name=1, age=1, id=0) assert len(aggify.pipelines) == 2 - assert aggify.pipelines[1]["$project"] == {"_id": 1, "name": 1, "age": 1} + assert aggify.pipelines[1]["$project"] == {"_id": 0, "name": 1, "age": 1} + assert list(aggify.base_model._fields.keys()) == ["name", "age"] def test_filtering_and_ordering(self): aggify = Aggify(BaseModel) @@ -636,3 +638,9 @@ def test_project_use_inclusion_and_exclusion_together(self): with pytest.raises(InvalidProjection): # noinspection PyUnusedLocal var = aggify.project(name=0, age=1) + + def test_project_add_new_field(self): + aggify = Aggify(BaseModel) + thing = list(aggify.project(test="test", id=0)) + assert thing[0]["$project"] == {"test": "test", "_id": 0} + assert list(aggify.base_model._fields.keys()) == ["test"]