From 67e75a201bd85058e2b5a843b061e00cb3856087 Mon Sep 17 00:00:00 2001 From: citoubest <1206539220@qq.com> Date: Sun, 18 Sep 2016 13:37:40 +0800 Subject: [PATCH 1/2] pyspark dataframe agg not support multiple functions for all columns, change group.agg to support df.groupby(name).agg(max,min) like pandas --- python/pyspark/sql/group.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index f2092f9c63054..2d80ce206cd83 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -85,10 +85,25 @@ def agg(self, *exprs): if len(exprs) == 1 and isinstance(exprs[0], dict): jdf = self._jgd.agg(exprs[0]) else: + assert all(isinstance(c,Column) for c in exprs) or all(isinstance(c,str) for c in exprs),\ + "all exprs should be Column or support aggregate function names(str)" # Columns - assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jdf = self._jgd.agg(exprs[0]._jc, - _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) + if all(isinstance(c,Column) for c in exprs): + jdf = self._jgd.agg(exprs[0]._jc, + _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) + # agg methods + elif all(isinstance(c, str) for c in exprs): + tmp_df = None + keys = [] + for m in exprs: + df = getattr(self, m)() + if tmp_df == None: + tmp_df = df + else: + if len(keys == 0): + keys = [key for key in set(tmp_df.columns).intersection(set(df.columns))] + tmp_df = tmp_df.join(df,keys) + return tmp_df return DataFrame(jdf, self.sql_ctx) @dfapi From 7407bc84b0376c78a7816de349a607cecac1d6f4 Mon Sep 17 00:00:00 2001 From: citoubest <1206539220@qq.com> Date: Sun, 18 Sep 2016 13:39:25 +0800 Subject: [PATCH 2/2] add comment for last change --- python/pyspark/sql/group.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 2d80ce206cd83..4eef242df579c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -80,6 +80,8 @@ def agg(self, *exprs): >>> from pyspark.sql import functions as F >>> sorted(gdf.agg(F.min(df.age)).collect()) [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] + + >>> sorted(gdf.agg('max','min').collect()) """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict):