From 9e65181090d3e0ed530fdb126e26e021111c99ae Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 26 Aug 2015 19:10:37 -0700 Subject: [PATCH 1/6] Add regression test for count() on Redshift table. --- .../databricks/spark/redshift/RedshiftIntegrationSuite.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala b/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala index 61dde0bf..26543942 100644 --- a/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala +++ b/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala @@ -240,6 +240,10 @@ class RedshiftIntegrationSuite TestUtils.expectedData) } + test("count() on DataFrame created from a Redshift table") { + assert(sqlContext.sql("select * from test_table").count() === TestUtils.expectedData.length) + } + test("Can load output when 'dbtable' is a subquery wrapped in parentheses") { // scalastyle:off val query = From 217589ab0b59c40f5aa56300e9604ac72f358be7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 26 Aug 2015 19:29:53 -0700 Subject: [PATCH 2/6] Fix for count bug --- .../spark/redshift/RedshiftRelation.scala | 80 ++++++++++++------- 1 file changed, 52 insertions(+), 28 deletions(-) diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala index c961c4cf..3b26626f 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala @@ -49,50 +49,74 @@ private[redshift] case class RedshiftRelation( } } - override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { - // Always quote column names: - val columns = requiredColumns.map(col => s""""$col"""").mkString(", ") - val whereClause = FilterPushdown.buildWhereClause(schema, filters) - unloadToTemp(columns, whereClause) - makeRdd(pruneSchema(schema, requiredColumns)) - } - override def insert(data: DataFrame, overwrite: Boolean): Unit = { val updatedParams = Parameters.mergeParameters(params.parameters updated ("overwrite", overwrite.toString)) new RedshiftWriter(jdbcWrapper).saveToRedshift(sqlContext, data, updatedParams) } - private def unloadToTemp(columnList: String = "*", whereClause: String = ""): Unit = { - val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl) - val unloadSql = unloadStmnt(columnList, whereClause) - val statement = conn.prepareStatement(unloadSql) + private val tableNameOrSubquery: String = { + val unescaped = params.query.map(q => s"($q)").orElse(params.table).get + unescaped.replace("'", "\\'") + } - statement.execute() - conn.close() + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { + if (requiredColumns.isEmpty) { + // In the special case where no columns were requested, issue a `count(*)` against Redshift + // rather than unloading data. + val whereClause = FilterPushdown.buildWhereClause(schema, filters) + val countQuery = s"SELECT count(*) FROM $tableNameOrSubquery $whereClause" + logInfo(countQuery) + val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl) + try { + val results = conn.prepareStatement(countQuery).executeQuery() + if (results.next()) { + val numRows = results.getLong(1) + val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt + val emptyRow = Row.empty + sqlContext.sparkContext.parallelize(1L to numRows).map(_ => emptyRow) + } else { + throw new IllegalStateException("Could not read count from Redshift") + } + } finally { + conn.close() + } + } else { + // Unload data from Redshift into a temporary directory in S3: + val unloadSql = buildUnloadStmt(requiredColumns, filters) + logInfo(unloadSql) + val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl) + try { + conn.prepareStatement(unloadSql).execute() + } finally { + conn.close() + } + // Create a DataFrame to read the unloaded data: + val sc = sqlContext.sparkContext + val hadoopConf = new Configuration(sc.hadoopConfiguration) + params.setCredentials(hadoopConf) + val rdd = sc.newAPIHadoopFile(params.tempPath, classOf[RedshiftInputFormat], + classOf[java.lang.Long], classOf[Array[String]], hadoopConf) + val prunedSchema = pruneSchema(schema, requiredColumns) + rdd.values.mapPartitions { iter => + val converter: Array[String] => Row = Conversions.rowConverter(prunedSchema) + iter.map(converter) + } + } } - private def unloadStmnt(columnList: String, whereClause: String) : String = { + private def buildUnloadStmt(requiredColumns: Array[String], filters: Array[Filter]): String = { + assert(!requiredColumns.isEmpty) + // Always quote column names: + val columnList = requiredColumns.map(col => s""""$col"""").mkString(", ") + val whereClause = FilterPushdown.buildWhereClause(schema, filters) val credsString = params.credentialsString(sqlContext.sparkContext.hadoopConfiguration) - val tableNameOrSubquery: String = { - val unescaped = params.query.map(q => s"($q)").orElse(params.table).get - unescaped.replace("'", "\\'") - } val query = s"SELECT $columnList FROM $tableNameOrSubquery $whereClause" val fixedUrl = Utils.fixS3Url(params.tempPath) s"UNLOAD ('$query') TO '$fixedUrl' WITH CREDENTIALS '$credsString' ESCAPE ALLOWOVERWRITE" } - private def makeRdd(schema: StructType): RDD[Row] = { - val sc = sqlContext.sparkContext - val hadoopConf = new Configuration(sc.hadoopConfiguration) - params.setCredentials(hadoopConf) - val rdd = sc.newAPIHadoopFile(params.tempPath, classOf[RedshiftInputFormat], - classOf[java.lang.Long], classOf[Array[String]], hadoopConf) - rdd.values.map(Conversions.rowConverter(schema)) - } - private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { val fieldMap = Map(schema.fields.map(x => x.name -> x): _*) new StructType(columns.map(name => fieldMap(name))) From 07d00ca2b4614efbd523406a3442f0f681db3fad Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 26 Aug 2015 19:57:07 -0700 Subject: [PATCH 3/6] Rename rowConverter to createRowConverter --- .../scala/com/databricks/spark/redshift/Conversions.scala | 2 +- .../com/databricks/spark/redshift/RedshiftRelation.scala | 2 +- .../com/databricks/spark/redshift/ConversionsSuite.scala | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/main/scala/com/databricks/spark/redshift/Conversions.scala b/src/main/scala/com/databricks/spark/redshift/Conversions.scala index e912bc15..c3190cca 100644 --- a/src/main/scala/com/databricks/spark/redshift/Conversions.scala +++ b/src/main/scala/com/databricks/spark/redshift/Conversions.scala @@ -121,7 +121,7 @@ private[redshift] object Conversions { * Return a function that will convert arrays of strings conforming to * the given schema to Row instances */ - def rowConverter(schema: StructType): (Array[String]) => Row = { + def createRowConverter(schema: StructType): (Array[String]) => Row = { convertRow(schema, _: Array[String]) } } diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala index 3b26626f..7590596e 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala @@ -99,7 +99,7 @@ private[redshift] case class RedshiftRelation( classOf[java.lang.Long], classOf[Array[String]], hadoopConf) val prunedSchema = pruneSchema(schema, requiredColumns) rdd.values.mapPartitions { iter => - val converter: Array[String] => Row = Conversions.rowConverter(prunedSchema) + val converter: Array[String] => Row = Conversions.createRowConverter(prunedSchema) iter.map(converter) } } diff --git a/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala b/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala index c78af44d..9425fb3c 100644 --- a/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala +++ b/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.{StructField, BooleanType, StructType} class ConversionsSuite extends FunSuite { test("Data should be correctly converted") { - val convertRow = Conversions.rowConverter(TestUtils.testSchema) + val convertRow = Conversions.createRowConverter(TestUtils.testSchema) val doubleMin = Double.MinValue.toString val longMax = Long.MaxValue.toString // scalastyle:off @@ -53,13 +53,13 @@ class ConversionsSuite extends FunSuite { } test("Row conversion handles null values") { - val convertRow = Conversions.rowConverter(TestUtils.testSchema) + val convertRow = Conversions.createRowConverter(TestUtils.testSchema) val emptyRow = List.fill(TestUtils.testSchema.length)(null).toArray[String] assert(convertRow(emptyRow) === Row(emptyRow: _*)) } test("Booleans are correctly converted") { - val convertRow = Conversions.rowConverter(StructType(StructField("a", BooleanType) :: Nil)) + val convertRow = Conversions.createRowConverter(StructType(Seq(StructField("a", BooleanType)))) assert(convertRow(Array("t")) === Row(true)) assert(convertRow(Array("f")) === Row(false)) assert(convertRow(Array(null)) === Row(null)) From 77cf4cc006a25cc499570e0190c3327296bcece7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 27 Aug 2015 10:18:04 -0700 Subject: [PATCH 4/6] Actually pass parallelism properly --- .../scala/com/databricks/spark/redshift/RedshiftRelation.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala index 7590596e..47555f4a 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala @@ -74,7 +74,7 @@ private[redshift] case class RedshiftRelation( val numRows = results.getLong(1) val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt val emptyRow = Row.empty - sqlContext.sparkContext.parallelize(1L to numRows).map(_ => emptyRow) + sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow) } else { throw new IllegalStateException("Could not read count from Redshift") } From dc8806e6e783f91f0f63f62e00bccb2fc23eb8d5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 27 Aug 2015 10:18:48 -0700 Subject: [PATCH 5/6] Use checkAnswer instead of assert; add failing test to expose subquery bug. --- .../redshift/RedshiftIntegrationSuite.scala | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala b/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala index 26543942..c4429df0 100644 --- a/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala +++ b/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala @@ -241,7 +241,25 @@ class RedshiftIntegrationSuite } test("count() on DataFrame created from a Redshift table") { - assert(sqlContext.sql("select * from test_table").count() === TestUtils.expectedData.length) + checkAnswer( + sqlContext.sql("select count(*) from test_table"), + Seq(Row(TestUtils.expectedData.length)) + ) + } + + test("count() on DataFrame created from a Redshift query") { + val loadedDf = sqlContext.read + .format("com.databricks.spark.redshift") + .option("url", jdbcUrl) + // scalastyle:off + .option("query", s"select * from $test_table where teststring = 'Unicode''s樂趣'") + // scalastyle:on + .option("tempdir", tempDir) + .load() + checkAnswer( + loadedDf.selectExpr("count(*)"), + Seq(Row(1)) + ) } test("Can load output when 'dbtable' is a subquery wrapped in parentheses") { From 180ee086cb25ae604e7a4d621e0681dc6940d440 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 27 Aug 2015 11:07:19 -0700 Subject: [PATCH 6/6] Fix escaping; add comment --- .../spark/redshift/RedshiftRelation.scala | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala index 47555f4a..cc632d17 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala @@ -55,16 +55,12 @@ private[redshift] case class RedshiftRelation( new RedshiftWriter(jdbcWrapper).saveToRedshift(sqlContext, data, updatedParams) } - private val tableNameOrSubquery: String = { - val unescaped = params.query.map(q => s"($q)").orElse(params.table).get - unescaped.replace("'", "\\'") - } - override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { if (requiredColumns.isEmpty) { // In the special case where no columns were requested, issue a `count(*)` against Redshift // rather than unloading data. val whereClause = FilterPushdown.buildWhereClause(schema, filters) + val tableNameOrSubquery = params.query.map(q => s"($q)").orElse(params.table).get val countQuery = s"SELECT count(*) FROM $tableNameOrSubquery $whereClause" logInfo(countQuery) val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl) @@ -111,7 +107,15 @@ private[redshift] case class RedshiftRelation( val columnList = requiredColumns.map(col => s""""$col"""").mkString(", ") val whereClause = FilterPushdown.buildWhereClause(schema, filters) val credsString = params.credentialsString(sqlContext.sparkContext.hadoopConfiguration) - val query = s"SELECT $columnList FROM $tableNameOrSubquery $whereClause" + val query = { + // Since the query passed to UNLOAD will be enclosed in single quotes, we need to escape + // any single quotes that appear in the query itself + val tableNameOrSubquery: String = { + val unescaped = params.query.map(q => s"($q)").orElse(params.table).get + unescaped.replace("'", "\\'") + } + s"SELECT $columnList FROM $tableNameOrSubquery $whereClause" + } val fixedUrl = Utils.fixS3Url(params.tempPath) s"UNLOAD ('$query') TO '$fixedUrl' WITH CREDENTIALS '$credsString' ESCAPE ALLOWOVERWRITE"