diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index f2092f9c63054..4eef242df579c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -80,15 +80,32 @@ def agg(self, *exprs): >>> from pyspark.sql import functions as F >>> sorted(gdf.agg(F.min(df.age)).collect()) [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] + + >>> sorted(gdf.agg('max','min').collect()) """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): jdf = self._jgd.agg(exprs[0]) else: + assert all(isinstance(c,Column) for c in exprs) or all(isinstance(c,str) for c in exprs),\ + "all exprs should be Column or support aggregate function names(str)" # Columns - assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jdf = self._jgd.agg(exprs[0]._jc, - _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) + if all(isinstance(c,Column) for c in exprs): + jdf = self._jgd.agg(exprs[0]._jc, + _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) + # agg methods + elif all(isinstance(c, str) for c in exprs): + tmp_df = None + keys = [] + for m in exprs: + df = getattr(self, m)() + if tmp_df == None: + tmp_df = df + else: + if len(keys == 0): + keys = [key for key in set(tmp_df.columns).intersection(set(df.columns))] + tmp_df = tmp_df.join(df,keys) + return tmp_df return DataFrame(jdf, self.sql_ctx) @dfapi