diff --git a/README.md b/README.md
index 865351bd..a173cf88 100644
--- a/README.md
+++ b/README.md
@@ -71,12 +71,13 @@ val df: DataFrame = sqlContext.read
// Data Source API to write the data back to another table
df.write
- .format("com.databricks.spark.redshift")
+ .format("com.databricks.spark.redshift")
.option("url", "jdbc:redshift://redshifthost:5439/database?user=username&password=pass")
.option("dbtable", "my_table_copy")
.option("tempdir", "s3n://path/for/temp/data")
- .mode("error")
- .save()
+ .option("avrocompression", "snappy")
+ .mode("error")
+ .save()
```
#### Python
@@ -105,12 +106,13 @@ df = sql_context.read \
# Write back to a table
df.write \
- .format("com.databricks.spark.redshift") \
- .option("url", "jdbc:redshift://redshifthost:5439/database?user=username&password=pass") \
- .option("dbtable", "my_table_copy") \
- .option("tempdir", "s3n://path/for/temp/data") \
- .mode("error") \
- .save()
+ .format("com.databricks.spark.redshift") \
+ .option("url", "jdbc:redshift://redshifthost:5439/database?user=username&password=pass") \
+ .option("dbtable", "my_table_copy") \
+ .option("tempdir", "s3n://path/for/temp/data") \
+ .option("avrocompression", "snappy")
+ .mode("error") \
+ .save()
```
#### SQL
@@ -120,6 +122,7 @@ CREATE TABLE my_table
USING com.databricks.spark.redshift
OPTIONS (dbtable 'my_table',
tempdir 's3n://my_bucket/tmp',
+ avrocompression 'snappy',
url 'jdbc:redshift://host:port/db?user=username&password=pass');
```
@@ -299,6 +302,24 @@ It may be useful to have some GRANT commands or similar run here when l
table, the changes will be reverted and the backup table restored if post actions fail.
+
+ | avrocompression |
+ No |
+ No compression (unless set in Hadoop config) |
+
+ Sets the compression codec to use on the Avro data to be loaded into Redshift. This overwrites the avro.output.codec
+key in the Hadoop configuration with the specified value and also sets mapred.output.compress = true and
+mapred.output.compression.type = BLOCK. If left unset (or set to null or an empty string) it will leave
+the Hadoop configuration unchanged.
+Valid settings are:
+
+ - "" (default): use compression settings from Hadoop config (usually none unless explicitly set).
+ - "snappy": use snappy compression.
+ - "deflate": use deflate (zlib) compression (better ratio but more CPU intensive than snappy).
+ - "null": disable compression.
+
+ |
+
## Additional configuration options
diff --git a/src/it/scala/com/databricks/spark/redshift/CompressionIntegrationSuite.scala b/src/it/scala/com/databricks/spark/redshift/CompressionIntegrationSuite.scala
new file mode 100644
index 00000000..317a73ca
--- /dev/null
+++ b/src/it/scala/com/databricks/spark/redshift/CompressionIntegrationSuite.scala
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2015 TouchType Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.databricks.spark.redshift
+
+import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+
+/**
+ * End-to-end tests for loading/unloading data from Redshift using Avro
+ * compression.
+ */
+class CompressionIntegrationSuite extends IntegrationSuiteBase {
+
+ test("roundtrip save and load with Avro snappy compression") {
+ val tableName = s"roundtrip_save_and_load$randomSuffix"
+ val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
+ StructType(StructField("a", IntegerType) :: Nil))
+ try {
+ df.write
+ .format("com.databricks.spark.redshift")
+ .option("url", jdbcUrl)
+ .option("dbtable", tableName)
+ .option("tempdir", tempDir)
+ .option("avrocompression", "snappy")
+ .mode(SaveMode.ErrorIfExists)
+ .save()
+
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ val loadedDf = sqlContext.read
+ .format("com.databricks.spark.redshift")
+ .option("url", jdbcUrl)
+ .option("dbtable", tableName)
+ .option("tempdir", tempDir)
+ .option("avrocompression", "snappy")
+ .load()
+ assert(loadedDf.schema.length === 1)
+ assert(loadedDf.columns === Seq("a"))
+ checkAnswer(loadedDf, Seq(Row(1)))
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+
+ test("roundtrip save and load with Avro deflate compression") {
+ val tableName = s"roundtrip_save_and_load$randomSuffix"
+ val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
+ StructType(StructField("a", IntegerType) :: Nil))
+ try {
+ df.write
+ .format("com.databricks.spark.redshift")
+ .option("url", jdbcUrl)
+ .option("dbtable", tableName)
+ .option("tempdir", tempDir)
+ .option("avrocompression", "deflate")
+ .mode(SaveMode.ErrorIfExists)
+ .save()
+
+ assert(DefaultJDBCWrapper.tableExists(conn, tableName))
+ val loadedDf = sqlContext.read
+ .format("com.databricks.spark.redshift")
+ .option("url", jdbcUrl)
+ .option("dbtable", tableName)
+ .option("tempdir", tempDir)
+ .option("avrocompression", "deflate")
+ .load()
+ assert(loadedDf.schema.length === 1)
+ assert(loadedDf.columns === Seq("a"))
+ checkAnswer(loadedDf, Seq(Row(1)))
+ } finally {
+ conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
+ conn.commit()
+ }
+ }
+}
diff --git a/src/main/scala/com/databricks/spark/redshift/Parameters.scala b/src/main/scala/com/databricks/spark/redshift/Parameters.scala
index a228a901..6e633fcc 100644
--- a/src/main/scala/com/databricks/spark/redshift/Parameters.scala
+++ b/src/main/scala/com/databricks/spark/redshift/Parameters.scala
@@ -34,7 +34,8 @@ private[redshift] object Parameters {
"overwrite" -> "false",
"diststyle" -> "EVEN",
"usestagingtable" -> "true",
- "postactions" -> ";"
+ "postactions" -> ";",
+ "avrocompression" -> ""
)
/**
@@ -187,5 +188,12 @@ private[redshift] object Parameters {
sessionToken <- parameters.get("temporary_aws_session_token")
) yield new BasicSessionCredentials(accessKey, secretAccessKey, sessionToken)
}
+
+ /**
+ * When nonempty/non-null sets the compression codec to use for writing Avro data.
+ *
+ * Defaults to disabled (i.e. whatever is set in Hadoop config).
+ */
+ def avrocompression: String = parameters("avrocompression")
}
}
diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala
index 7e8edca4..85e82c50 100644
--- a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala
+++ b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala
@@ -272,6 +272,14 @@ private[redshift] class RedshiftWriter(
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl)
val tempDir = params.createPerQueryTempDir()
+
+ if (params.avrocompression != null && params.avrocompression.nonEmpty) {
+ val conf = sqlContext.sparkContext.hadoopConfiguration
+ conf.set("mapred.output.compress", "true")
+ conf.set("mapred.output.compression.type", "BLOCK")
+ conf.set("avro.output.codec", params.avrocompression)
+ }
+
try {
if (params.overwrite && params.useStagingTable) {
withStagingTable(conn, params.table.get, stagingTable => {
diff --git a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala
index 67654cc7..29dd1b28 100644
--- a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala
+++ b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala
@@ -348,6 +348,49 @@ class RedshiftSourceSuite
checkAnswer(written, TestUtils.expectedDataWithConvertedTimesAndDates)
}
+ test("DefaultSource serializes data as Avro when avrocompression is enabled") {
+
+ val params = defaultParams ++ Map(
+ "postactions" -> "GRANT SELECT ON %s TO jeremy",
+ "diststyle" -> "KEY",
+ "distkey" -> "testInt",
+ "avrocompression" -> "snappy")
+
+ val expectedCommands =
+ Seq("DROP TABLE IF EXISTS test_table_staging_.*".r,
+ "CREATE TABLE IF NOT EXISTS test_table_staging.* DISTSTYLE KEY DISTKEY \\(testInt\\).*".r,
+ "COPY test_table_staging_.*".r,
+ "GRANT SELECT ON test_table_staging.+ TO jeremy".r,
+ "ALTER TABLE test_table RENAME TO test_table_backup_.*".r,
+ "ALTER TABLE test_table_staging_.* RENAME TO test_table".r,
+ "DROP TABLE IF EXISTS test_table_backup.*".r)
+
+ val jdbcWrapper = mockJdbcWrapper(params("url"), expectedCommands)
+
+ (jdbcWrapper.tableExists _)
+ .expects(*, "test_table")
+ .returning(true)
+ .anyNumberOfTimes()
+
+ (jdbcWrapper.schemaString _)
+ .expects(*)
+ .returning("schema")
+ .anyNumberOfTimes()
+
+ val relation = RedshiftRelation(
+ jdbcWrapper, _ => mockS3Client, Parameters.mergeParameters(params), None)(testSqlContext)
+ relation.asInstanceOf[InsertableRelation].insert(expectedDataDF, true)
+
+ // Make sure we wrote the data out ready for Redshift load, in the expected formats
+ // The data should have been written to a random subdirectory of `tempdir`. Since we clear
+ // `tempdir` between every unit test, there should only be one directory here.
+ // Note: this does not actually test that the written files are properly compressed.
+ assert(s3FileSystem.listStatus(new Path(s3TempDir)).length === 1)
+ val dirWithAvroFiles = s3FileSystem.listStatus(new Path(s3TempDir)).head.getPath.toUri.toString
+ val written = testSqlContext.read.format("com.databricks.spark.avro").load(dirWithAvroFiles)
+ checkAnswer(written, TestUtils.expectedDataWithConvertedTimesAndDates)
+ }
+
test("Cannot write table with column names that become ambiguous under case insensitivity") {
val jdbcWrapper = mock[JDBCWrapper]
val mockedConnection = mock[Connection]