Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,28 @@ class RedshiftIntegrationSuite
TestUtils.expectedData)
}

test("count() on DataFrame created from a Redshift table") {
checkAnswer(
sqlContext.sql("select count(*) from test_table"),
Seq(Row(TestUtils.expectedData.length))
)
}

test("count() on DataFrame created from a Redshift query") {
val loadedDf = sqlContext.read
.format("com.databricks.spark.redshift")
.option("url", jdbcUrl)
// scalastyle:off
.option("query", s"select * from $test_table where teststring = 'Unicode''s樂趣'")
// scalastyle:on
.option("tempdir", tempDir)
.load()
checkAnswer(
loadedDf.selectExpr("count(*)"),
Seq(Row(1))
)
}

test("Can load output when 'dbtable' is a subquery wrapped in parentheses") {
// scalastyle:off
val query =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ private[redshift] object Conversions {
* Return a function that will convert arrays of strings conforming to
* the given schema to Row instances
*/
def rowConverter(schema: StructType): (Array[String]) => Row = {
def createRowConverter(schema: StructType): (Array[String]) => Row = {
convertRow(schema, _: Array[String])
}
}
86 changes: 57 additions & 29 deletions src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,50 +49,78 @@ private[redshift] case class RedshiftRelation(
}
}

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
// Always quote column names:
val columns = requiredColumns.map(col => s""""$col"""").mkString(", ")
val whereClause = FilterPushdown.buildWhereClause(schema, filters)
unloadToTemp(columns, whereClause)
makeRdd(pruneSchema(schema, requiredColumns))
}

override def insert(data: DataFrame, overwrite: Boolean): Unit = {
val updatedParams =
Parameters.mergeParameters(params.parameters updated ("overwrite", overwrite.toString))
new RedshiftWriter(jdbcWrapper).saveToRedshift(sqlContext, data, updatedParams)
}

private def unloadToTemp(columnList: String = "*", whereClause: String = ""): Unit = {
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl)
val unloadSql = unloadStmnt(columnList, whereClause)
val statement = conn.prepareStatement(unloadSql)

statement.execute()
conn.close()
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
if (requiredColumns.isEmpty) {
// In the special case where no columns were requested, issue a `count(*)` against Redshift
// rather than unloading data.
val whereClause = FilterPushdown.buildWhereClause(schema, filters)
val tableNameOrSubquery = params.query.map(q => s"($q)").orElse(params.table).get
val countQuery = s"SELECT count(*) FROM $tableNameOrSubquery $whereClause"
logInfo(countQuery)
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl)
try {
val results = conn.prepareStatement(countQuery).executeQuery()
if (results.next()) {
val numRows = results.getLong(1)
val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, this is unused. I'll update to pass it into parallelize.

val emptyRow = Row.empty
sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow)
} else {
throw new IllegalStateException("Could not read count from Redshift")
}
} finally {
conn.close()
}
} else {
// Unload data from Redshift into a temporary directory in S3:
val unloadSql = buildUnloadStmt(requiredColumns, filters)
logInfo(unloadSql)
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl)
try {
conn.prepareStatement(unloadSql).execute()
} finally {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I just inlined the bodies of these functions after some refactoring. I think the new code is slightly easier to follow.

conn.close()
}
// Create a DataFrame to read the unloaded data:
val sc = sqlContext.sparkContext
val hadoopConf = new Configuration(sc.hadoopConfiguration)
params.setCredentials(hadoopConf)
val rdd = sc.newAPIHadoopFile(params.tempPath, classOf[RedshiftInputFormat],
classOf[java.lang.Long], classOf[Array[String]], hadoopConf)
val prunedSchema = pruneSchema(schema, requiredColumns)
rdd.values.mapPartitions { iter =>
val converter: Array[String] => Row = Conversions.createRowConverter(prunedSchema)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I took the opportunity to do a bit of minor performance optimization by re-using the row converter rather than creating a new one for each row. There are many other inefficiencies / opportunities for optimization in the converter itself, but I'll address those in a separate patch.

iter.map(converter)
}
}
}

private def unloadStmnt(columnList: String, whereClause: String) : String = {
private def buildUnloadStmt(requiredColumns: Array[String], filters: Array[Filter]): String = {
assert(!requiredColumns.isEmpty)
// Always quote column names:
val columnList = requiredColumns.map(col => s""""$col"""").mkString(", ")
val whereClause = FilterPushdown.buildWhereClause(schema, filters)
val credsString = params.credentialsString(sqlContext.sparkContext.hadoopConfiguration)
val tableNameOrSubquery: String = {
val unescaped = params.query.map(q => s"($q)").orElse(params.table).get
unescaped.replace("'", "\\'")
val query = {
// Since the query passed to UNLOAD will be enclosed in single quotes, we need to escape
// any single quotes that appear in the query itself
val tableNameOrSubquery: String = {
val unescaped = params.query.map(q => s"($q)").orElse(params.table).get
unescaped.replace("'", "\\'")
}
s"SELECT $columnList FROM $tableNameOrSubquery $whereClause"
}
val query = s"SELECT $columnList FROM $tableNameOrSubquery $whereClause"
val fixedUrl = Utils.fixS3Url(params.tempPath)

s"UNLOAD ('$query') TO '$fixedUrl' WITH CREDENTIALS '$credsString' ESCAPE ALLOWOVERWRITE"
}

private def makeRdd(schema: StructType): RDD[Row] = {
val sc = sqlContext.sparkContext
val hadoopConf = new Configuration(sc.hadoopConfiguration)
params.setCredentials(hadoopConf)
val rdd = sc.newAPIHadoopFile(params.tempPath, classOf[RedshiftInputFormat],
classOf[java.lang.Long], classOf[Array[String]], hadoopConf)
rdd.values.map(Conversions.rowConverter(schema))
}

private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
val fieldMap = Map(schema.fields.map(x => x.name -> x): _*)
new StructType(columns.map(name => fieldMap(name)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.{StructField, BooleanType, StructType}
class ConversionsSuite extends FunSuite {

test("Data should be correctly converted") {
val convertRow = Conversions.rowConverter(TestUtils.testSchema)
val convertRow = Conversions.createRowConverter(TestUtils.testSchema)
val doubleMin = Double.MinValue.toString
val longMax = Long.MaxValue.toString
// scalastyle:off
Expand All @@ -53,13 +53,13 @@ class ConversionsSuite extends FunSuite {
}

test("Row conversion handles null values") {
val convertRow = Conversions.rowConverter(TestUtils.testSchema)
val convertRow = Conversions.createRowConverter(TestUtils.testSchema)
val emptyRow = List.fill(TestUtils.testSchema.length)(null).toArray[String]
assert(convertRow(emptyRow) === Row(emptyRow: _*))
}

test("Booleans are correctly converted") {
val convertRow = Conversions.rowConverter(StructType(StructField("a", BooleanType) :: Nil))
val convertRow = Conversions.createRowConverter(StructType(Seq(StructField("a", BooleanType))))
assert(convertRow(Array("t")) === Row(true))
assert(convertRow(Array("f")) === Row(false))
assert(convertRow(Array(null)) === Row(null))
Expand Down