diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c37c338a134f3..280a443523563 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1689,6 +1689,10 @@ using the call `toPandas()` and when creating a Spark DataFrame from a Pandas Da `createDataFrame(pandas_df)`. To use Arrow when executing these calls, users need to first set the Spark configuration 'spark.sql.execution.arrow.enabled' to 'true'. This is disabled by default. +In addition, optimizations enabled by 'spark.sql.execution.arrow.enabled' could fallback automatically +to non-Arrow optimization implementation if an error occurs before the actual computation within Spark. +This can be controlled by 'spark.sql.execution.arrow.fallback.enabled'. +
{% include_example dataframe_with_arrow python/sql/arrow.py %} @@ -1800,6 +1804,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. + - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index f37777e13ee12..a24b9e1baf596 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1986,55 +1986,91 @@ def toPandas(self): timezone = None if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": + use_arrow = True try: - from pyspark.sql.types import _check_dataframe_convert_date, \ - _check_dataframe_localize_timestamps, to_arrow_schema + from pyspark.sql.types import to_arrow_schema from pyspark.sql.utils import require_minimum_pyarrow_version + require_minimum_pyarrow_version() - import pyarrow to_arrow_schema(self.schema) - tables = self._collectAsArrow() - if tables: - table = pyarrow.concat_tables(tables) - pdf = table.to_pandas() - pdf = _check_dataframe_convert_date(pdf, self.schema) - return _check_dataframe_localize_timestamps(pdf, timezone) - else: - return pd.DataFrame.from_records([], columns=self.columns) except Exception as e: - msg = ( - "Note: toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true. Please set it to false " - "to disable this.") - raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) - else: - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - dtype = {} + if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \ + .lower() == "true": + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "Attempts non-optimization as " + "'spark.sql.execution.arrow.fallback.enabled' is set to " + "true." % _exception_message(e)) + warnings.warn(msg) + use_arrow = False + else: + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "For fallback to non-optimization automatically, please set true to " + "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) + raise RuntimeError(msg) + + # Try to use Arrow optimization when the schema is supported and the required version + # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled. + if use_arrow: + try: + from pyspark.sql.types import _check_dataframe_convert_date, \ + _check_dataframe_localize_timestamps + import pyarrow + + tables = self._collectAsArrow() + if tables: + table = pyarrow.concat_tables(tables) + pdf = table.to_pandas() + pdf = _check_dataframe_convert_date(pdf, self.schema) + return _check_dataframe_localize_timestamps(pdf, timezone) + else: + return pd.DataFrame.from_records([], columns=self.columns) + except Exception as e: + # We might have to allow fallback here as well but multiple Spark jobs can + # be executed. So, simply fail in this case for now. + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed unexpectedly:\n %s\n" + "Note that 'spark.sql.execution.arrow.fallback.enabled' does " + "not have an effect in such failure in the middle of " + "computation." % _exception_message(e)) + raise RuntimeError(msg) + + # Below is toPandas without Arrow optimization. + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + # SPARK-21766: if an integer field is nullable and has null values, it can be + # inferred by pandas as float column. Once we convert the column with NaN back + # to integer type e.g., np.int16, we will hit exception. So we use the inferred + # float type, not the corrected type from the schema in this case. + if pandas_type is not None and \ + not(isinstance(field.dataType, IntegralType) and field.nullable and + pdf[field.name].isnull().any()): + dtype[field.name] = pandas_type + + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + + if timezone is None: + return pdf + else: + from pyspark.sql.types import _check_series_convert_timestamps_local_tz for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - # SPARK-21766: if an integer field is nullable and has null values, it can be - # inferred by pandas as float column. Once we convert the column with NaN back - # to integer type e.g., np.int16, we will hit exception. So we use the inferred - # float type, not the corrected type from the schema in this case. - if pandas_type is not None and \ - not(isinstance(field.dataType, IntegralType) and field.nullable and - pdf[field.name].isnull().any()): - dtype[field.name] = pandas_type - - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - - if timezone is None: - return pdf - else: - from pyspark.sql.types import _check_series_convert_timestamps_local_tz - for field in self.schema: - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if isinstance(field.dataType, TimestampType): - pdf[field.name] = \ - _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) - return pdf + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if isinstance(field.dataType, TimestampType): + pdf[field.name] = \ + _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) + return pdf def _collectAsArrow(self): """ diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index b3af9b82953f3..215bb3e5c5173 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -666,8 +666,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr try: return self._create_from_pandas_with_arrow(data, schema, timezone) except Exception as e: - warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e)) - # Fallback to create DataFrame without arrow if raise some exception + from pyspark.util import _exception_message + + if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "true") \ + .lower() == "true": + msg = ( + "createDataFrame attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "Attempts non-optimization as " + "'spark.sql.execution.arrow.fallback.enabled' is set to " + "true." % _exception_message(e)) + warnings.warn(msg) + else: + msg = ( + "createDataFrame attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "For fallback to non-optimization automatically, please set true to " + "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) + raise RuntimeError(msg) data = self._convert_from_pandas(data, schema, timezone) if isinstance(schema, StructType): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 19653072ea316..5bfc91176508a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -32,7 +32,9 @@ import datetime import array import ctypes +import warnings import py4j +from contextlib import contextmanager try: import xmlrunner @@ -48,12 +50,13 @@ else: import unittest +from pyspark.util import _exception_message + _pandas_requirement_message = None try: from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() except ImportError as e: - from pyspark.util import _exception_message # If Pandas version requirement is not satisfied, skip related tests. _pandas_requirement_message = _exception_message(e) @@ -62,7 +65,6 @@ from pyspark.sql.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() except ImportError as e: - from pyspark.util import _exception_message # If Arrow version requirement is not satisfied, skip related tests. _pyarrow_requirement_message = _exception_message(e) @@ -195,6 +197,28 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() cls.spark.stop() + @contextmanager + def sql_conf(self, pairs): + """ + A convenient context manager to test some configuration specific logic. This sets + `value` to the configuration `key` and then restores it back when it exits. + """ + assert isinstance(pairs, dict), "pairs should be a dictionary." + + keys = pairs.keys() + new_values = pairs.values() + old_values = [self.spark.conf.get(key, None) for key in keys] + for key, new_value in zip(keys, new_values): + self.spark.conf.set(key, new_value) + try: + yield + finally: + for key, old_value in zip(keys, old_values): + if old_value is None: + self.spark.conf.unset(key) + else: + self.spark.conf.set(key, old_value) + def assertPandasEqual(self, expected, result): msg = ("DataFrames are not equal: " + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + @@ -3458,6 +3482,8 @@ def setUpClass(cls): cls.spark.conf.set("spark.sql.session.timeZone", tz) cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + # Disable fallback by default to easily detect the failures. + cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false") cls.schema = StructType([ StructField("1_str_t", StringType(), True), StructField("2_int_t", IntegerType(), True), @@ -3493,20 +3519,30 @@ def create_pandas_data_frame(self): data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) return pd.DataFrame(data=data_dict) - def test_unsupported_datatype(self): + def test_toPandas_fallback_enabled(self): + import pandas as pd + + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) + df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) + with QuietTest(self.sc): + with warnings.catch_warnings(record=True) as warns: + pdf = df.toPandas() + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempts non-optimization" in _exception_message(user_warns[-1])) + self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) + + def test_toPandas_fallback_disabled(self): schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported type'): df.toPandas() - df = self.spark.createDataFrame([(None,)], schema="a binary") - with QuietTest(self.sc): - with self.assertRaisesRegexp( - Exception, - 'Unsupported type.*\nNote: toPandas attempted Arrow optimization because'): - df.toPandas() - def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + self.data) @@ -3625,7 +3661,7 @@ def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() wrong_schema = StructType(list(reversed(self.schema))) with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"): + with self.assertRaisesRegexp(RuntimeError, ".*No cast.*string.*timestamp.*"): self.spark.createDataFrame(pdf, schema=wrong_schema) def test_createDataFrame_with_names(self): @@ -3650,7 +3686,7 @@ def test_createDataFrame_column_name_encoding(self): def test_createDataFrame_with_single_data_type(self): import pandas as pd with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"): + with self.assertRaisesRegexp(RuntimeError, ".*IntegerType.*not supported.*"): self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") def test_createDataFrame_does_not_modify_input(self): @@ -3705,6 +3741,30 @@ def test_createDataFrame_with_int_col_names(self): self.assertEqual(pdf_col_names, df.columns) self.assertEqual(pdf_col_names, df_arrow.columns) + def test_createDataFrame_fallback_enabled(self): + import pandas as pd + + with QuietTest(self.sc): + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + with warnings.catch_warnings(record=True) as warns: + df = self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map") + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempts non-optimization" in _exception_message(user_warns[-1])) + self.assertEqual(df.collect(), [Row(a={u'a': 1})]) + + def test_createDataFrame_fallback_disabled(self): + import pandas as pd + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported type'): + self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map") + # Regression test for SPARK-23314 def test_timestamp_dst(self): import pandas as pd diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ce3f94618edeb..3f96112659c11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1058,7 +1058,7 @@ object SQLConf { .intConf .createWithDefault(100) - val ARROW_EXECUTION_ENABLE = + val ARROW_EXECUTION_ENABLED = buildConf("spark.sql.execution.arrow.enabled") .doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " + "for use with pyspark.sql.DataFrame.toPandas, and " + @@ -1068,6 +1068,13 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ARROW_FALLBACK_ENABLED = + buildConf("spark.sql.execution.arrow.fallback.enabled") + .doc("When true, optimizations enabled by 'spark.sql.execution.arrow.enabled' will " + + "fallback automatically to non-optimized implementations if an error occurs.") + .booleanConf + .createWithDefault(true) + val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") .doc("When using Apache Arrow, limit the maximum number of records that can be written " + @@ -1518,7 +1525,9 @@ class SQLConf extends Serializable with Logging { def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION) - def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) + def arrowEnabled: Boolean = getConf(ARROW_EXECUTION_ENABLED) + + def arrowFallbackEnabled: Boolean = getConf(ARROW_FALLBACK_ENABLED) def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)