diff --git a/README.md b/README.md index 865351bd..a173cf88 100644 --- a/README.md +++ b/README.md @@ -71,12 +71,13 @@ val df: DataFrame = sqlContext.read // Data Source API to write the data back to another table df.write - .format("com.databricks.spark.redshift") + .format("com.databricks.spark.redshift") .option("url", "jdbc:redshift://redshifthost:5439/database?user=username&password=pass") .option("dbtable", "my_table_copy") .option("tempdir", "s3n://path/for/temp/data") - .mode("error") - .save() + .option("avrocompression", "snappy") + .mode("error") + .save() ``` #### Python @@ -105,12 +106,13 @@ df = sql_context.read \ # Write back to a table df.write \ - .format("com.databricks.spark.redshift") \ - .option("url", "jdbc:redshift://redshifthost:5439/database?user=username&password=pass") \ - .option("dbtable", "my_table_copy") \ - .option("tempdir", "s3n://path/for/temp/data") \ - .mode("error") \ - .save() + .format("com.databricks.spark.redshift") \ + .option("url", "jdbc:redshift://redshifthost:5439/database?user=username&password=pass") \ + .option("dbtable", "my_table_copy") \ + .option("tempdir", "s3n://path/for/temp/data") \ + .option("avrocompression", "snappy") + .mode("error") \ + .save() ``` #### SQL @@ -120,6 +122,7 @@ CREATE TABLE my_table USING com.databricks.spark.redshift OPTIONS (dbtable 'my_table', tempdir 's3n://my_bucket/tmp', + avrocompression 'snappy', url 'jdbc:redshift://host:port/db?user=username&password=pass'); ``` @@ -299,6 +302,24 @@ It may be useful to have some GRANT commands or similar run here when l table, the changes will be reverted and the backup table restored if post actions fail.

+ + avrocompression + No + No compression (unless set in Hadoop config) + +

Sets the compression codec to use on the Avro data to be loaded into Redshift. This overwrites the avro.output.codec +key in the Hadoop configuration with the specified value and also sets mapred.output.compress = true and +mapred.output.compression.type = BLOCK. If left unset (or set to null or an empty string) it will leave +the Hadoop configuration unchanged.

+

Valid settings are:

