Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ exportMethods("arrange",
"covar_samp",
"covar_pop",
"crosstab",
"dapply",
"describe",
"dim",
"distinct",
Expand Down
61 changes: 61 additions & 0 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
NULL

setOldClass("jobj")
setOldClass("structType")

#' @title S4 class that represents a SparkDataFrame
#' @description DataFrames can be created using functions like \link{createDataFrame},
Expand Down Expand Up @@ -1125,6 +1126,66 @@ setMethod("summarize",
agg(x, ...)
})

#' dapply
#'
#' Apply a function to each partition of a DataFrame.
#'
#' @param x A SparkDataFrame
#' @param func A function to be applied to each partition of the SparkDataFrame.
#' func should have only one parameter, to which a data.frame corresponds
#' to each partition will be passed.
#' The output of func should be a data.frame.
#' @param schema The schema of the resulting DataFrame after the function is applied.
#' It must match the output of func.
#' @family SparkDataFrame functions
#' @rdname dapply
#' @name dapply
#' @export

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls add doc example

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#' @examples
#' \dontrun{
#' df <- createDataFrame (sqlContext, iris)
#' df1 <- dapply(df, function(x) { x }, schema(df))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we have an more elaborate example to explain how func should expect or handle "each partition of the DataFrame"?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

#' collect(df1)
#'
#' # filter and add a column
#' df <- createDataFrame (
#' sqlContext,
#' list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")),
#' c("a", "b", "c"))
#' schema <- structType(structField("a", "integer"), structField("b", "double"),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, we already have a simpler way (string based) to define a schema in Scala and Python, we may also add that to R.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will investigate it. Will submit a new PR for this or reuse https://issues.apache.org/jira/browse/SPARK-11046

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sun-rui - Just a note that it'll be great to have the simpler schema specification for 2.0. Let me know if you have a new JIRA or we will use 11046, so we can track it for the release.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, let me do some investigation

#' structField("c", "string"), structField("d", "integer"))
#' df1 <- dapply(
#' df,
#' function(x) {
#' y <- x[x[1] > 1, ]
#' y <- cbind(y, y[1] + 1L)
#' },
#' schema)
#' collect(df1)
#' # the result
#' # a b c d
#' # 1 2 2 2 3
#' # 2 3 3 3 4
#' }
setMethod("dapply",
signature(x = "SparkDataFrame", func = "function", schema = "structType"),
function(x, func, schema) {
packageNamesArr <- serialize(.sparkREnv[[".packages"]],
connection = NULL)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make NULL as the default value?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain more? don't understand


broadcastArr <- lapply(ls(.broadcastNames),
function(name) { get(name, .broadcastNames) })

sdf <- callJStatic(
"org.apache.spark.sql.api.r.SQLUtils",
"dapply",
x@sdf,
serialize(cleanClosure(func), connection = NULL),
packageNamesArr,
broadcastArr,
schema$jobj)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If schema is NULL this might lead to an error ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no. if schema is NULL, schema$jobj evaluates to NULL.
However, I agree it is confusing, and have changed it.

dataFrame(sdf)
})

############################## RDD Map Functions ##################################
# All of the following functions mirror the existing RDD map functions, #
Expand Down
4 changes: 4 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,10 @@ setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") })
#' @export
setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") })

#' @rdname dapply
#' @export
setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") })

#' @rdname summary
#' @export
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
Expand Down
40 changes: 40 additions & 0 deletions R/pkg/inst/tests/testthat/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -2017,6 +2017,46 @@ test_that("Histogram", {
df <- as.DataFrame(sqlContext, data.frame(x = c(1, 2, 3, 4, 100)))
expect_equal(histogram(df, "x")$counts, c(4, 0, 0, 0, 0, 0, 0, 0, 0, 1))
})

test_that("dapply() on a DataFrame", {
df <- createDataFrame (
sqlContext,
list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")),
c("a", "b", "c"))
ldf <- collect(df)
df1 <- dapply(df, function(x) { x }, schema(df))
result <- collect(df1)
expect_identical(ldf, result)


# Filter and add a column
schema <- structType(structField("a", "integer"), structField("b", "double"),
structField("c", "string"), structField("d", "integer"))
df1 <- dapply(
df,
function(x) {
y <- x[x$a > 1, ]
y <- cbind(y, y$a + 1L)
},
schema)
result <- collect(df1)
expected <- ldf[ldf$a > 1, ]
expected$d <- expected$a + 1L
rownames(expected) <- NULL
expect_identical(expected, result)

# Remove the added column
df2 <- dapply(
df1,
function(x) {
x[, c("a", "b", "c")]
},
schema(df))
result <- collect(df2)
expected <- expected[, c("a", "b", "c")]
expect_identical(expected, result)
})

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have more tests for chained dapply (with and without schema)?

@NarineK NarineK Apr 20, 2016

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I think that it would be good to add other data types beside double in the schema.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


unlink(parquetPath)
unlink(jsonPath)
unlink(jsonPathNa)
36 changes: 35 additions & 1 deletion R/pkg/inst/worker/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ broadcastElap <- elapsedSecs()
# as number of partitions to create.
numPartitions <- SparkR:::readInt(inputCon)

isDataFrame <- as.logical(SparkR:::readInt(inputCon))

@felixcheung felixcheung Apr 19, 2016

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be beyond the scope of this JIRA/PR: when we have protocol changes like this, how do we make sure the peer has matching implementation, and then we are not misinterpreting the byte stream? Should there be some sort of protocol version handshake?

For example, here we are coercing an Int value into true/false - but the Int may not be 0 or 1.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. I think the assumption is that R worker processes are started using the same binary release as the JVM processes. But yeah having a protocol version number or something like that might be interesting to explore.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently, SparkR is not a standalone package but an integral part of the Spark binary release. So it is assumed that the R worker script of correct matching version is always invoked. The protocol between JVM and the R worker is internal.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally, I agree. it is possible though an user could have an initialization or profile file that inadvertently loads a mismatch version of SparkR..

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the "--vanilla" option when launching R worker prevents this.


# If isDataFrame, then read column names
if (isDataFrame) {
colNames <- SparkR:::readObject(inputCon)
}

isEmpty <- SparkR:::readInt(inputCon)

