Skip to content

Commit 923d975

Browse files
amroidcpcloud
authored andcommitted
FEAT: Spark tests
Author: amroid <amr@dt3.org> Closes ibis-project#1830 from amroid/spark-tests and squashes the following commits: 7430d44 [amroid] Small changes from PR comments 54f948c [amroid] Tests now pass 6bae62c [amroid] Add changes from PR review 8cb80ae [amroid] Correct nullable behavior for spark to ibis type translation, fix tests for Spark backend ebd8f08 [amroid] Break up spark to ibis type conversion into separate dt.dtype registered functions 1083348 [amroid] Merge branch 'master' of https://github.com/ibis-project/ibis into spark-tests 8c398ca [amroid] incorporated suggestions for spark-client 87a2ece [amroid] changed imports to not fail CI c1b122f [amroid] added string concat 2aa6d7a [amroid] test_temporal now mostly works with Spark backend with some known issues, added floor option in convert_unit in util.py 45e80c6 [amroid] test_param now works with Spark backend 2ea6ac8 [amroid] test_numeric now works with Spark backend, changed impala compiler implementation of _number_literal_format a8ecb0e [amroid] test_column, test_generic now work with Spark backend (test_column required no changes) f934e40 [amroid] test_client now works with Spark backend, fixed mistake in test_sql in test_client.py c96cb58 [amroid] test_client now works with Spark backend, fixed mistake in test_sql in test_client.py 3b30726 [amroid] added pyspark>=2.4.3 to requirements-3.x-dev.yml files 0560ada [amroid] test_array now works with Spark backend, fixed Spark table creation and tests bc9fff7 [amroid] test_aggregation now works with Spark backend, changed base compiler rewrite for any, notany, all, notall to use max and min instead of sum 54422ef [amroid] test_string now works with Spark backend, added xfails and xpasses to regex tests in test_string 0d7bb67 [amroid] fixed SparkUnion bug d55c64a [amroid] Merge branch 'spark-client' into spark-tests dc3608b [amroid] fixed PR changes 78f8fc1 [amroid] added Spark subclass of Backend 66067d9 [amroid] added kwargs for SparkClient and SparkContext bbb7244 [amroid] added spark client, compiler, some unit tests
1 parent 35d3f4d commit 923d975

18 files changed

Lines changed: 584 additions & 119 deletions

conftest.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import pytest
55

6+
import ibis
7+
68
collect_ignore = ['setup.py']
79

810

@@ -18,3 +20,116 @@ def data_directory():
1820
pytest.skip('test data directory not found')
1921

