diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index c37c338a134f3..280a443523563 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1689,6 +1689,10 @@ using the call `toPandas()` and when creating a Spark DataFrame from a Pandas Da
`createDataFrame(pandas_df)`. To use Arrow when executing these calls, users need to first set
the Spark configuration 'spark.sql.execution.arrow.enabled' to 'true'. This is disabled by default.
+In addition, optimizations enabled by 'spark.sql.execution.arrow.enabled' could fallback automatically
+to non-Arrow optimization implementation if an error occurs before the actual computation within Spark.
+This can be controlled by 'spark.sql.execution.arrow.fallback.enabled'.
+
{% include_example dataframe_with_arrow python/sql/arrow.py %}
@@ -1800,6 +1804,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see
## Upgrading From Spark SQL 2.3 to 2.4
- Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively.
+ - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`.
## Upgrading From Spark SQL 2.2 to 2.3
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index f37777e13ee12..a24b9e1baf596 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1986,55 +1986,91 @@ def toPandas(self):
timezone = None
if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true":
+ use_arrow = True
try:
- from pyspark.sql.types import _check_dataframe_convert_date, \
- _check_dataframe_localize_timestamps, to_arrow_schema
+ from pyspark.sql.types import to_arrow_schema
from pyspark.sql.utils import require_minimum_pyarrow_version
+
require_minimum_pyarrow_version()
- import pyarrow
to_arrow_schema(self.schema)
- tables = self._collectAsArrow()
- if tables:
- table = pyarrow.concat_tables(tables)
- pdf = table.to_pandas()
- pdf = _check_dataframe_convert_date(pdf, self.schema)
- return _check_dataframe_localize_timestamps(pdf, timezone)
- else:
- return pd.DataFrame.from_records([], columns=self.columns)
except Exception as e:
- msg = (
- "Note: toPandas attempted Arrow optimization because "
- "'spark.sql.execution.arrow.enabled' is set to true. Please set it to false "
- "to disable this.")
- raise RuntimeError("%s\n%s" % (_exception_message(e), msg))
- else:
- pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
- dtype = {}
+ if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \
+ .lower() == "true":
+ msg = (
+ "toPandas attempted Arrow optimization because "
+ "'spark.sql.execution.arrow.enabled' is set to true; however, "
+ "failed by the reason below:\n %s\n"
+ "Attempts non-optimization as "
+ "'spark.sql.execution.arrow.fallback.enabled' is set to "
+ "true." % _exception_message(e))
+ warnings.warn(msg)
+ use_arrow = False
+ else:
+ msg = (
+ "toPandas attempted Arrow optimization because "
+ "'spark.sql.execution.arrow.enabled' is set to true; however, "
+ "failed by the reason below:\n %s\n"
+ "For fallback to non-optimization automatically, please set true to "
+ "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e))
+ raise RuntimeError(msg)
+
+ # Try to use Arrow optimization when the schema is supported and the required version
+ # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled.
+ if use_arrow:
+ try:
+ from pyspark.sql.types import _check_dataframe_convert_date, \
+ _check_dataframe_localize_timestamps
+ import pyarrow
+
+ tables = self._collectAsArrow()
+ if tables:
+ table = pyarrow.concat_tables(tables)
+ pdf = table.to_pandas()
+ pdf = _check_dataframe_convert_date(pdf, self.schema)
+ return _check_dataframe_localize_timestamps(pdf, timezone)
+ else:
+ return pd.DataFrame.from_records([], columns=self.columns)
+ except Exception as e:
+ # We might have to allow fallback here as well but multiple Spark jobs can
+ # be executed. So, simply fail in this case for now.
+ msg = (
+ "toPandas attempted Arrow optimization because "
+ "'spark.sql.execution.arrow.enabled' is set to true; however, "
+ "failed unexpectedly:\n %s\n"
+ "Note that 'spark.sql.execution.arrow.fallback.enabled' does "
+ "not have an effect in such failure in the middle of "
+ "computation." % _exception_message(e))
+ raise RuntimeError(msg)
+
+ # Below is toPandas without Arrow optimization.
+ pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
+
+ dtype = {}
+ for field in self.schema:
+ pandas_type = _to_corrected_pandas_type(field.dataType)
+ # SPARK-21766: if an integer field is nullable and has null values, it can be
+ # inferred by pandas as float column. Once we convert the column with NaN back
+ # to integer type e.g., np.int16, we will hit exception. So we use the inferred
+ # float type, not the corrected type from the schema in this case.
+ if pandas_type is not None and \
+ not(isinstance(field.dataType, IntegralType) and field.nullable and
+ pdf[field.name].isnull().any()):
+ dtype[field.name] = pandas_type
+
+ for f, t in dtype.items():
+ pdf[f] = pdf[f].astype(t, copy=False)
+
+ if timezone is None:
+ return pdf
+ else:
+ from pyspark.sql.types import _check_series_convert_timestamps_local_tz
for field in self.schema:
- pandas_type = _to_corrected_pandas_type(field.dataType)
- # SPARK-21766: if an integer field is nullable and has null values, it can be
- # inferred by pandas as float column. Once we convert the column with NaN back
- # to integer type e.g., np.int16, we will hit exception. So we use the inferred
- # float type, not the corrected type from the schema in this case.
- if pandas_type is not None and \
- not(isinstance(field.dataType, IntegralType) and field.nullable and
- pdf[field.name].isnull().any()):
- dtype[field.name] = pandas_type
-
- for f, t in dtype.items():
- pdf[f] = pdf[f].astype(t, copy=False)
-
- if timezone is None:
- return pdf
- else:
- from pyspark.sql.types import _check_series_convert_timestamps_local_tz
- for field in self.schema:
- # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
- if isinstance(field.dataType, TimestampType):
- pdf[field.name] = \
- _check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
- return pdf
+ # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
+ if isinstance(field.dataType, TimestampType):
+ pdf[field.name] = \
+ _check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
+ return pdf
def _collectAsArrow(self):
"""
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index b3af9b82953f3..215bb3e5c5173 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -666,8 +666,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
try:
return self._create_from_pandas_with_arrow(data, schema, timezone)
except Exception as e:
- warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e))
- # Fallback to create DataFrame without arrow if raise some exception
+ from pyspark.util import _exception_message
+
+ if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "true") \
+ .lower() == "true":
+ msg = (
+ "createDataFrame attempted Arrow optimization because "
+ "'spark.sql.execution.arrow.enabled' is set to true; however, "
+ "failed by the reason below:\n %s\n"
+ "Attempts non-optimization as "
+ "'spark.sql.execution.arrow.fallback.enabled' is set to "
+ "true." % _exception_message(e))
+ warnings.warn(msg)
+ else:
+ msg = (
+ "createDataFrame attempted Arrow optimization because "
+ "'spark.sql.execution.arrow.enabled' is set to true; however, "
+ "failed by the reason below:\n %s\n"
+ "For fallback to non-optimization automatically, please set true to "
+ "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e))
+ raise RuntimeError(msg)
data = self._convert_from_pandas(data, schema, timezone)
if isinstance(schema, StructType):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 19653072ea316..5bfc91176508a 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -32,7 +32,9 @@
import datetime
import array
import ctypes
+import warnings
import py4j
+from contextlib import contextmanager
try:
import xmlrunner
@@ -48,12 +50,13 @@
else:
import unittest
+from pyspark.util import _exception_message
+
_pandas_requirement_message = None
try:
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
except ImportError as e:
- from pyspark.util import _exception_message
# If Pandas version requirement is not satisfied, skip related tests.
_pandas_requirement_message = _exception_message(e)
@@ -62,7 +65,6 @@
from pyspark.sql.utils import require_minimum_pyarrow_version
require_minimum_pyarrow_version()
except ImportError as e:
- from pyspark.util import _exception_message
# If Arrow version requirement is not satisfied, skip related tests.
_pyarrow_requirement_message = _exception_message(e)
@@ -195,6 +197,28 @@ def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
cls.spark.stop()
+ @contextmanager
+ def sql_conf(self, pairs):
+ """
+ A convenient context manager to test some configuration specific logic. This sets
+ `value` to the configuration `key` and then restores it back when it exits.
+ """
+ assert isinstance(pairs, dict), "pairs should be a dictionary."
+
+ keys = pairs.keys()
+ new_values = pairs.values()
+ old_values = [self.spark.conf.get(key, None) for key in keys]
+ for key, new_value in zip(keys, new_values):
+ self.spark.conf.set(key, new_value)
+ try:
+ yield
+ finally:
+ for key, old_value in zip(keys, old_values):
+ if old_value is None:
+ self.spark.conf.unset(key)
+ else:
+ self.spark.conf.set(key, old_value)
+
def assertPandasEqual(self, expected, result):
msg = ("DataFrames are not equal: " +
"\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
@@ -3458,6 +3482,8 @@ def setUpClass(cls):
cls.spark.conf.set("spark.sql.session.timeZone", tz)
cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
+ # Disable fallback by default to easily detect the failures.
+ cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false")
cls.schema = StructType([
StructField("1_str_t", StringType(), True),
StructField("2_int_t", IntegerType(), True),
@@ -3493,20 +3519,30 @@ def create_pandas_data_frame(self):
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
return pd.DataFrame(data=data_dict)
- def test_unsupported_datatype(self):
+ def test_toPandas_fallback_enabled(self):
+ import pandas as pd
+
+ with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}):
+ schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
+ df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
+ with QuietTest(self.sc):
+ with warnings.catch_warnings(record=True) as warns:
+ pdf = df.toPandas()
+ # Catch and check the last UserWarning.
+ user_warns = [
+ warn.message for warn in warns if isinstance(warn.message, UserWarning)]
+ self.assertTrue(len(user_warns) > 0)
+ self.assertTrue(
+ "Attempts non-optimization" in _exception_message(user_warns[-1]))
+ self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
+
+ def test_toPandas_fallback_disabled(self):
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
df.toPandas()
- df = self.spark.createDataFrame([(None,)], schema="a binary")
- with QuietTest(self.sc):
- with self.assertRaisesRegexp(
- Exception,
- 'Unsupported type.*\nNote: toPandas attempted Arrow optimization because'):
- df.toPandas()
-
def test_null_conversion(self):
df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
self.data)
@@ -3625,7 +3661,7 @@ def test_createDataFrame_with_incorrect_schema(self):
pdf = self.create_pandas_data_frame()
wrong_schema = StructType(list(reversed(self.schema)))
with QuietTest(self.sc):
- with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"):
+ with self.assertRaisesRegexp(RuntimeError, ".*No cast.*string.*timestamp.*"):
self.spark.createDataFrame(pdf, schema=wrong_schema)
def test_createDataFrame_with_names(self):
@@ -3650,7 +3686,7 @@ def test_createDataFrame_column_name_encoding(self):
def test_createDataFrame_with_single_data_type(self):
import pandas as pd
with QuietTest(self.sc):
- with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"):
+ with self.assertRaisesRegexp(RuntimeError, ".*IntegerType.*not supported.*"):
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
def test_createDataFrame_does_not_modify_input(self):
@@ -3705,6 +3741,30 @@ def test_createDataFrame_with_int_col_names(self):
self.assertEqual(pdf_col_names, df.columns)
self.assertEqual(pdf_col_names, df_arrow.columns)
+ def test_createDataFrame_fallback_enabled(self):
+ import pandas as pd
+
+ with QuietTest(self.sc):
+ with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}):
+ with warnings.catch_warnings(record=True) as warns:
+ df = self.spark.createDataFrame(
+ pd.DataFrame([[{u'a': 1}]]), "a: map")
+ # Catch and check the last UserWarning.
+ user_warns = [
+ warn.message for warn in warns if isinstance(warn.message, UserWarning)]
+ self.assertTrue(len(user_warns) > 0)
+ self.assertTrue(
+ "Attempts non-optimization" in _exception_message(user_warns[-1]))
+ self.assertEqual(df.collect(), [Row(a={u'a': 1})])
+
+ def test_createDataFrame_fallback_disabled(self):
+ import pandas as pd
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(Exception, 'Unsupported type'):
+ self.spark.createDataFrame(
+ pd.DataFrame([[{u'a': 1}]]), "a: map")
+
# Regression test for SPARK-23314
def test_timestamp_dst(self):
import pandas as pd
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index ce3f94618edeb..3f96112659c11 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1058,7 +1058,7 @@ object SQLConf {
.intConf
.createWithDefault(100)
- val ARROW_EXECUTION_ENABLE =
+ val ARROW_EXECUTION_ENABLED =
buildConf("spark.sql.execution.arrow.enabled")
.doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " +
"for use with pyspark.sql.DataFrame.toPandas, and " +
@@ -1068,6 +1068,13 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val ARROW_FALLBACK_ENABLED =
+ buildConf("spark.sql.execution.arrow.fallback.enabled")
+ .doc("When true, optimizations enabled by 'spark.sql.execution.arrow.enabled' will " +
+ "fallback automatically to non-optimized implementations if an error occurs.")
+ .booleanConf
+ .createWithDefault(true)
+
val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH =
buildConf("spark.sql.execution.arrow.maxRecordsPerBatch")
.doc("When using Apache Arrow, limit the maximum number of records that can be written " +
@@ -1518,7 +1525,9 @@ class SQLConf extends Serializable with Logging {
def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION)
- def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE)
+ def arrowEnabled: Boolean = getConf(ARROW_EXECUTION_ENABLED)
+
+ def arrowFallbackEnabled: Boolean = getConf(ARROW_FALLBACK_ENABLED)
def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)