if (isEmpty != 0) {
Expand All @@ -100,7 +107,34 @@ if (isEmpty != 0) {
# Timing reading input data for execution
inputElap <- elapsedSecs()

output <- computeFunc(partition, data)
if (isDataFrame) {
if (deserializer == "row") {
# Transform the list of rows into a data.frame
# Note that the optional argument stringsAsFactors for rbind is
# available since R 3.2.4. So we set the global option here.
oldOpt <- getOption("stringsAsFactors")
options(stringsAsFactors = FALSE)
data <- do.call(rbind.data.frame, data)
options(stringsAsFactors = oldOpt)

names(data) <- colNames
} else {
# Check to see if data is a valid data.frame
stopifnot(deserializer == "byte")
stopifnot(class(data) == "data.frame")
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If deserializer is not row can we add a check to see if data is a valid data.frame ? (I'm guess the UDFs assume that is the input type)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

output <- computeFunc(data)
if (serializer == "row") {
# Transform the result data.frame back to a list of rows
output <- split(output, seq(nrow(output)))
} else {
# Serialize the ouput to a byte array
stopifnot(serializer == "byte")
}
} else {
output <- computeFunc(partition, data)
}

# Timing computing
computeElap <- elapsedSecs()

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/api/r/RRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
// The parent may be also an RRDD, so we should launch it first.
val parentIterator = firstParent[T].iterator(partition, context)

runner.compute(parentIterator, partition.index, context)
runner.compute(parentIterator, partition.index)
}
}

Expand Down
13 changes: 10 additions & 3 deletions core/src/main/scala/org/apache/spark/api/r/RRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ private[spark] class RRunner[U](
serializer: String,
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
numPartitions: Int = -1)
numPartitions: Int = -1,
isDataFrame: Boolean = false,
colNames: Array[String] = null)
extends Logging {
private var bootTime: Double = _
private var dataStream: DataInputStream = _
Expand All @@ -53,8 +55,7 @@ private[spark] class RRunner[U](

def compute(
inputIterator: Iterator[_],
partitionIndex: Int,
context: TaskContext): Iterator[U] = {
partitionIndex: Int): Iterator[U] = {
// Timing start
bootTime = System.currentTimeMillis / 1000.0

Expand Down Expand Up @@ -148,6 +149,12 @@ private[spark] class RRunner[U](

dataOut.writeInt(numPartitions)

dataOut.writeInt(if (isDataFrame) 1 else 0)

if (isDataFrame) {
SerDe.writeObject(dataOut, colNames)
}

if (!iter.hasNext) {
dataOut.writeInt(0)
} else {
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/api/r/SerDe.scala
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ private[spark] object SerDe {

}

private[r] object SerializationFormats {
private[spark] object SerializationFormats {
val BYTE = "byte"
val STRING = "string"
val ROW = "row"
Expand Down
5 changes: 5 additions & 0 deletions docs/sql-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,11 @@ parquetFile <- read.parquet(sqlContext, "people.parquet")
# Parquet files can also be registered as tables and then used in SQL statements.
registerTempTable(parquetFile, "parquetFile")
teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19")
schema <- structType(structField("name", "string"))
teenNames <- dapply(df, function(p) { cbind(paste("Name:", p$name)) }, schema)
for (teenName in collect(teenNames)$name) {
cat(teenName, "\n")
}
{% endhighlight %}

</div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,15 @@ object EliminateSerialization extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case d @ DeserializeToObject(_, _, s: SerializeFromObject)
if d.outputObjectType == s.inputObjectType =>
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
Project(objAttr :: Nil, s.child)

// A workaround for SPARK-14803. Remove this after it is fixed.
if (d.outputObjectType.isInstanceOf[ObjectType] &&
d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) {
s.child
} else {
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
Project(objAttr :: Nil, s.child)
}
case a @ AppendColumns(_, _, _, s: SerializeFromObject)
if a.deserializer.dataType == s.inputObjectType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,39 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.Encoder
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types._

object CatalystSerde {
def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
DeserializeToObject(deserializer, generateObjAttr[T], child)
}

def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = {
val deserializer = UnresolvedDeserializer(encoder.deserializer)
DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child)
}

def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
SerializeFromObject(encoderFor[T].namedExpressions, child)
}

def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = {
SerializeFromObject(encoder.namedExpressions, child)
}

def generateObjAttr[T : Encoder]: Attribute = {
AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)()
}

def generateObjAttrForRow(encoder: ExpressionEncoder[Row]): Attribute = {
AttributeReference("obj", encoder.deserializer.dataType, nullable = false)()
}
}

/**
Expand Down Expand Up @@ -106,6 +120,42 @@ case class MapPartitions(
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer

object MapPartitionsInR {
def apply(
func: Array[Byte],
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
schema: StructType,
encoder: ExpressionEncoder[Row],
child: LogicalPlan): LogicalPlan = {
val deserialized = CatalystSerde.deserialize(child, encoder)
val mapped = MapPartitionsInR(
func,
packageNames,
broadcastVars,
encoder.schema,
schema,
CatalystSerde.generateObjAttrForRow(RowEncoder(schema)),
deserialized)
CatalystSerde.serialize(mapped, RowEncoder(schema))
}
}

/**
* A relation produced by applying a serialized R function `func` to each partition of the `child`.
*
*/
case class MapPartitionsInR(
func: Array[Byte],
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
inputSchema: StructType,
outputSchema: StructType,
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer {
override lazy val schema = outputSchema
}

object MapElements {
def apply[T : Encoder, U : Encoder](
func: AnyRef,
Expand Down
18 changes: 18 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
import org.apache.spark.api.python.PythonRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
Expand Down Expand Up @@ -1980,6 +1981,23 @@ class Dataset[T] private[sql](
mapPartitions(func)(encoder)
}

/**
* Returns a new [[DataFrame]] that contains the result of applying a serialized R function
* `func` to each partition.
*
* @group func
*/

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add @SInCE attribute in the comment ?

@sun-rui sun-rui Apr 29, 2016

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark 2.0 is a good chance for add "since" for SparkR API methods. But I think we can do it consistently for all methods at one. I will submit a new JIRA issue for it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private[sql] def mapPartitionsInR(
func: Array[Byte],
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
schema: StructType): DataFrame = {
val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]]
Dataset.ofRows(
sparkSession,
MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan))
}

/**
* :: Experimental ::
* (Scala-specific)
Expand Down
Loading