diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py index c30a2c8689d66..b380e8b534ebd 100644 --- a/python/pyspark/sql/datasource.py +++ b/python/pyspark/sql/datasource.py @@ -15,7 +15,7 @@ # limitations under the License. # from abc import ABC, abstractmethod -from typing import final, Any, Dict, Iterator, List, Optional, Tuple, Type, Union, TYPE_CHECKING +from typing import final, Any, Dict, Iterator, List, Tuple, Type, Union, TYPE_CHECKING from pyspark import since from pyspark.sql import Row @@ -45,21 +45,12 @@ class DataSource(ABC): """ @final - def __init__( - self, - paths: List[str], - userSpecifiedSchema: Optional[StructType], - options: Dict[str, "OptionalPrimitiveType"], - ) -> None: + def __init__(self, options: Dict[str, "OptionalPrimitiveType"]) -> None: """ - Initializes the data source with user-provided information. + Initializes the data source with user-provided options. Parameters ---------- - paths : list - A list of paths to the data source. - userSpecifiedSchema : StructType, optional - The user-specified schema of the data source. options : dict A dictionary representing the options for this data source. @@ -67,8 +58,6 @@ def __init__( ----- This method should not be overridden. """ - self.paths = paths - self.userSpecifiedSchema = userSpecifiedSchema self.options = options @classmethod diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index fe6a841752746..46b9fa642fd0c 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -30,7 +30,7 @@ class MyDataSource(DataSource): ... options = dict(a=1, b=2) - ds = MyDataSource(paths=[], userSpecifiedSchema=None, options=options) + ds = MyDataSource(options=options) self.assertEqual(ds.options, options) self.assertEqual(ds.name(), "MyDataSource") with self.assertRaises(NotImplementedError): @@ -53,8 +53,7 @@ def test_in_memory_data_source(self): class InMemDataSourceReader(DataSourceReader): DEFAULT_NUM_PARTITIONS: int = 3 - def __init__(self, paths, options): - self.paths = paths + def __init__(self, options): self.options = options def partitions(self): @@ -76,7 +75,7 @@ def schema(self): return "x INT, y STRING" def reader(self, schema) -> "DataSourceReader": - return InMemDataSourceReader(self.paths, self.options) + return InMemDataSourceReader(self.options) self.spark.dataSource.register(InMemoryDataSource) df = self.spark.read.format("memory").load() @@ -91,14 +90,13 @@ def test_custom_json_data_source(self): import json class JsonDataSourceReader(DataSourceReader): - def __init__(self, paths, options): - self.paths = paths + def __init__(self, options): self.options = options - def partitions(self): - return iter(self.paths) - - def read(self, path): + def read(self, partition): + path = self.options.get("path") + if path is None: + raise Exception("path is not specified") with open(path, "r") as file: for line in file.readlines(): if line.strip(): @@ -114,28 +112,18 @@ def schema(self): return "name STRING, age INT" def reader(self, schema) -> "DataSourceReader": - return JsonDataSourceReader(self.paths, self.options) + return JsonDataSourceReader(self.options) self.spark.dataSource.register(JsonDataSource) path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json") path2 = os.path.join(SPARK_HOME, "python/test_support/sql/people1.json") - df1 = self.spark.read.format("my-json").load(path1) - self.assertEqual(df1.rdd.getNumPartitions(), 1) assertDataFrameEqual( - df1, + self.spark.read.format("my-json").load(path1), [Row(name="Michael", age=None), Row(name="Andy", age=30), Row(name="Justin", age=19)], ) - - df2 = self.spark.read.format("my-json").load([path1, path2]) - self.assertEqual(df2.rdd.getNumPartitions(), 2) assertDataFrameEqual( - df2, - [ - Row(name="Michael", age=None), - Row(name="Andy", age=30), - Row(name="Justin", age=19), - Row(name="Jonathan", age=None), - ], + self.spark.read.format("my-json").load(path2), + [Row(name="Jonathan", age=None)], ) diff --git a/python/pyspark/sql/worker/create_data_source.py b/python/pyspark/sql/worker/create_data_source.py index 6a9ef79b7c18d..1ba4dc9e8a3cb 100644 --- a/python/pyspark/sql/worker/create_data_source.py +++ b/python/pyspark/sql/worker/create_data_source.py @@ -17,7 +17,7 @@ import inspect import os import sys -from typing import IO, List +from typing import IO from pyspark.accumulators import _accumulatorRegistry from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, PySparkTypeError @@ -55,7 +55,6 @@ def main(infile: IO, outfile: IO) -> None: The JVM sends the following information to this process: - a `DataSource` class representing the data source to be created. - a provider name in string. - - a list of paths in string. - an optional user-specified schema in json string. - a dictionary of options in string. @@ -107,12 +106,6 @@ def main(infile: IO, outfile: IO) -> None: }, ) - # Receive the paths. - num_paths = read_int(infile) - paths: List[str] = [] - for _ in range(num_paths): - paths.append(utf8_deserializer.loads(infile)) - # Receive the user-specified schema user_specified_schema = None if read_bool(infile): @@ -136,11 +129,7 @@ def main(infile: IO, outfile: IO) -> None: # Instantiate a data source. try: - data_source = data_source_cls( - paths=paths, - userSpecifiedSchema=user_specified_schema, # type: ignore - options=options, - ) + data_source = data_source_cls(options=options) except Exception as e: raise PySparkRuntimeError( error_class="PYTHON_DATA_SOURCE_CREATE_ERROR", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ef447e8a80102..7fadbbfac687f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -244,9 +244,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = { val builder = sparkSession.sharedState.dataSourceManager.lookupDataSource(source) - // Unless the legacy path option behavior is enabled, the extraOptions here - // should not include "path" or "paths" as keys. - val plan = builder(sparkSession, source, paths, userSpecifiedSchema, extraOptions) + // Add `path` and `paths` options to the extra options if specified. + val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, paths: _*) + val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath) Dataset.ofRows(sparkSession, plan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index 72a9e6497aca5..a8c9c892b8b0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -35,7 +35,6 @@ class DataSourceManager { private type DataSourceBuilder = ( SparkSession, // Spark session String, // provider name - Seq[String], // paths Option[StructType], // user specified schema CaseInsensitiveMap[String] // options ) => LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index c4e7bf23cace7..3dde20ac44e7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -152,7 +152,7 @@ private[sql] object DataSourceV2Utils extends Logging { } private lazy val objectMapper = new ObjectMapper() - private def getOptionsWithPaths( + def getOptionsWithPaths( extraOptions: CaseInsensitiveMap[String], paths: String*): CaseInsensitiveMap[String] = { if (paths.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index 0e7eb056f434c..7044ef65c638c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -42,12 +42,11 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { def builder( sparkSession: SparkSession, provider: String, - paths: Seq[String], userSpecifiedSchema: Option[StructType], options: CaseInsensitiveMap[String]): LogicalPlan = { val runner = new UserDefinedPythonDataSourceRunner( - dataSourceCls, provider, paths, userSpecifiedSchema, options) + dataSourceCls, provider, userSpecifiedSchema, options) val result = runner.runInPython() val pickledDataSourceInstance = result.dataSource @@ -68,10 +67,9 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { def apply( sparkSession: SparkSession, provider: String, - paths: Seq[String] = Seq.empty, userSpecifiedSchema: Option[StructType] = None, options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)): DataFrame = { - val plan = builder(sparkSession, provider, paths, userSpecifiedSchema, options) + val plan = builder(sparkSession, provider, userSpecifiedSchema, options) Dataset.ofRows(sparkSession, plan) } } @@ -89,7 +87,6 @@ case class PythonDataSourceCreationResult( class UserDefinedPythonDataSourceRunner( dataSourceCls: PythonFunction, provider: String, - paths: Seq[String], userSpecifiedSchema: Option[StructType], options: CaseInsensitiveMap[String]) extends PythonPlannerRunner[PythonDataSourceCreationResult](dataSourceCls) { @@ -103,10 +100,6 @@ class UserDefinedPythonDataSourceRunner( // Send the provider name PythonWorkerUtils.writeUTF(provider, dataOut) - // Send the paths - dataOut.writeInt(paths.length) - paths.foreach(PythonWorkerUtils.writeUTF(_, dataOut)) - // Send the user-specified schema, if provided dataOut.writeBoolean(userSpecifiedSchema.isDefined) userSpecifiedSchema.map(_.json).foreach(PythonWorkerUtils.writeUTF(_, dataOut)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index 22a1e5250cd95..bd0b08cbec8b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -160,13 +160,20 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { val dataSourceScript = s""" |from pyspark.sql.datasource import DataSource, DataSourceReader + |import json + | |class SimpleDataSourceReader(DataSourceReader): - | def __init__(self, paths, options): - | self.paths = paths + | def __init__(self, options): | self.options = options | | def partitions(self): - | return iter(self.paths) + | if "paths" in self.options: + | paths = json.loads(self.options["paths"]) + | elif "path" in self.options: + | paths = [self.options["path"]] + | else: + | paths = [] + | return paths | | def read(self, path): | yield (path, 1) @@ -180,11 +187,10 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { | return "id STRING, value INT" | | def reader(self, schema): - | return SimpleDataSourceReader(self.paths, self.options) + | return SimpleDataSourceReader(self.options) |""".stripMargin val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) spark.dataSource.registerPython("test", dataSource) - checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1))) checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1))) checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), Row("2", 1)))