diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index aa139cb6b0c3b..078813b7d631d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.util.Locale +import javax.ws.rs.core.UriBuilder import scala.collection.JavaConverters._ @@ -753,7 +754,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * * Expected format: * {{{ - * INSERT OVERWRITE DIRECTORY + * INSERT OVERWRITE [LOCAL] DIRECTORY * [path] * [OPTIONS table_property_list] * select_statement; @@ -761,11 +762,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { */ override def visitInsertOverwriteDir( ctx: InsertOverwriteDirContext): InsertDirParams = withOrigin(ctx) { - if (ctx.LOCAL != null) { - throw new ParseException( - "LOCAL is not supported in INSERT OVERWRITE DIRECTORY to data source", ctx) - } - val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) var storage = DataSource.buildStorageFormatFromOptions(options) @@ -781,6 +777,19 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { storage = storage.copy(locationUri = customLocation) } + if (ctx.LOCAL() != null) { + // assert if directory is local when LOCAL keyword is mentioned + val scheme = Option(storage.locationUri.get.getScheme) + scheme match { + case None => + // force scheme to be file rather than fs.default.name + val loc = Some(UriBuilder.fromUri(CatalogUtils.stringToURI(path)).scheme("file").build()) + storage = storage.copy(locationUri = loc) + case Some(pathScheme) if (!pathScheme.equals("file")) => + throw new ParseException("LOCAL is supported only with file: scheme", ctx) + } + } + val provider = ctx.tableProvider.multipartIdentifier.getText (false, storage, Some(provider)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index bcff30a51c3f5..0101803561c90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSparkSession @@ -820,6 +821,28 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } } } + + test("SPARK-29174 Support LOCAL in INSERT OVERWRITE DIRECTORY to data source") { + withTempPath { dir => + val path = dir.toURI.getPath + sql(s"""create table tab1 ( a int) location '$path'""") + sql("insert into tab1 values(1)") + checkAnswer(sql("select * from tab1"), Seq(1).map(i => Row(i))) + sql("create table tab2 ( a int)") + sql("insert into tab2 values(2)") + checkAnswer(sql("select * from tab2"), Seq(2).map(i => Row(i))) + sql(s"""insert overwrite local directory '$path' using parquet select * from tab2""") + sql("refresh table tab1") + checkAnswer(sql("select * from tab1"), Seq(2).map(i => Row(i))) + } + } + + test("SPARK-29174 fail LOCAL in INSERT OVERWRITE DIRECT remote path") { + val message = intercept[ParseException] { + sql("insert overwrite local directory 'hdfs:/abcd' using parquet select 1") + }.getMessage + assert(message.contains("LOCAL is supported only with file: scheme")) + } } class FileExistingTestFileSystem extends RawLocalFileSystem {