Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
catalog.ParquetConversions ::
catalog.CreateTables ::
catalog.PreInsertionCasts ::
catalog.WriteToDirs ::
ExtractPythonUDFs ::
ResolveHiveWindowFunction ::
PreInsertCastAndRename ::
Expand Down Expand Up @@ -515,7 +516,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
}

@transient
private val hivePlanner = new SparkPlanner with HiveStrategies {
private[hive] val hivePlanner = new SparkPlanner with HiveStrategies {
val hiveContext = self

override def strategies: Seq[Strategy] = experimental.extraStrategies ++ Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,32 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
}
}

/**
* Resolve hive.WriteToDirectory node,to set the properties
* of columns and columns.types in tableDesc.
*/
object WriteToDirs extends Rule[LogicalPlan] with HiveInspectors {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add scala doc about what this class does?


def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Wait until children are resolved.
case p: LogicalPlan if !p.childrenResolved => p

case WriteToDirectory(path, child, isLocal, tableDesc)
if !tableDesc.getProperties.containsKey("columns.types") =>
// generate column name and related type info as hive style
val Array(cols, types) = child.output.foldLeft(Array("", ""))((r, a) => {
r(0) = r(0) + a.name + ","
r(1) = r(1) + a.dataType.toTypeInfo.getTypeName + ":"
r
})
tableDesc.getProperties.setProperty("columns", cols.dropRight(1))
tableDesc.getProperties.setProperty("columns.types", types.dropRight(1))
WriteToDirectory(path, child, isLocal, tableDesc)
case WriteToDirectory(path, child, isLocal, tableDesc) =>
execution.WriteToDirectory(path, hive.executePlan(child).executedPlan, isLocal, tableDesc)
}
}

/**
* Casts input data to correct data types according to table definition before inserting into
* that table.
Expand Down
106 changes: 105 additions & 1 deletion sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.hadoop.hive.ql.{ErrorMsg, Context}
import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo}
import org.apache.hadoop.hive.ql.lib.Node
import org.apache.hadoop.hive.ql.parse._
import org.apache.hadoop.hive.ql.plan.PlanUtils
import org.apache.hadoop.hive.ql.plan.{TableDesc, PlanUtils}
import org.apache.hadoop.hive.ql.session.SessionState

import org.apache.spark.Logging
Expand Down Expand Up @@ -77,6 +77,22 @@ private[hive] case class CreateTableAsSelect(
childrenResolved
}

/**
* Logical node for "INSERT OVERWRITE [LOCAL] DIRECTORY directory
* [ROW FORMAT row_format] STORED AS file_format SELECT ... FROM ..."
* @param path the target path to write data.
* @param child the child logical plan.
* @param isLocal whether to write data to local file system.
* @param desc describe the write property such as file format.
*/
private[hive] case class WriteToDirectory(
path: String,
child: LogicalPlan,
isLocal: Boolean,
desc: TableDesc) extends UnaryNode with Command {
override def output: Seq[Attribute] = Seq.empty[Attribute]
}

/** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */
private[hive] object HiveQl extends Logging {
protected val nativeCommands = Seq(
Expand Down Expand Up @@ -1210,6 +1226,19 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) =>
query

case Token(destinationToken(),
Token("TOK_DIR", path :: formats) :: Nil) =>
var isLocal = false
formats.collect {
case Token("LOCAL", others) =>
isLocal = true
}
WriteToDirectory(
BaseSemanticAnalyzer.unescapeSQLString(path.getText),
query,
isLocal,
parseTableDesc(formats))

case Token(destinationToken(),
Token("TOK_TAB",
tableArgs) :: Nil) =>
Expand Down Expand Up @@ -1678,6 +1707,81 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
}
}

def parseTableDesc(nodeList: Seq[ASTNode]): TableDesc = {
import org.apache.hadoop.hive.ql.plan._

val createTableDesc = new CreateTableDesc()

nodeList.collect {
case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) =>
child.getText().toLowerCase(Locale.ENGLISH) match {
case "orc" =>
createTableDesc.setOutputFormat("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")
createTableDesc.setSerName("org.apache.hadoop.hive.ql.io.orc.OrcSerde")

case "parquet" =>
createTableDesc
.setOutputFormat("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")
createTableDesc
.setSerName("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")

case "rcfile" =>
createTableDesc.setOutputFormat("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")
createTableDesc.setSerName(hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTRCFILESERDE))

case "textfile" =>
createTableDesc
.setOutputFormat("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")

case "sequencefile" =>
createTableDesc.setOutputFormat("org.apache.hadoop.mapred.SequenceFileOutputFormat")

case _ =>
throw new SemanticException(
s"Unrecognized file format in STORED AS clause: ${child.getText}")
}