2022
return datadir
23+
24+
25+
@pytest.fixture(scope='session')
26+
def spark_client_testing(data_directory):
27+
pytest.importorskip('pyspark')
28+
29+
import pyspark.sql.types as pt
30+
31+
client = ibis.spark.connect()
32+
33+
df_functional_alltypes = client._session.read.csv(
34+
path=str(data_directory / 'functional_alltypes.csv'),
35+
schema=pt.StructType([
36+
pt.StructField('index', pt.IntegerType(), True),
37+
pt.StructField('Unnamed: 0', pt.IntegerType(), True),
38+
pt.StructField('id', pt.IntegerType(), True),
39+
# cast below, Spark can't read 0/1 as bool
40+
pt.StructField('bool_col', pt.ByteType(), True),
41+
pt.StructField('tinyint_col', pt.ByteType(), True),
42+
pt.StructField('smallint_col', pt.ShortType(), True),
43+
pt.StructField('int_col', pt.IntegerType(), True),
44+
pt.StructField('bigint_col', pt.LongType(), True),
45+
pt.StructField('float_col', pt.FloatType(), True),
46+
pt.StructField('double_col', pt.DoubleType(), True),
47+
pt.StructField('date_string_col', pt.StringType(), True),
48+
pt.StructField('string_col', pt.StringType(), True),
49+
pt.StructField('timestamp_col', pt.TimestampType(), True),
50+
pt.StructField('year', pt.IntegerType(), True),
51+
pt.StructField('month', pt.IntegerType(), True),
52+
]),
53+
mode='FAILFAST',
54+
header=True,
55+
)
56+
df_functional_alltypes = df_functional_alltypes.withColumn(
57+
"bool_col", df_functional_alltypes["bool_col"].cast("boolean"))
58+
df_functional_alltypes.createOrReplaceTempView('functional_alltypes')
59+
60+
df_batting = client._session.read.csv(
61+
path=str(data_directory / 'batting.csv'),
62+
schema=pt.StructType([
63+
pt.StructField('playerID', pt.StringType(), True),
64+
pt.StructField('yearID', pt.IntegerType(), True),
65+
pt.StructField('stint', pt.IntegerType(), True),
66+
pt.StructField('teamID', pt.StringType(), True),
67+
pt.StructField('lgID', pt.StringType(), True),
68+
pt.StructField('G', pt.IntegerType(), True),
69+
pt.StructField('AB', pt.DoubleType(), True),
70+
pt.StructField('R', pt.DoubleType(), True),
71+
pt.StructField('H', pt.DoubleType(), True),
72+
pt.StructField('X2B', pt.DoubleType(), True),
73+
pt.StructField('X3B', pt.DoubleType(), True),
74+
pt.StructField('HR', pt.DoubleType(), True),
75+
pt.StructField('RBI', pt.DoubleType(), True),
76+
pt.StructField('SB', pt.DoubleType(), True),
77+
pt.StructField('CS', pt.DoubleType(), True),
78+
pt.StructField('BB', pt.DoubleType(), True),
79+
pt.StructField('SO', pt.DoubleType(), True),
80+
pt.StructField('IBB', pt.DoubleType(), True),
81+
pt.StructField('HBP', pt.DoubleType(), True),
82+
pt.StructField('SH', pt.DoubleType(), True),
83+
pt.StructField('SF', pt.DoubleType(), True),
84+
pt.StructField('GIDP', pt.DoubleType(), True),
85+
]),
86+
header=True,
87+
)
88+
df_batting.createOrReplaceTempView('batting')
89+
90+
df_awards_players = client._session.read.csv(
91+
path=str(data_directory / 'awards_players.csv'),
92+
schema=pt.StructType([
93+
pt.StructField('playerID', pt.StringType(), True),
94+
pt.StructField('awardID', pt.StringType(), True),
95+
pt.StructField('yearID', pt.IntegerType(), True),
96+
pt.StructField('lgID', pt.StringType(), True),
97+
pt.StructField('tie', pt.StringType(), True),
98+
pt.StructField('notes', pt.StringType(), True),
99+
]),
100+
header=True,
101+
)
102+
df_awards_players.createOrReplaceTempView('awards_players')
103+
104+
df_simple = client._session.createDataFrame([(1, 'a')], ['foo', 'bar'])
105+
df_simple.createOrReplaceTempView('simple')
106+
107+
df_struct = client._session.createDataFrame(
108+
[((1, 2, 'a'),)],
109+
['struct_col']
110+
)
111+
df_struct.createOrReplaceTempView('struct')
112+
113+
df_nested_types = client._session.createDataFrame(
114+
[
115+
(
116+
[1, 2],
117+
[[3, 4], [5, 6]],
118+
{'a' : [[2, 4], [3, 5]]},
119+
)
120+
],
121+
[
122+
'list_of_ints',
123+
'list_of_list_of_ints',
124+
'map_string_list_of_list_of_ints'
125+
]
126+
)
127+
df_nested_types.createOrReplaceTempView('nested_types')
128+
129+
df_complicated = client._session.createDataFrame(
130+
[({(1, 3) : [[2, 4], [3, 5]]},)],
131+
['map_tuple_list_of_list_of_ints']
132+
)
133+
df_complicated.createOrReplaceTempView('complicated')
134+
135+
return client

ibis/impala/compiler.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
import itertools
3+
import math
34
from io import StringIO
45
from operator import add, mul, sub
56
from typing import Optional
@@ -631,10 +632,18 @@ def _string_literal_format(translator, expr):
631632

632633
def _number_literal_format(translator, expr):
633634
value = expr.op().value
634-
formatted = repr(value)
635635

