diff --git a/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala b/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala index 61dde0bf..c4429df0 100644 --- a/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala +++ b/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala @@ -240,6 +240,28 @@ class RedshiftIntegrationSuite TestUtils.expectedData) } + test("count() on DataFrame created from a Redshift table") { + checkAnswer( + sqlContext.sql("select count(*) from test_table"), + Seq(Row(TestUtils.expectedData.length)) + ) + } + + test("count() on DataFrame created from a Redshift query") { + val loadedDf = sqlContext.read + .format("com.databricks.spark.redshift") + .option("url", jdbcUrl) + // scalastyle:off + .option("query", s"select * from $test_table where teststring = 'Unicode''s樂趣'") + // scalastyle:on + .option("tempdir", tempDir) + .load() + checkAnswer( + loadedDf.selectExpr("count(*)"), + Seq(Row(1)) + ) + } + test("Can load output when 'dbtable' is a subquery wrapped in parentheses") { // scalastyle:off val query = diff --git a/src/main/scala/com/databricks/spark/redshift/Conversions.scala b/src/main/scala/com/databricks/spark/redshift/Conversions.scala index e912bc15..c3190cca 100644 --- a/src/main/scala/com/databricks/spark/redshift/Conversions.scala +++ b/src/main/scala/com/databricks/spark/redshift/Conversions.scala @@ -121,7 +121,7 @@ private[redshift] object Conversions { * Return a function that will convert arrays of strings conforming to * the given schema to Row instances */ - def rowConverter(schema: StructType): (Array[String]) => Row = { + def createRowConverter(schema: StructType): (Array[String]) => Row = { convertRow(schema, _: Array[String]) } } diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala index c961c4cf..cc632d17 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala @@ -49,50 +49,78 @@ private[redshift] case class RedshiftRelation( } } - override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { - // Always quote column names: - val columns = requiredColumns.map(col => s""""$col"""").mkString(", ") - val whereClause = FilterPushdown.buildWhereClause(schema, filters) - unloadToTemp(columns, whereClause) - makeRdd(pruneSchema(schema, requiredColumns)) - } - override def insert(data: DataFrame, overwrite: Boolean): Unit = { val updatedParams = Parameters.mergeParameters(params.parameters updated ("overwrite", overwrite.toString)) new RedshiftWriter(jdbcWrapper).saveToRedshift(sqlContext, data, updatedParams) } - private def unloadToTemp(columnList: String = "*", whereClause: String = ""): Unit = { - val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl) - val unloadSql = unloadStmnt(columnList, whereClause) - val statement = conn.prepareStatement(unloadSql) - - statement.execute() - conn.close() + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { + if (requiredColumns.isEmpty) { + // In the special case where no columns were requested, issue a `count(*)` against Redshift + // rather than unloading data. + val whereClause = FilterPushdown.buildWhereClause(schema, filters) + val tableNameOrSubquery = params.query.map(q => s"($q)").orElse(params.table).get + val countQuery = s"SELECT count(*) FROM $tableNameOrSubquery $whereClause" + logInfo(countQuery) + val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl) + try { + val results = conn.prepareStatement(countQuery).executeQuery() + if (results.next()) { + val numRows = results.getLong(1) + val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt + val emptyRow = Row.empty + sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow) + } else { + throw new IllegalStateException("Could not read count from Redshift") + } + } finally { + conn.close() + } + } else { + // Unload data from Redshift into a temporary directory in S3: + val unloadSql = buildUnloadStmt(requiredColumns, filters) + logInfo(unloadSql) + val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl) + try { + conn.prepareStatement(unloadSql).execute() + } finally { + conn.close() + } + // Create a DataFrame to read the unloaded data: + val sc = sqlContext.sparkContext + val hadoopConf = new Configuration(sc.hadoopConfiguration) + params.setCredentials(hadoopConf) + val rdd = sc.newAPIHadoopFile(params.tempPath, classOf[RedshiftInputFormat], + classOf[java.lang.Long], classOf[Array[String]], hadoopConf) + val prunedSchema = pruneSchema(schema, requiredColumns) + rdd.values.mapPartitions { iter => + val converter: Array[String] => Row = Conversions.createRowConverter(prunedSchema) + iter.map(converter) + } + } } - private def unloadStmnt(columnList: String, whereClause: String) : String = { + private def buildUnloadStmt(requiredColumns: Array[String], filters: Array[Filter]): String = { + assert(!requiredColumns.isEmpty) + // Always quote column names: + val columnList = requiredColumns.map(col => s""""$col"""").mkString(", ") + val whereClause = FilterPushdown.buildWhereClause(schema, filters) val credsString = params.credentialsString(sqlContext.sparkContext.hadoopConfiguration) - val tableNameOrSubquery: String = { - val unescaped = params.query.map(q => s"($q)").orElse(params.table).get - unescaped.replace("'", "\\'") + val query = { + // Since the query passed to UNLOAD will be enclosed in single quotes, we need to escape + // any single quotes that appear in the query itself + val tableNameOrSubquery: String = { + val unescaped = params.query.map(q => s"($q)").orElse(params.table).get + unescaped.replace("'", "\\'") + } + s"SELECT $columnList FROM $tableNameOrSubquery $whereClause" } - val query = s"SELECT $columnList FROM $tableNameOrSubquery $whereClause" val fixedUrl = Utils.fixS3Url(params.tempPath) s"UNLOAD ('$query') TO '$fixedUrl' WITH CREDENTIALS '$credsString' ESCAPE ALLOWOVERWRITE" } - private def makeRdd(schema: StructType): RDD[Row] = { - val sc = sqlContext.sparkContext - val hadoopConf = new Configuration(sc.hadoopConfiguration) - params.setCredentials(hadoopConf) - val rdd = sc.newAPIHadoopFile(params.tempPath, classOf[RedshiftInputFormat], - classOf[java.lang.Long], classOf[Array[String]], hadoopConf) - rdd.values.map(Conversions.rowConverter(schema)) - } - private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { val fieldMap = Map(schema.fields.map(x => x.name -> x): _*) new StructType(columns.map(name => fieldMap(name))) diff --git a/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala b/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala index c78af44d..9425fb3c 100644 --- a/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala +++ b/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.{StructField, BooleanType, StructType} class ConversionsSuite extends FunSuite { test("Data should be correctly converted") { - val convertRow = Conversions.rowConverter(TestUtils.testSchema) + val convertRow = Conversions.createRowConverter(TestUtils.testSchema) val doubleMin = Double.MinValue.toString val longMax = Long.MaxValue.toString // scalastyle:off @@ -53,13 +53,13 @@ class ConversionsSuite extends FunSuite { } test("Row conversion handles null values") { - val convertRow = Conversions.rowConverter(TestUtils.testSchema) + val convertRow = Conversions.createRowConverter(TestUtils.testSchema) val emptyRow = List.fill(TestUtils.testSchema.length)(null).toArray[String] assert(convertRow(emptyRow) === Row(emptyRow: _*)) } test("Booleans are correctly converted") { - val convertRow = Conversions.rowConverter(StructType(StructField("a", BooleanType) :: Nil)) + val convertRow = Conversions.createRowConverter(StructType(Seq(StructField("a", BooleanType)))) assert(convertRow(Array("t")) === Row(true)) assert(convertRow(Array("f")) === Row(false)) assert(convertRow(Array(null)) === Row(null))