case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) =>
val serdeParams = new java.util.HashMap[String, String]()
child match {
case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) =>
val fieldDelim = BaseSemanticAnalyzer.unescapeSQLString (rowChild1.getText())
serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim)
serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim)
if (rowChild2.length > 1) {
val fieldEscape = BaseSemanticAnalyzer.unescapeSQLString (rowChild2(0).getText)
serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape)
}
case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) =>
val collItemDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText)
serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim)
case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) =>
val mapKeyDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText)
serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim)
case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) =>
val lineDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText)
if (!(lineDelim == "\n") && !(lineDelim == "10")) {
throw new AnalysisException(
SemanticAnalyzer.generateErrorMessage(
rowChild,
ErrorMsg.LINES_TERMINATED_BY_NON_NEWLINE.getMsg))
}
serdeParams.put(serdeConstants.LINE_DELIM, lineDelim)

case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) =>
val nullFormat = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText)
// TODO support the nullFormat
case _ => assert(false)
}
createTableDesc.setSerdeProps(serdeParams)

case _ => // Unsupport features
}
// Note: we do not know the columns and column types when parsing, so here
// just input `null` for column types. column types will be set in analyzer.
PlanUtils.getDefaultTableDesc(createTableDesc, "", null)
}

def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0)
: StringBuilder = {
node match {
Expand Down
101 changes: 101 additions & 0 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/SaveAsHiveFile.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive

import scala.collection.JavaConversions._

import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.hadoop.hive.serde2.Serializer
import org.apache.hadoop.mapred.{FileOutputFormat, JobConf}
import org.apache.hadoop.hive.serde2.objectinspector.{StructObjectInspector, ObjectInspectorUtils}
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.util.SerializableJobConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
import org.apache.spark.{Logging, SparkContext, TaskContext}

/**
* A trait for subclasses that write data using arbitrary SerDes to a file system .
*/
private[hive] trait SaveAsHiveFile extends HiveInspectors with Logging {
def newSerializer(tableDesc: TableDesc): Serializer = {
val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer]
serializer.initialize(null, tableDesc.getProperties)
serializer
}

def saveAsHiveFile(
sparkContext: SparkContext,
rdd: RDD[InternalRow],
schema: StructType,
dataTypes: Array[DataType],
valueClass: Class[_],
fileSinkConf: FileSinkDesc,
conf: SerializableJobConf,
writerContainer: SparkHiveWriterContainer): Unit = {
assert(valueClass != null, "Output value class not set")
conf.value.setOutputValueClass(valueClass)

val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName
assert(outputFileFormatClassName != null, "Output format class not set")
conf.value.set("mapred.output.format.class", outputFileFormatClassName)

FileOutputFormat.setOutputPath(
conf.value,
SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value))
log.debug("Saving as hadoop file of type " + valueClass.getSimpleName)

writerContainer.driverSideSetup()
sparkContext.runJob(rdd, writeToFile _)
writerContainer.commitJob()

// Note that this function is executed on executor side
def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = {
val serializer = newSerializer(fileSinkConf.getTableInfo)
val standardOI = ObjectInspectorUtils
.getStandardObjectInspector(
fileSinkConf.getTableInfo.getDeserializer.getObjectInspector,
ObjectInspectorCopyOption.JAVA)
.asInstanceOf[StructObjectInspector]

val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray
val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt)}
val outputData = new Array[Any](fieldOIs.length)

writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)

iterator.foreach { row =>
var i = 0
while (i < fieldOIs.length) {
outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
i += 1
}

writerContainer
.getLocalFileWriter(row, schema)
.write(serializer.serialize(outputData, standardOI))
}

writerContainer.close()
}
}
}
Loading