636-
if formatted in {'nan', 'inf', '-inf'}:
637-
return "CAST({!r} AS DOUBLE)".format(formatted)
636+
if math.isfinite(value):
637+
formatted = repr(value)
638+
else:
639+
if math.isnan(value):
640+
formatted_val = 'NaN'
641+
elif math.isinf(value):
642+
if value > 0:
643+
formatted_val = 'Infinity'
644+
else:
645+
formatted_val = '-Infinity'
646+
formatted = "CAST({!r} AS DOUBLE)".format(formatted_val)
638647

639648
return formatted
640649

ibis/impala/tests/test_exprs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,10 +357,10 @@ def test_any_all(self):
357357
bool_expr = t.f == 0
358358

359359
cases = [
360-
(bool_expr.any(), 'sum(`f` = 0) > 0'),
361-
(-bool_expr.any(), 'sum(`f` = 0) = 0'),
362-
(bool_expr.all(), 'sum(`f` = 0) = count(*)'),
363-
(-bool_expr.all(), 'sum(`f` = 0) < count(*)'),
360+
(bool_expr.any(), 'max(`f` = 0)'),
361+
(-bool_expr.any(), 'max(`f` = 0) = FALSE'),
362+
(bool_expr.all(), 'min(`f` = 0)'),
363+
(-bool_expr.all(), 'min(`f` = 0) = FALSE'),
364364
]
365365
self._check_expr_cases(cases)
366366

ibis/spark/api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ibis.spark.client import SparkClient
2+
from ibis.spark.compiler import dialect # noqa: F401
23

34

