diff --git a/aggify/aggify.py b/aggify/aggify.py index d45fc48..a2e3b1c 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -117,8 +117,13 @@ def project(self, **kwargs: QueryParams) -> "Aggify": return self @last_out_stage_check - def group(self, expression: Union[str, None] = "id") -> "Aggify": - if expression: + def group(self, expression: Union[str, Dict, List, None] = "id") -> "Aggify": + if isinstance(expression, list): + expression = { + field: f"${self.get_field_name_recursively(field)}" + for field in expression + } + if expression and not isinstance(expression, dict): try: expression = "$" + self.get_field_name_recursively(expression) except InvalidField: @@ -454,7 +459,9 @@ def annotate( # Determine the data type based on the aggregation operator self.pipelines[-1]["$group"].update({annotate_name: {acc: value}}) - self.base_model._fields[annotate_name] = field_type # noqa + base_model_fields = self.base_model._fields # noqa + if not base_model_fields.get(annotate_name, None): + base_model_fields[annotate_name] = field_type return self def __match(self, matches: Dict[str, Any]): diff --git a/tests/test_aggify.py b/tests/test_aggify.py index 983faf0..b73053e 100644 --- a/tests/test_aggify.py +++ b/tests/test_aggify.py @@ -662,3 +662,7 @@ def test_lookup_raw_let(self): "as": "test_name", } assert "test_name" in list(aggify.base_model._fields.keys()) + + def test_group_multi_expressions(self): + thing = list(Aggify(BaseModel).group(["name", "age"])) + assert thing[0]["$group"] == {"_id": {"name": "$name", "age": "$age"}}