+ + + ## Additional configuration options diff --git a/src/it/scala/com/databricks/spark/redshift/CompressionIntegrationSuite.scala b/src/it/scala/com/databricks/spark/redshift/CompressionIntegrationSuite.scala new file mode 100644 index 00000000..317a73ca --- /dev/null +++ b/src/it/scala/com/databricks/spark/redshift/CompressionIntegrationSuite.scala @@ -0,0 +1,89 @@ +/* + * Copyright 2015 TouchType Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.spark.redshift + +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +/** + * End-to-end tests for loading/unloading data from Redshift using Avro + * compression. + */ +class CompressionIntegrationSuite extends IntegrationSuiteBase { + + test("roundtrip save and load with Avro snappy compression") { + val tableName = s"roundtrip_save_and_load$randomSuffix" + val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), + StructType(StructField("a", IntegerType) :: Nil)) + try { + df.write + .format("com.databricks.spark.redshift") + .option("url", jdbcUrl) + .option("dbtable", tableName) + .option("tempdir", tempDir) + .option("avrocompression", "snappy") + .mode(SaveMode.ErrorIfExists) + .save() + + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + val loadedDf = sqlContext.read + .format("com.databricks.spark.redshift") + .option("url", jdbcUrl) + .option("dbtable", tableName) + .option("tempdir", tempDir) + .option("avrocompression", "snappy") + .load() + assert(loadedDf.schema.length === 1) + assert(loadedDf.columns === Seq("a")) + checkAnswer(loadedDf, Seq(Row(1))) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } + + test("roundtrip save and load with Avro deflate compression") { + val tableName = s"roundtrip_save_and_load$randomSuffix" + val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), + StructType(StructField("a", IntegerType) :: Nil)) + try { + df.write + .format("com.databricks.spark.redshift") + .option("url", jdbcUrl) + .option("dbtable", tableName) + .option("tempdir", tempDir) + .option("avrocompression", "deflate") + .mode(SaveMode.ErrorIfExists) + .save() + + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + val loadedDf = sqlContext.read + .format("com.databricks.spark.redshift") + .option("url", jdbcUrl) + .option("dbtable", tableName) + .option("tempdir", tempDir) + .option("avrocompression", "deflate") + .load() + assert(loadedDf.schema.length === 1) + assert(loadedDf.columns === Seq("a")) + checkAnswer(loadedDf, Seq(Row(1))) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } +} diff --git a/src/main/scala/com/databricks/spark/redshift/Parameters.scala b/src/main/scala/com/databricks/spark/redshift/Parameters.scala index a228a901..6e633fcc 100644 --- a/src/main/scala/com/databricks/spark/redshift/Parameters.scala +++ b/src/main/scala/com/databricks/spark/redshift/Parameters.scala @@ -34,7 +34,8 @@ private[redshift] object Parameters { "overwrite" -> "false", "diststyle" -> "EVEN", "usestagingtable" -> "true", - "postactions" -> ";" + "postactions" -> ";", + "avrocompression" -> "" ) /** @@ -187,5 +188,12 @@ private[redshift] object Parameters { sessionToken <- parameters.get("temporary_aws_session_token") ) yield new BasicSessionCredentials(accessKey, secretAccessKey, sessionToken) } + + /** + * When nonempty/non-null sets the compression codec to use for writing Avro data. + * + * Defaults to disabled (i.e. whatever is set in Hadoop config). + */ + def avrocompression: String = parameters("avrocompression") } } diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala index 7e8edca4..85e82c50 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala @@ -272,6 +272,14 @@ private[redshift] class RedshiftWriter( val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl) val tempDir = params.createPerQueryTempDir() + + if (params.avrocompression != null && params.avrocompression.nonEmpty) { + val conf = sqlContext.sparkContext.hadoopConfiguration + conf.set("mapred.output.compress", "true") + conf.set("mapred.output.compression.type", "BLOCK") + conf.set("avro.output.codec", params.avrocompression) + } + try { if (params.overwrite && params.useStagingTable) { withStagingTable(conn, params.table.get, stagingTable => { diff --git a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala index 67654cc7..29dd1b28 100644 --- a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala +++ b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala @@ -348,6 +348,49 @@ class RedshiftSourceSuite checkAnswer(written, TestUtils.expectedDataWithConvertedTimesAndDates) } + test("DefaultSource serializes data as Avro when avrocompression is enabled") { + + val params = defaultParams ++ Map( + "postactions" -> "GRANT SELECT ON %s TO jeremy", + "diststyle" -> "KEY", + "distkey" -> "testInt", + "avrocompression" -> "snappy") + + val expectedCommands = + Seq("DROP TABLE IF EXISTS test_table_staging_.*".r, + "CREATE TABLE IF NOT EXISTS test_table_staging.* DISTSTYLE KEY DISTKEY \\(testInt\\).*".r, + "COPY test_table_staging_.*".r, + "GRANT SELECT ON test_table_staging.+ TO jeremy".r, + "ALTER TABLE test_table RENAME TO test_table_backup_.*".r, + "ALTER TABLE test_table_staging_.* RENAME TO test_table".r, + "DROP TABLE IF EXISTS test_table_backup.*".r) + + val jdbcWrapper = mockJdbcWrapper(params("url"), expectedCommands) + + (jdbcWrapper.tableExists _) + .expects(*, "test_table") + .returning(true) + .anyNumberOfTimes() + + (jdbcWrapper.schemaString _) + .expects(*) + .returning("schema") + .anyNumberOfTimes() + + val relation = RedshiftRelation( + jdbcWrapper, _ => mockS3Client, Parameters.mergeParameters(params), None)(testSqlContext) + relation.asInstanceOf[InsertableRelation].insert(expectedDataDF, true) + + // Make sure we wrote the data out ready for Redshift load, in the expected formats + // The data should have been written to a random subdirectory of `tempdir`. Since we clear + // `tempdir` between every unit test, there should only be one directory here. + // Note: this does not actually test that the written files are properly compressed. + assert(s3FileSystem.listStatus(new Path(s3TempDir)).length === 1) + val dirWithAvroFiles = s3FileSystem.listStatus(new Path(s3TempDir)).head.getPath.toUri.toString + val written = testSqlContext.read.format("com.databricks.spark.avro").load(dirWithAvroFiles) + checkAnswer(written, TestUtils.expectedDataWithConvertedTimesAndDates) + } + test("Cannot write table with column names that become ambiguous under case insensitivity") { val jdbcWrapper = mock[JDBCWrapper] val mockedConnection = mock[Connection]