45
def connect(**kwargs):
@@ -9,4 +10,10 @@ def connect(**kwargs):
910
"""
1011
client = SparkClient(**kwargs)
1112

13+
# Spark internally stores timestamps as UTC values, and timestamp data that
14+
# is brought in without a specified time zone is converted as local time to
15+
# UTC with microsecond resolution.
16+
# https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html#timestamp-with-time-zone-semantics
17+
client._session.conf.set('spark.sql.session.timeZone', 'UTC')
18+
1219
return client

ibis/spark/client.py

Lines changed: 68 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from collections import OrderedDict
2-
31
import pyspark as ps
42
import pyspark.sql.types as pt
53
import regex as re
4+
from pkg_resources import parse_version
65

76
import ibis.common as com
87
import ibis.expr.datatypes as dt
@@ -11,56 +10,75 @@
1110
from ibis.client import Database, Query, SQLClient
1211
from ibis.spark import compiler as comp
1312

14-
_DTYPE_TO_IBIS_TYPE = {
15-
pt.NullType : dt.null,
16-
pt.StringType : dt.string,
17-
pt.BinaryType : dt.binary,
18-
pt.BooleanType : dt.boolean,
19-
pt.DateType : dt.date,
20-
pt.TimestampType : dt.timestamp,
21-
pt.DoubleType : dt.double,
22-
pt.FloatType : dt.float,
23-
pt.ByteType : dt.int8,
24-
pt.IntegerType : dt.int32,
25-
pt.LongType : dt.int64,
26-
pt.ShortType : dt.int16,
13+
# maps pyspark type class to ibis type class
14+
_SPARK_DTYPE_TO_IBIS_DTYPE = {
15+
pt.NullType : dt.Null,
16+
pt.StringType : dt.String,
17+
pt.BinaryType : dt.Binary,
18+
pt.BooleanType : dt.Boolean,
19+
pt.DateType : dt.Date,
20+
pt.DoubleType : dt.Double,
21+
pt.FloatType : dt.Float,
22+
pt.ByteType : dt.Int8,
23+
pt.IntegerType : dt.Int32,
24+
pt.LongType : dt.Int64,
25+
pt.ShortType : dt.Int16,
2726
}
2827

2928

3029
@dt.dtype.register(pt.DataType)
31-
def spark_type_to_ibis_dtype(spark_type_obj):
32-
"""Convert Spark SQL types to ibis types."""
33-
34-
if isinstance(spark_type_obj, pt.DecimalType):
35-
precision = spark_type_obj.precision
36-
scale = spark_type_obj.scale
37-
ibis_type = dt.Decimal(precision, scale)
38-
elif isinstance(spark_type_obj, pt.ArrayType):
39-
value_type = dt.dtype(spark_type_obj.elementType)
40-
nullable = spark_type_obj.containsNull
41-
ibis_type = dt.Array(value_type, nullable)
42-
elif isinstance(spark_type_obj, pt.MapType):
43-
key_type = dt.dtype(spark_type_obj.keyType)
44-
value_type = dt.dtype(spark_type_obj.valueType)
45-
nullable = spark_type_obj.valueContainsNull
46-
ibis_type = dt.Map(key_type, value_type, nullable)
47-
elif isinstance(spark_type_obj, pt.StructType):
48-
names = spark_type_obj.names
49-
fields = spark_type_obj.fields
50-
ibis_types = [dt.dtype(f.dataType) for f in fields]
51-
ibis_type = dt.Struct(names, ibis_types)
52-
else:
53-
ibis_type = _DTYPE_TO_IBIS_TYPE.get(type(spark_type_obj))
54-
55-
return ibis_type
30+
def spark_dtype_to_ibis_dtype(spark_type_obj, nullable=True):
31+
"""Convert Spark SQL type objects to ibis type objects."""
32+
ibis_type_class = _SPARK_DTYPE_TO_IBIS_DTYPE.get(type(spark_type_obj))
33+
return ibis_type_class(nullable=nullable)
34+
35+
36+
@dt.dtype.register(pt.TimestampType)
37+
def spark_timestamp_dtype_to_ibis_dtype(spark_type_obj, nullable=True):
38+
return dt.Timestamp(nullable=nullable)
39+
40+
41+
@dt.dtype.register(pt.DecimalType)
42+
def spark_decimal_dtype_to_ibis_dtype(spark_type_obj, nullable=True):
43+
precision = spark_type_obj.precision
44+
scale = spark_type_obj.scale
45+
return dt.Decimal(precision, scale, nullable=nullable)
46+
47+
48+
@dt.dtype.register(pt.ArrayType)
49+
def spark_array_dtype_to_ibis_dtype(spark_type_obj, nullable=True):
50+
value_type = dt.dtype(
51+
spark_type_obj.elementType,
52+
nullable=spark_type_obj.containsNull
53+
)
54+
return dt.Array(value_type, nullable=nullable)
55+
56+
57+
@dt.dtype.register(pt.MapType)
58+
def spark_map_dtype_to_ibis_dtype(spark_type_obj, nullable=True):
59+
key_type = dt.dtype(spark_type_obj.keyType)
60+
value_type = dt.dtype(
61+
spark_type_obj.valueType,
62+
nullable=spark_type_obj.valueContainsNull
63+
)
64+
return dt.Map(key_type, value_type, nullable=nullable)
65+
66+
67+
@dt.dtype.register(pt.StructType)
68+
def spark_struct_dtype_to_ibis_dtype(spark_type_obj, nullable=True):
69+
names = spark_type_obj.names
70+
fields = spark_type_obj.fields
71+
ibis_types = [dt.dtype(f.dataType, nullable=f.nullable) for f in fields]
72+
return dt.Struct(names, ibis_types, nullable=nullable)
5673

5774

5875
@sch.infer.register(ps.sql.dataframe.DataFrame)
5976
def spark_dataframe_schema(df):
6077
"""Infer the schema of a Spark SQL `DataFrame` object."""
61-
fields = OrderedDict((el.name, dt.dtype(el.dataType)) for el in df.schema)
78+
# df.schema is a pt.StructType
79+
schema_struct = dt.dtype(df.schema)
6280

63-
return sch.schema(fields)
81+
return sch.schema(schema_struct.names, schema_struct.types)
6482

6583

6684
class SparkCursor:
@@ -180,6 +198,10 @@ def current_database(self):
180198
def _get_table_schema(self, table_name):
181199
return self.get_schema(table_name)
182200

201+
def _get_schema_using_query(self, query):
202+
cur = self._execute(query, results=True)
203+
return spark_dataframe_schema(cur.query)
204+
183205
def list_tables(self, like=None, database=None):
184206
"""
185207
List tables in the current (or indicated) database. Like the SHOW
@@ -274,3 +296,7 @@ def get_schema(self, table_name, database=None):
274296
df = self._session.table(table_name)
275297

276298
return sch.infer(df)
299+
300+
@property
301+
def version(self):
302+
return parse_version(ps.__version__)

0 commit comments

Comments
 (0)