Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution

import java.util.Locale
import javax.ws.rs.core.UriBuilder

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -753,19 +754,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
*
* Expected format:
* {{{
* INSERT OVERWRITE DIRECTORY
* INSERT OVERWRITE [LOCAL] DIRECTORY
* [path]
* [OPTIONS table_property_list]
* select_statement;
* }}}
*/
Comment thread
ajithme marked this conversation as resolved.
Outdated
override def visitInsertOverwriteDir(
ctx: InsertOverwriteDirContext): InsertDirParams = withOrigin(ctx) {
if (ctx.LOCAL != null) {
Comment thread
ajithme marked this conversation as resolved.
Outdated
throw new ParseException(
"LOCAL is not supported in INSERT OVERWRITE DIRECTORY to data source", ctx)
}

val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty)
var storage = DataSource.buildStorageFormatFromOptions(options)

Expand All @@ -781,6 +777,19 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
storage = storage.copy(locationUri = customLocation)
}

if (ctx.LOCAL() != null) {
// assert if directory is local when LOCAL keyword is mentioned
val scheme = Option(storage.locationUri.get.getScheme)
scheme match {
case None =>
// force scheme to be file rather than fs.default.name
val loc = Some(UriBuilder.fromUri(CatalogUtils.stringToURI(path)).scheme("file").build())
storage = storage.copy(locationUri = loc)
case Some(pathScheme) if (!pathScheme.equals("file")) =>
throw new ParseException("LOCAL is supported only with file: scheme", ctx)
}
}

val provider = ctx.tableProvider.multipartIdentifier.getText

(false, storage, Some(provider))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -820,6 +821,28 @@ class InsertSuite extends DataSourceTest with SharedSparkSession {
}
}
}

test("SPARK-29174 Support LOCAL in INSERT OVERWRITE DIRECTORY to data source") {
withTempPath { dir =>
val path = dir.toURI.getPath
sql(s"""create table tab1 ( a int) location '$path'""")
sql("insert into tab1 values(1)")
checkAnswer(sql("select * from tab1"), Seq(1).map(i => Row(i)))
sql("create table tab2 ( a int)")
sql("insert into tab2 values(2)")
checkAnswer(sql("select * from tab2"), Seq(2).map(i => Row(i)))
sql(s"""insert overwrite local directory '$path' using parquet select * from tab2""")
sql("refresh table tab1")
checkAnswer(sql("select * from tab1"), Seq(2).map(i => Row(i)))
}
}

test("SPARK-29174 fail LOCAL in INSERT OVERWRITE DIRECT remote path") {
val message = intercept[ParseException] {
sql("insert overwrite local directory 'hdfs:/abcd' using parquet select 1")
}.getMessage
assert(message.contains("LOCAL is supported only with file: scheme"))
}
}

class FileExistingTestFileSystem extends RawLocalFileSystem {
Expand Down