diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index a8b90d9d266a1..f85c57bb35c1d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -26,6 +26,7 @@ import org.apache.spark.Logging import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ +import org.apache.spark.ml.util.{Stopwatch, LocalStopwatch, MultiStopwatch} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, @@ -52,11 +53,13 @@ private[ml] object RandomForest extends Logging { seed: Long, parentUID: Option[String] = None): Array[DecisionTreeModel] = { - val timer = new TimeTracker() + val multiTimer = new MultiStopwatch(input.sparkContext) - timer.start("total") + multiTimer.addLocal("total") + multiTimer("total").start() - timer.start("init") + multiTimer.addLocal("init") + multiTimer("init").start() val retaggedInput = input.retag(classOf[LabeledPoint]) val metadata = @@ -71,9 +74,10 @@ private[ml] object RandomForest extends Logging { // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - timer.start("findSplitsBins") + multiTimer.addLocal("findSplitsBins") + multiTimer("findSplitsBins").start() val splits = findSplits(retaggedInput, metadata) - timer.stop("findSplitsBins") + multiTimer("findSplitsBins").stop() logDebug("numBins: feature: number of bins") logDebug(Range(0, metadata.numFeatures).map { featureIndex => s"\t$featureIndex\t${metadata.numBins(featureIndex)}" @@ -113,7 +117,7 @@ private[ml] object RandomForest extends Logging { " which is too small for the given features." + s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}") - timer.stop("init") + multiTimer("init").stop() /* * The main idea here is to perform group-wise training of the decision tree nodes thus @@ -144,6 +148,8 @@ private[ml] object RandomForest extends Logging { val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1)) Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) + multiTimer.addLocal("findBestSplits") + multiTimer.addLocal("chooseSplits") while (nodeQueue.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. @@ -154,18 +160,18 @@ private[ml] object RandomForest extends Logging { s"RandomForest selected empty nodesForGroup. Error for unknown reason.") // Choose node splits, and enqueue new nodes as needed. - timer.start("findBestSplits") + multiTimer("findBestSplits").start() RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, - treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache) - timer.stop("findBestSplits") + treeToNodeToIndexInfo, splits, nodeQueue, multiTimer("chooseSplits"), nodeIdCache) + multiTimer("findBestSplits").stop() } baggedInput.unpersist() - timer.stop("total") + multiTimer("total").stop() logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") + logInfo(s"$multiTimer") // Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) { @@ -365,7 +371,7 @@ private[ml] object RandomForest extends Logging { treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], splits: Array[Array[Split]], nodeQueue: mutable.Queue[(Int, LearningNode)], - timer: TimeTracker = new TimeTracker, + timer: Stopwatch = new LocalStopwatch("chooseSplits"), nodeIdCache: Option[NodeIdCache] = None): Unit = { /* @@ -497,7 +503,7 @@ private[ml] object RandomForest extends Logging { } // Calculate best splits for all nodes in the group - timer.start("chooseSplits") + timer.start() // In each partition, iterate all instances and compute aggregate stats for each node, // yield an (nodeIndex, nodeAggregateStats) pair for each node. @@ -558,7 +564,7 @@ private[ml] object RandomForest extends Logging { (nodeIndex, (split, stats)) }.collectAsMap() - timer.stop("chooseSplits") + timer.stop() val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { Array.fill[mutable.Map[Int, NodeIndexUpdater]]( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index cecd1fed896d5..2ea748649afd4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuilder +import org.apache.spark.ml.util.{Stopwatch, LocalStopwatch} import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD @@ -447,7 +448,7 @@ object DecisionTree extends Serializable with Logging { splits: Array[Array[Split]], bins: Array[Array[Bin]], nodeQueue: mutable.Queue[(Int, Node)], - timer: TimeTracker = new TimeTracker, + timer: Stopwatch = new LocalStopwatch("chooseSplits"), nodeIdCache: Option[NodeIdCache] = None): Unit = { /* @@ -580,7 +581,7 @@ object DecisionTree extends Serializable with Logging { } // Calculate best splits for all nodes in the group - timer.start("chooseSplits") + timer.start() // In each partition, iterate all instances and compute aggregate stats for each node, // yield an (nodeIndex, nodeAggregateStats) pair for each node. @@ -641,7 +642,7 @@ object DecisionTree extends Serializable with Logging { (nodeIndex, (split, stats, predict)) }.collectAsMap() - timer.stop("chooseSplits") + timer.stop() val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { Array.fill[mutable.Map[Int, NodeIndexUpdater]]( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 069959976a188..447f8abc34a64 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD +import org.apache.spark.ml.util.MultiStopwatch import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ @@ -128,11 +129,13 @@ private class RandomForest ( */ def run(input: RDD[LabeledPoint]): RandomForestModel = { - val timer = new TimeTracker() + val multiTimer = new MultiStopwatch(input.sparkContext) - timer.start("total") + multiTimer.addLocal("total") + multiTimer("total").start() - timer.start("init") + multiTimer.addLocal("init") + multiTimer("init").start() val retaggedInput = input.retag(classOf[LabeledPoint]) val metadata = @@ -147,9 +150,10 @@ private class RandomForest ( // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - timer.start("findSplitsBins") + multiTimer.addLocal("findSplitsBins") + multiTimer("findSplitsBins").start() val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) - timer.stop("findSplitsBins") + multiTimer("findSplitsBins").stop() logDebug("numBins: feature: number of bins") logDebug(Range(0, metadata.numFeatures).map { featureIndex => s"\t$featureIndex\t${metadata.numBins(featureIndex)}" @@ -190,7 +194,7 @@ private class RandomForest ( " which is too small for the given features." + s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}") - timer.stop("init") + multiTimer("init").stop() /* * The main idea here is to perform group-wise training of the decision tree nodes thus @@ -221,6 +225,8 @@ private class RandomForest ( val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1)) Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) + multiTimer.addLocal("findBestSplits") + multiTimer.addLocal("chooseSplits") while (nodeQueue.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. @@ -231,18 +237,19 @@ private class RandomForest ( s"RandomForest selected empty nodesForGroup. Error for unknown reason.") // Choose node splits, and enqueue new nodes as needed. - timer.start("findBestSplits") + multiTimer("findBestSplits").start() DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, - treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache) - timer.stop("findBestSplits") + treeToNodeToIndexInfo, splits, bins, nodeQueue, multiTimer("chooseSplits"), + nodeIdCache = nodeIdCache) + multiTimer("findBestSplits").stop() } baggedInput.unpersist() - timer.stop("total") + multiTimer("total").stop() logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") + logInfo(s"$multiTimer") // Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) {