From 4f1105166631fca2ccfb7cadf0b64dba4268f4e6 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sat, 16 Oct 2021 18:09:49 +0800 Subject: [PATCH] add approx distinct test --- python/src/functions.rs | 2 ++ python/tests/test_aggregation.py | 47 ++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 python/tests/test_aggregation.py diff --git a/python/src/functions.rs b/python/src/functions.rs index cecf28d2e7780..22a5ce44bdc11 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -224,6 +224,7 @@ define_unary_function!(avg); define_unary_function!(min); define_unary_function!(max); define_unary_function!(count); +define_unary_function!(approx_distinct); #[pyclass(name = "Volatility", module = "datafusion.functions")] #[derive(Clone)] @@ -323,6 +324,7 @@ pub fn init(module: &PyModule) -> PyResult<()> { module.add_class::()?; module.add_function(wrap_pyfunction!(abs, module)?)?; module.add_function(wrap_pyfunction!(acos, module)?)?; + module.add_function(wrap_pyfunction!(approx_distinct, module)?)?; module.add_function(wrap_pyfunction!(array, module)?)?; module.add_function(wrap_pyfunction!(ascii, module)?)?; module.add_function(wrap_pyfunction!(asin, module)?)?; diff --git a/python/tests/test_aggregation.py b/python/tests/test_aggregation.py new file mode 100644 index 0000000000000..f0996f9e06d9f --- /dev/null +++ b/python/tests/test_aggregation.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pyarrow as pa +import pytest +from datafusion import ExecutionContext +from datafusion import functions as f + + +@pytest.fixture +def df(): + ctx = ExecutionContext() + + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 4, 6])], + names=["a", "b"], + ) + return ctx.create_dataframe([[batch]]) + + +def test_built_in_aggregation(df): + col_a = f.col("a") + col_b = f.col("b") + df = df.aggregate( + [], + [f.max(col_a), f.min(col_a), f.count(col_a), f.approx_distinct(col_b)], + ) + result = df.collect()[0] + assert result.column(0) == pa.array([3]) + assert result.column(1) == pa.array([1]) + assert result.column(2) == pa.array([3], type=pa.uint64()) + assert result.column(3) == pa.array([2], type=pa.uint64())