Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.{Stopwatch, LocalStopwatch, MultiStopwatch}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
Expand All @@ -52,11 +53,13 @@ private[ml] object RandomForest extends Logging {
seed: Long,
parentUID: Option[String] = None): Array[DecisionTreeModel] = {

val timer = new TimeTracker()
val multiTimer = new MultiStopwatch(input.sparkContext)

timer.start("total")
multiTimer.addLocal("total")
multiTimer("total").start()

timer.start("init")
multiTimer.addLocal("init")
multiTimer("init").start()

val retaggedInput = input.retag(classOf[LabeledPoint])
val metadata =
Expand All @@ -71,9 +74,10 @@ private[ml] object RandomForest extends Logging {

// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
timer.start("findSplitsBins")
multiTimer.addLocal("findSplitsBins")
multiTimer("findSplitsBins").start()
val splits = findSplits(retaggedInput, metadata)
timer.stop("findSplitsBins")
multiTimer("findSplitsBins").stop()
logDebug("numBins: feature: number of bins")
logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
Expand Down Expand Up @@ -113,7 +117,7 @@ private[ml] object RandomForest extends Logging {
" which is too small for the given features." +
s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")

timer.stop("init")
multiTimer("init").stop()

/*
* The main idea here is to perform group-wise training of the decision tree nodes thus
Expand Down Expand Up @@ -144,6 +148,8 @@ private[ml] object RandomForest extends Logging {
val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))

multiTimer.addLocal("findBestSplits")
multiTimer.addLocal("chooseSplits")
while (nodeQueue.nonEmpty) {
// Collect some nodes to split, and choose features for each node (if subsampling).
// Each group of nodes may come from one or multiple trees, and at multiple levels.
Expand All @@ -154,18 +160,18 @@ private[ml] object RandomForest extends Logging {
s"RandomForest selected empty nodesForGroup. Error for unknown reason.")

// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
multiTimer("findBestSplits").start()
RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache)
timer.stop("findBestSplits")
treeToNodeToIndexInfo, splits, nodeQueue, multiTimer("chooseSplits"), nodeIdCache)
multiTimer("findBestSplits").stop()
}

baggedInput.unpersist()

timer.stop("total")
multiTimer("total").stop()

logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
logInfo(s"$multiTimer")

// Delete any remaining checkpoints used for node Id cache.
if (nodeIdCache.nonEmpty) {
Expand Down Expand Up @@ -365,7 +371,7 @@ private[ml] object RandomForest extends Logging {
treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
splits: Array[Array[Split]],
nodeQueue: mutable.Queue[(Int, LearningNode)],
timer: TimeTracker = new TimeTracker,
timer: Stopwatch = new LocalStopwatch("chooseSplits"),
nodeIdCache: Option[NodeIdCache] = None): Unit = {

/*
Expand Down Expand Up @@ -497,7 +503,7 @@ private[ml] object RandomForest extends Logging {
}

// Calculate best splits for all nodes in the group
timer.start("chooseSplits")
timer.start()

// In each partition, iterate all instances and compute aggregate stats for each node,
// yield an (nodeIndex, nodeAggregateStats) pair for each node.
Expand Down Expand Up @@ -558,7 +564,7 @@ private[ml] object RandomForest extends Logging {
(nodeIndex, (split, stats))
}.collectAsMap()

timer.stop("chooseSplits")
timer.stop()

val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
Array.fill[mutable.Map[Int, NodeIndexUpdater]](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuilder

import org.apache.spark.ml.util.{Stopwatch, LocalStopwatch}
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
Expand Down Expand Up @@ -447,7 +448,7 @@ object DecisionTree extends Serializable with Logging {
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
nodeQueue: mutable.Queue[(Int, Node)],
timer: TimeTracker = new TimeTracker,
timer: Stopwatch = new LocalStopwatch("chooseSplits"),
nodeIdCache: Option[NodeIdCache] = None): Unit = {

/*
Expand Down Expand Up @@ -580,7 +581,7 @@ object DecisionTree extends Serializable with Logging {
}

// Calculate best splits for all nodes in the group
timer.start("chooseSplits")
timer.start()

// In each partition, iterate all instances and compute aggregate stats for each node,
// yield an (nodeIndex, nodeAggregateStats) pair for each node.
Expand Down Expand Up @@ -641,7 +642,7 @@ object DecisionTree extends Serializable with Logging {
(nodeIndex, (split, stats, predict))
}.collectAsMap()

timer.stop("chooseSplits")
timer.stop()

val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
Array.fill[mutable.Map[Int, NodeIndexUpdater]](
Expand Down
29 changes: 18 additions & 11 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.util.MultiStopwatch
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
Expand Down Expand Up @@ -128,11 +129,13 @@ private class RandomForest (
*/
def run(input: RDD[LabeledPoint]): RandomForestModel = {

val timer = new TimeTracker()
val multiTimer = new MultiStopwatch(input.sparkContext)

timer.start("total")
multiTimer.addLocal("total")
multiTimer("total").start()

timer.start("init")
multiTimer.addLocal("init")
multiTimer("init").start()

val retaggedInput = input.retag(classOf[LabeledPoint])
val metadata =
Expand All @@ -147,9 +150,10 @@ private class RandomForest (

// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
timer.start("findSplitsBins")
multiTimer.addLocal("findSplitsBins")
multiTimer("findSplitsBins").start()
val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
timer.stop("findSplitsBins")
multiTimer("findSplitsBins").stop()
logDebug("numBins: feature: number of bins")
logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
Expand Down Expand Up @@ -190,7 +194,7 @@ private class RandomForest (
" which is too small for the given features." +
s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")

timer.stop("init")
multiTimer("init").stop()

/*
* The main idea here is to perform group-wise training of the decision tree nodes thus
Expand Down Expand Up @@ -221,6 +225,8 @@ private class RandomForest (
val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))

multiTimer.addLocal("findBestSplits")
multiTimer.addLocal("chooseSplits")
while (nodeQueue.nonEmpty) {
// Collect some nodes to split, and choose features for each node (if subsampling).
// Each group of nodes may come from one or multiple trees, and at multiple levels.
Expand All @@ -231,18 +237,19 @@ private class RandomForest (
s"RandomForest selected empty nodesForGroup. Error for unknown reason.")

// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
multiTimer("findBestSplits").start()
DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
timer.stop("findBestSplits")
treeToNodeToIndexInfo, splits, bins, nodeQueue, multiTimer("chooseSplits"),
nodeIdCache = nodeIdCache)
multiTimer("findBestSplits").stop()
}

baggedInput.unpersist()

timer.stop("total")
multiTimer("total").stop()

logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
logInfo(s"$multiTimer")

// Delete any remaining checkpoints used for node Id cache.
if (nodeIdCache.nonEmpty) {
Expand Down