diff --git a/aggify/aggify.py b/aggify/aggify.py index 28db214..5422bae 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -116,12 +116,10 @@ def project(self, **kwargs: QueryParams) -> "Aggify": @last_out_stage_check def group(self, expression: Union[str, None] = "id") -> "Aggify": if expression: - check_fields_exist(self.base_model, [expression]) - expression = ( - get_db_field(self.base_model, expression, add_dollar_sign=True) - if expression - else None - ) + try: + expression = "$" + self.get_field_name_recursively(expression) + except InvalidField: + pass self.pipelines.append({"$group": {"_id": expression}}) return self diff --git a/tests/test_aggify.py b/tests/test_aggify.py index 043411d..e49ea9f 100644 --- a/tests/test_aggify.py +++ b/tests/test_aggify.py @@ -613,3 +613,7 @@ def test_aggify_get_item_slice_negative_start(self): with pytest.raises(MongoIndexError): # noinspection PyUnusedLocal var = aggify.filter(name=1)[slice(-5, -1)] + + def test_group_invalid_field(self): + thing = list(Aggify(BaseModel).group("invalid").annotate("name", "first", 2)) + assert thing[0]["$group"] == {"_id": "invalid", "name": {"$first": 2}} diff --git a/tests/test_query.py b/tests/test_query.py index 0bce2c7..61ef1b2 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -456,6 +456,16 @@ class ParameterTestCase: ), expected_query=[{"$group": {"_id": "$owner_id", "sss": {"$first": "sss"}}}], ), + ParameterTestCase( + compiled_query=( + Aggify(PostDocument) + .group("stat__like_count") + .annotate("sss", "first", "sss") + ), + expected_query=[ + {"$group": {"_id": "$stat.like_count", "sss": {"$first": "sss"}}} + ], + ), ]