From a43c42c8558ee60201638b2e3adfe5f2044bb3fb Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 8 May 2026 17:44:27 +0800 Subject: [PATCH] [SPARK-34591][ML] Add decision tree pruning as a parameter This PR adds a parameter to enable/disable a featuer where LearningNodes are merged after a RF model is trained. This PR takes over https://github.com/apache/spark/pull/32813 2 Reasons: 1. In addition to basic classification, another use case for decision trees are the probabilities associated with predictions. Once pruned, these predictions are lost and it makes the trees/predictions challenging to work with if not unusable. 2. It is not in line with the default behavior in sklearn. In sklearn, the trees are left unpruned by default. Please see Jira ticket for more explanation. **New params:** adds a parameter `pruneTree` that is exposed to the Tree based classifiers. Will add tests here to ensure parameter is exposed correctly. Unit tests. Closes #55728 from WeichenXu123/SPARK-34591. Lead-authored-by: WeichenXu Co-authored-by: bribiescas-carlos Co-authored-by: Carlos Bribiescas Signed-off-by: Weichen Xu (cherry picked from commit 1f4650674a663627cdf38a6100d9fb7cf1527c47) Signed-off-by: Weichen Xu --- .../DecisionTreeClassifier.scala | 8 +- .../RandomForestClassifier.scala | 7 +- .../spark/ml/tree/impl/RandomForest.scala | 661 +++++++++++------- .../org/apache/spark/ml/tree/treeParams.scala | 19 +- .../mllib/tree/configuration/Strategy.scala | 8 +- .../ml/tree/impl/RandomForestSuite.scala | 411 +++++++---- python/pyspark/ml/classification.py | 26 +- python/pyspark/ml/tree.py | 13 + 8 files changed, 753 insertions(+), 400 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 887d8277d3117..d5564f6a3fbda 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -74,6 +74,10 @@ class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.4.0") def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + /** @group setParam */ + @Since("4.3.0") + def setPruneTree(value: Boolean): this.type = set(pruneTree, value) + /** @group expertSetParam */ @Since("1.4.0") def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) @@ -134,9 +138,11 @@ class DecisionTreeClassifier @Since("1.4.0") ( val strategy = getOldStrategy(categoricalFeatures, numClasses) require(!strategy.bootstrap, "DecisionTreeClassifier does not need bootstrap sampling") + strategy.pruneTree = $(pruneTree) + instr.logNumClasses(numClasses) instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol, - probabilityCol, leafCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, + probabilityCol, leafCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, pruneTree, maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed, thresholds) val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all", diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index fb61358536d0c..2c22ca5b42302 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -76,6 +76,10 @@ class RandomForestClassifier @Since("1.4.0") ( @Since("1.4.0") def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + /** @group setParam */ + @Since("4.3.0") + def setPruneTree(value: Boolean): this.type = set(pruneTree, value) + /** @group expertSetParam */ @Since("1.4.0") def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) @@ -159,10 +163,11 @@ class RandomForestClassifier @Since("1.4.0") ( val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) strategy.bootstrap = $(bootstrap) + strategy.pruneTree = $(pruneTree) instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, probabilityCol, rawPredictionCol, leafCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, - maxMemoryInMB, minInfoGain, minInstancesPerNode, minWeightFractionPerNode, seed, + maxMemoryInMB, minInfoGain, pruneTree, minInstancesPerNode, minWeightFractionPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval, bootstrap) val trees = RandomForest diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index cabbc497571b6..c3a16ab3dddd3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -41,7 +41,6 @@ import org.apache.spark.util.SizeEstimator import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} - /** * ALGORITHM * @@ -97,8 +96,9 @@ private[spark] object RandomForest extends Logging with Serializable { numTrees: Int, featureSubsetStrategy: String, seed: Long): Array[DecisionTreeModel] = { - val instances = input.map { case LabeledPoint(label, features) => - Instance(label, 1.0, features.asML) + val instances = input.map { + case LabeledPoint(label, features) => + Instance(label, 1.0, features.asML) } run(instances, strategy, numTrees, featureSubsetStrategy, seed, None) } @@ -124,7 +124,6 @@ private[spark] object RandomForest extends Logging with Serializable { featureSubsetStrategy: String, seed: Long, instr: Option[Instrumentation], - prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None, earlyStopModelSizeThresholdInBytes: Long = 0): Array[DecisionTreeModel] = { lastEarlyStoppedModelSize = 0 @@ -151,7 +150,8 @@ private[spark] object RandomForest extends Logging with Serializable { // depth of the decision tree val maxDepth = strategy.maxDepth - require(maxDepth <= 30, + require( + maxDepth <= 30, s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") // Max memory usage for aggregates @@ -203,9 +203,10 @@ private[spark] object RandomForest extends Logging with Serializable { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. val (nodesForGroup, treeToNodeToIndexInfo) = - RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) + RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) // Sanity check (should never occur): - assert(nodesForGroup.nonEmpty, + assert( + nodesForGroup.nonEmpty, s"RandomForest selected empty nodesForGroup. Error for unknown reason.") // Only send trees to worker if they contain nodes being split this iteration. @@ -214,8 +215,16 @@ private[spark] object RandomForest extends Logging with Serializable { // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") - val bestSplit = RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, - nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack, timer, nodeIds, + val bestSplit = RandomForest.findBestSplits( + baggedInput, + metadata, + topNodesForGroup, + nodesForGroup, + treeToNodeToIndexInfo, + bcSplits, + nodeStack, + timer, + nodeIds, outputBestSplits = strategy.useNodeIdCache) if (strategy.useNodeIdCache) { nodeIds = updateNodeIds(baggedInput, nodeIds, bcSplits, bestSplit) @@ -225,7 +234,7 @@ private[spark] object RandomForest extends Logging with Serializable { timer.stop("findBestSplits") if (earlyStopModelSizeThresholdInBytes > 0) { - val nodes = topNodes.map(_.toNode(prune)) + val nodes = topNodes.map(_.toNode(strategy.pruneTree)) val estimatedSize = SizeEstimator.estimate(nodes) if (estimatedSize > earlyStopModelSizeThresholdInBytes){ earlyStop = true @@ -258,23 +267,28 @@ private[spark] object RandomForest extends Logging with Serializable { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures, - strategy.getNumClasses()) + new DecisionTreeClassificationModel( + uid, + rootNode.toNode(strategy.pruneTree), + numFeatures, + strategy.getNumClasses) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toNode(strategy.pruneTree), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures, - strategy.getNumClasses()) + new DecisionTreeClassificationModel( + rootNode.toNode(strategy.pruneTree), + numFeatures, + strategy.getNumClasses) } } else { topNodes.map(rootNode => - new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures)) + new DecisionTreeRegressionModel(rootNode.toNode(strategy.pruneTree), numFeatures)) } } } @@ -293,7 +307,6 @@ private[spark] object RandomForest extends Logging with Serializable { featureSubsetStrategy: String, seed: Long, instr: Option[Instrumentation], - prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { val earlyStopModelSizeThresholdInBytes = TreeConfig.trainingEarlyStopModelSizeThresholdInBytes val timer = new TimeTracker() @@ -311,9 +324,12 @@ private[spark] object RandomForest extends Logging with Serializable { val splits = findSplits(retaggedInput, metadata, seed) timer.stop("findSplits") logDebug("numBins: feature: number of bins") - logDebug(Range(0, metadata.numFeatures).map { featureIndex => - s"\t$featureIndex\t${metadata.numBins(featureIndex)}" - }.mkString("\n")) + logDebug( + Range(0, metadata.numFeatures) + .map { featureIndex => + s"\t$featureIndex\t${metadata.numBins(featureIndex)}" + } + .mkString("\n")) // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. @@ -321,14 +337,26 @@ private[spark] object RandomForest extends Logging with Serializable { val bcSplits = input.sparkContext.broadcast(splits) val baggedInput = BaggedPoint - .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, strategy.bootstrap, - (tp: TreePoint) => tp.weight, seed = seed) + .convertToBaggedRDD( + treeInput, + strategy.subsamplingRate, + numTrees, + strategy.bootstrap, + (tp: TreePoint) => tp.weight, + seed = seed) .persist(StorageLevel.MEMORY_AND_DISK) .setName("bagged tree points") - val trees = runBagged(baggedInput = baggedInput, metadata = metadata, bcSplits = bcSplits, - strategy = strategy, numTrees = numTrees, featureSubsetStrategy = featureSubsetStrategy, - seed = seed, instr = instr, prune = prune, parentUID = parentUID, + val trees = runBagged( + baggedInput = baggedInput, + metadata = metadata, + bcSplits = bcSplits, + strategy = strategy, + numTrees = numTrees, + featureSubsetStrategy = featureSubsetStrategy, + seed = seed, + instr = instr, + parentUID = parentUID, earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes) baggedInput.unpersist() @@ -346,26 +374,27 @@ private[spark] object RandomForest extends Logging with Serializable { bcSplits: Broadcast[Array[Array[Split]]], bestSplits: Array[Map[Int, Split]]): RDD[Array[Int]] = { require(nodeIds != null && bestSplits != null) - input.zip(nodeIds).map { case (point, ids) => - var treeId = 0 - while (treeId < bestSplits.length) { - val bestSplitsInTree = bestSplits(treeId) - if (bestSplitsInTree != null) { - val nodeId = ids(treeId) - bestSplitsInTree.get(nodeId).foreach { bestSplit => - val featureId = bestSplit.featureIndex - val bin = point.datum.binnedFeatures(featureId) - val newNodeId = if (bestSplit.shouldGoLeft(bin, bcSplits.value(featureId))) { - LearningNode.leftChildIndex(nodeId) - } else { - LearningNode.rightChildIndex(nodeId) + input.zip(nodeIds).map { + case (point, ids) => + var treeId = 0 + while (treeId < bestSplits.length) { + val bestSplitsInTree = bestSplits(treeId) + if (bestSplitsInTree != null) { + val nodeId = ids(treeId) + bestSplitsInTree.get(nodeId).foreach { bestSplit => + val featureId = bestSplit.featureIndex + val bin = point.datum.binnedFeatures(featureId) + val newNodeId = if (bestSplit.shouldGoLeft(bin, bcSplits.value(featureId))) { + LearningNode.leftChildIndex(nodeId) + } else { + LearningNode.rightChildIndex(nodeId) + } + ids(treeId) = newNodeId } - ids(treeId) = newNodeId } + treeId += 1 } - treeId += 1 - } - ids + ids } } @@ -417,7 +446,11 @@ private[spark] object RandomForest extends Logging with Serializable { var splitIndex = 0 while (splitIndex < numSplits) { if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { - agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, numSamples, + agg.featureUpdate( + leftNodeFeatureOffset, + splitIndex, + treePoint.label, + numSamples, sampleWeight) } splitIndex += 1 @@ -532,8 +565,9 @@ private[spark] object RandomForest extends Logging with Serializable { logDebug(s"numFeatures = ${metadata.numFeatures}") logDebug(s"numClasses = ${metadata.numClasses}") logDebug(s"isMulticlass = ${metadata.isMulticlass}") - logDebug(s"isMulticlassWithCategoricalFeatures = " + - s"${metadata.isMulticlassWithCategoricalFeatures}") + logDebug( + s"isMulticlassWithCategoricalFeatures = " + + s"${metadata.isMulticlassWithCategoricalFeatures}") logDebug(s"using nodeIdCache = $useNodeIdCache") /* @@ -560,11 +594,21 @@ private[spark] object RandomForest extends Logging with Serializable { val numSamples = baggedPoint.subsampleCounts(treeIndex) val sampleWeight = baggedPoint.sampleWeight if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples, sampleWeight, + orderedBinSeqOp( + agg(aggNodeIndex), + baggedPoint.datum, + numSamples, + sampleWeight, featuresForNode) } else { - mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, - metadata.unorderedFeatures, numSamples, sampleWeight, featuresForNode) + mixedBinSeqOp( + agg(aggNodeIndex), + baggedPoint.datum, + splits, + metadata.unorderedFeatures, + numSamples, + sampleWeight, + featuresForNode) } agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples, sampleWeight) } @@ -585,11 +629,16 @@ private[spark] object RandomForest extends Logging with Serializable { agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint], splits: Array[Array[Split]]): Array[DTStatsAggregator] = { - treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => - val nodeIndex = - topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) - nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), - agg, baggedPoint, splits) + treeToNodeToIndexInfo.foreach { + case (treeIndex, nodeIndexToInfo) => + val nodeIndex = + topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) + nodeBinSeqOp( + treeIndex, + nodeIndexToInfo.getOrElse(nodeIndex, null), + agg, + baggedPoint, + splits) } agg } @@ -601,12 +650,17 @@ private[spark] object RandomForest extends Logging with Serializable { agg: Array[DTStatsAggregator], dataPoint: (BaggedPoint[TreePoint], Array[Int]), splits: Array[Array[Split]]): Array[DTStatsAggregator] = { - treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => - val baggedPoint = dataPoint._1 - val nodeIdCache = dataPoint._2 - val nodeIndex = nodeIdCache(treeIndex) - nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), - agg, baggedPoint, splits) + treeToNodeToIndexInfo.foreach { + case (treeIndex, nodeIndexToInfo) => + val baggedPoint = dataPoint._1 + val nodeIdCache = dataPoint._2 + val nodeIndex = nodeIdCache(treeIndex) + nodeBinSeqOp( + treeIndex, + nodeIndexToInfo.getOrElse(nodeIndex, null), + agg, + baggedPoint, + splits) } agg } @@ -615,8 +669,8 @@ private[spark] object RandomForest extends Logging with Serializable { * Get node index in group --> features indices map, * which is a short cut to find feature indices for a node given node index in group. */ - def getNodeToFeatures( - treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = { + def getNodeToFeatures(treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) + : Option[Map[Int, Array[Int]]] = { if (!metadata.subsamplingFeatures) { None } else { @@ -624,7 +678,8 @@ private[spark] object RandomForest extends Logging with Serializable { treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo => nodeIdToNodeInfo.values.foreach { nodeIndexInfo => assert(nodeIndexInfo.featureSubset.isDefined) - mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get + mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = + nodeIndexInfo.featureSubset.get } } Some(mutableNodeToFeatures.toMap) @@ -633,10 +688,11 @@ private[spark] object RandomForest extends Logging with Serializable { // array of nodes to train indexed by node index in group val nodes = new Array[LearningNode](numNodes) - nodesForGroup.foreach { case (treeIndex, nodesForTree) => - nodesForTree.foreach { node => - nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node - } + nodesForGroup.foreach { + case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node + } } // Calculate best splits for all nodes in the group @@ -690,17 +746,20 @@ private[spark] object RandomForest extends Logging with Serializable { } } - val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map { - case (nodeIndex, aggStats) => - val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => - Some(nodeToFeatures(nodeIndex)) - } + val nodeToBestSplits = partitionAggregates + .reduceByKey((a, b) => a.merge(b)) + .map { + case (nodeIndex, aggStats) => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } - // find best split for each node - val (split: Split, stats: ImpurityStats) = - binsToBestSplit(aggStats, bcSplits.value, featuresForNode, nodes(nodeIndex)) - (nodeIndex, (split, stats)) - }.collectAsMap() + // find best split for each node + val (split: Split, stats: ImpurityStats) = + binsToBestSplit(aggStats, bcSplits.value, featuresForNode, nodes(nodeIndex)) + (nodeIndex, (split, stats)) + } + .collectAsMap() nodeToFeaturesBc.destroy() timer.stop("chooseSplits") @@ -712,55 +771,64 @@ private[spark] object RandomForest extends Logging with Serializable { } // Iterate over all nodes in this group. - nodesForGroup.foreach { case (treeIndex, nodesForTree) => - nodesForTree.foreach { node => - val nodeIndex = node.id - val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) - val aggNodeIndex = nodeInfo.nodeIndexInGroup - val (split: Split, stats: ImpurityStats) = - nodeToBestSplits(aggNodeIndex) - logDebug(s"best split = $split") - - // Extract info for this node. Create children if not leaf. - val isLeaf = - (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) - node.isLeaf = isLeaf - node.stats = stats - logDebug(s"Node = $node") - - if (!isLeaf) { - node.split = Some(split) - val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth - val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON) - val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON) - node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex), - leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) - node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex), - rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator))) - - if (outputBestSplits) { - val bestSplitsInTree = bestSplits(treeIndex) - if (bestSplitsInTree == null) { - bestSplits(treeIndex) = mutable.Map[Int, Split](nodeIndex -> split) - } else { - bestSplitsInTree.update(nodeIndex, split) + nodesForGroup.foreach { + case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + val nodeIndex = node.id + val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val (split: Split, stats: ImpurityStats) = + nodeToBestSplits(aggNodeIndex) + logDebug(s"best split = $split") + + // Extract info for this node. Create children if not leaf. + val isLeaf = + (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) + node.isLeaf = isLeaf + node.stats = stats + logDebug(s"Node = $node") + + if (!isLeaf) { + node.split = Some(split) + val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth + val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON) + val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON) + node.leftChild = Some( + LearningNode( + LearningNode.leftChildIndex(nodeIndex), + leftChildIsLeaf, + ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) + node.rightChild = Some( + LearningNode( + LearningNode.rightChildIndex(nodeIndex), + rightChildIsLeaf, + ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator))) + + if (outputBestSplits) { + val bestSplitsInTree = bestSplits(treeIndex) + if (bestSplitsInTree == null) { + bestSplits(treeIndex) = mutable.Map[Int, Split](nodeIndex -> split) + } else { + bestSplitsInTree.update(nodeIndex, split) + } } - } - // enqueue left child and right child if they are not leaves - if (!leftChildIsLeaf) { - nodeStack.prepend((treeIndex, node.leftChild.get)) - } - if (!rightChildIsLeaf) { - nodeStack.prepend((treeIndex, node.rightChild.get)) - } + // enqueue left child and right child if they are not leaves + if (!leftChildIsLeaf) { + nodeStack.prepend((treeIndex, node.leftChild.get)) + } + if (!rightChildIsLeaf) { + nodeStack.prepend((treeIndex, node.rightChild.get)) + } - logDebug(s"leftChildIndex = ${node.leftChild.get.id}" + - s", impurity = ${stats.leftImpurity}") - logDebug(s"rightChildIndex = ${node.rightChild.get.id}" + - s", impurity = ${stats.rightImpurity}") + logDebug( + s"leftChildIndex = ${node.leftChild.get.id}" + + s", impurity = ${stats.leftImpurity}") + logDebug( + s"rightChildIndex = ${node.rightChild.get.id}" + + s", impurity = ${stats.rightImpurity}") + } } - } } if (outputBestSplits) { @@ -830,8 +898,12 @@ private[spark] object RandomForest extends Logging with Serializable { return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } - new ImpurityStats(gain, impurity, parentImpurityCalculator, - leftImpurityCalculator, rightImpurityCalculator) + new ImpurityStats( + gain, + impurity, + parentImpurityCalculator, + leftImpurityCalculator, + rightImpurityCalculator) } /** @@ -855,130 +927,156 @@ private[spark] object RandomForest extends Logging with Serializable { } val validFeatureSplits = - Iterator.range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => - featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) - .getOrElse((featureIndexIdx, featureIndexIdx)) - }.withFilter { case (_, featureIndex) => - binAggregates.metadata.numSplits(featureIndex) != 0 - } + Iterator + .range(0, binAggregates.metadata.numFeaturesPerNode) + .map { featureIndexIdx => + featuresForNode + .map(features => (featureIndexIdx, features(featureIndexIdx))) + .getOrElse((featureIndexIdx, featureIndexIdx)) + } + .withFilter { + case (_, featureIndex) => + binAggregates.metadata.numSplits(featureIndex) != 0 + } // For each (feature, split), calculate the gain, and select the best (feature, split). val splitsAndImpurityInfo = - validFeatureSplits.map { case (featureIndexIdx, featureIndex) => - val numSplits = binAggregates.metadata.numSplits(featureIndex) - if (binAggregates.metadata.isContinuous(featureIndex)) { - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - var splitIndex = 0 - while (splitIndex < numSplits) { - binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) - splitIndex += 1 - } - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIdx => - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIdx, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else if (binAggregates.metadata.isUnordered(featureIndex)) { - // Unordered categorical feature - val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getParentImpurityCalculator() - .subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else { - // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val numCategories = binAggregates.metadata.numBins(featureIndex) - - /* Each bin is one category (feature value). - * The bins are ordered based on centroidForCategories, and this ordering determines which - * splits are considered. (With K categories, we consider K - 1 possible splits.) - * + validFeatureSplits.map { + case (featureIndexIdx, featureIndex) => + val numSplits = binAggregates.metadata.numSplits(featureIndex) + if (binAggregates.metadata.isContinuous(featureIndex)) { + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + var splitIndex = 0 + while (splitIndex < numSplits) { + binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) + splitIndex += 1 + } + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits) + .map { splitIdx => + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats( + gainAndImpurityStats, + leftChildStats, + rightChildStats, + binAggregates.metadata) + (splitIdx, gainAndImpurityStats) + } + .maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else if (binAggregates.metadata.isUnordered(featureIndex)) { + // Unordered categorical feature + val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits) + .map { splitIndex => + val leftChildStats = + binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val rightChildStats = binAggregates + .getParentImpurityCalculator() + .subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats( + gainAndImpurityStats, + leftChildStats, + rightChildStats, + binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + } + .maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else { + // Ordered categorical feature + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val numCategories = binAggregates.metadata.numBins(featureIndex) + + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines + * which splits are considered. (With K categories, we + * consider K - 1 possible splits.) + * * centroidForCategories is a list: (category, centroid) - */ - val centroidForCategories = Range(0, numCategories).map { featureValue => - val categoryStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { - if (binAggregates.metadata.isMulticlass) { - // multiclass classification - // For categorical variables in multiclass classification, - // the bins are ordered by the impurity of their corresponding labels. - categoryStats.calculate() - } else if (binAggregates.metadata.isClassification) { - // binary classification - // For categorical variables in binary classification, - // the bins are ordered by the count of class 1. - categoryStats.stats(1) + */ + val centroidForCategories = Range(0, numCategories).map { featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + if (binAggregates.metadata.isMulticlass) { + // multiclass classification + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + categoryStats.calculate() + } else if (binAggregates.metadata.isClassification) { + // binary classification + // For categorical variables in binary classification, + // the bins are ordered by the count of class 1. + categoryStats.stats(1) + } else { + // regression + // For categorical variables in regression and binary classification, + // the bins are ordered by the prediction. + categoryStats.predict + } } else { - // regression - // For categorical variables in regression and binary classification, - // the bins are ordered by the prediction. - categoryStats.predict + Double.MaxValue } - } else { - Double.MaxValue + (featureValue, centroid) } - (featureValue, centroid) - } - logDebug(s"Centroids for categorical variable: " + - s"${centroidForCategories.mkString(",")}") - - // bins sorted by centroids - val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) - - logDebug(s"Sorted centroids for categorical variable = " + - s"${categoriesSortedByCentroid.mkString(",")}") - - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - var splitIndex = 0 - while (splitIndex < numSplits) { - val currentCategory = categoriesSortedByCentroid(splitIndex)._1 - val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 - binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) - splitIndex += 1 + logDebug( + s"Centroids for categorical variable: " + + s"${centroidForCategories.mkString(",")}") + + // bins sorted by centroids + val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) + + logDebug( + s"Sorted centroids for categorical variable = " + + s"${categoriesSortedByCentroid.mkString(",")}") + + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + var splitIndex = 0 + while (splitIndex < numSplits) { + val currentCategory = categoriesSortedByCentroid(splitIndex)._1 + val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 + binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) + splitIndex += 1 + } + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last._1 + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits) + .map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex)._1 + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + rightChildStats.subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats( + gainAndImpurityStats, + leftChildStats, + rightChildStats, + binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + } + .maxBy(_._2.gain) + val categoriesForSplit = + categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) + val bestFeatureSplit = + new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) + (bestFeatureSplit, bestFeatureGainStats) } - // lastCategory = index of bin with total aggregates for this (node, feature) - val lastCategory = categoriesSortedByCentroid.last._1 - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val featureValue = categoriesSortedByCentroid(splitIndex)._1 - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) - val categoriesForSplit = - categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) - val bestFeatureSplit = - new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) - (bestFeatureSplit, bestFeatureGainStats) - } } val (bestSplit, bestSplitStats) = @@ -989,11 +1087,13 @@ private[spark] object RandomForest extends Logging with Serializable { val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0) val parentImpurityCalculator = binAggregates.getParentImpurityCalculator() if (binAggregates.metadata.isContinuous(dummyFeatureIndex)) { - (new ContinuousSplit(dummyFeatureIndex, 0), + ( + new ContinuousSplit(dummyFeatureIndex, 0), ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) } else { val numCategories = binAggregates.metadata.featureArity(dummyFeatureIndex) - (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories), + ( + new CategoricalSplit(dummyFeatureIndex, Array(), numCategories), ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) } } else { @@ -1066,27 +1166,34 @@ private[spark] object RandomForest extends Logging with Serializable { // being spun up that will definitely do no work. val numPartitions = math.min(continuousFeatures.length, input.partitions.length) - input.flatMap { point => - continuousFeatures.iterator - .map(idx => (idx, (point.features(idx), point.weight))) - .filter(_._2._1 != 0.0) - }.aggregateByKey((new OpenHashMap[Double, Double], 0L), numPartitions)( - seqOp = { case ((map, c), (v, w)) => - map.changeValue(v, w, _ + w) - (map, c + 1L) - }, - combOp = { case ((map1, c1), (map2, c2)) => - map2.foreach { case (v, w) => - map1.changeValue(v, w, _ + w) - } - (map1, c1 + c2) + input + .flatMap { point => + continuousFeatures.iterator + .map(idx => (idx, (point.features(idx), point.weight))) + .filter(_._2._1 != 0.0) } - ).map { case (idx, (map, c)) => - val thresholds = findSplitsForContinuousFeature(map.toMap, c, metadata, idx) - val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) - logDebug(s"featureIndex = $idx, numSplits = ${splits.length}") - (idx, splits) - }.collectAsMap() + .aggregateByKey((new OpenHashMap[Double, Double], 0L), numPartitions)( + seqOp = { + case ((map, c), (v, w)) => + map.changeValue(v, w, _ + w) + (map, c + 1L) + }, + combOp = { + case ((map1, c1), (map2, c2)) => + map2.foreach { + case (v, w) => + map1.changeValue(v, w, _ + w) + } + (map1, c1 + c2) + }) + .map { + case (idx, (map, c)) => + val thresholds = findSplitsForContinuousFeature(map.toMap, c, metadata, idx) + val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) + logDebug(s"featureIndex = $idx, numSplits = ${splits.length}") + (idx, splits) + } + .collectAsMap() } else Map.empty[Int, Array[Split]] val numFeatures = metadata.numFeatures @@ -1157,9 +1264,10 @@ private[spark] object RandomForest extends Logging with Serializable { featureIndex: Int): Array[Double] = { val valueWeights = new OpenHashMap[Double, Double] var count = 0L - featureSamples.foreach { case (weight, value) => - valueWeights.changeValue(value, weight, _ + weight) - count += 1L + featureSamples.foreach { + case (weight, value) => + valueWeights.changeValue(value, weight, _ + weight) + count += 1L } findSplitsForContinuousFeature(valueWeights.toMap, count, metadata, featureIndex) } @@ -1182,7 +1290,8 @@ private[spark] object RandomForest extends Logging with Serializable { count: Long, metadata: DecisionTreeMetadata, featureIndex: Int): Array[Double] = { - require(metadata.isContinuous(featureIndex), + require( + metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") val splits = if (partValueWeights.isEmpty) { @@ -1256,7 +1365,8 @@ private[spark] object RandomForest extends Logging with Serializable { private[tree] class NodeIndexInfo( val nodeIndexInGroup: Int, - val featureSubset: Option[Array[Int]]) extends Serializable + val featureSubset: Option[Array[Int]]) + extends Serializable /** * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration. @@ -1294,8 +1404,13 @@ private[spark] object RandomForest extends Logging with Serializable { val (treeIndex, node) = nodeStack.head // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - Some(SamplingUtils.reservoirSampleAndCount(Range(0, - metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1) + Some( + SamplingUtils + .reservoirSampleAndCount( + Range(0, metadata.numFeatures).iterator, + metadata.numFeaturesPerNode, + rng.nextLong()) + ._1) } else { None } @@ -1303,11 +1418,13 @@ private[spark] object RandomForest extends Logging with Serializable { val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) { nodeStack.remove(0) - mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) += + mutableNodesForGroup.getOrElseUpdate( + treeIndex, + new mutable.ArrayBuffer[LearningNode]()) += node mutableTreeToNodeToIndexInfo - .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) - = new NodeIndexInfo(numNodesInGroup, featureSubset) + .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) = + new NodeIndexInfo(numNodesInGroup, featureSubset) numNodesInGroup += 1 memUsage += nodeMemUsage } else { @@ -1355,8 +1472,7 @@ private[spark] object RandomForest extends Logging with Serializable { * @param metadata decision tree metadata * @return subsample fraction */ - private def samplesFractionForFindSplits( - metadata: DecisionTreeMetadata): Double = { + private def samplesFractionForFindSplits(metadata: DecisionTreeMetadata): Double = { // Calculate the number of samples for approximate quantile calculation. val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) if (requiredSamples < metadata.numExamples) { @@ -1365,4 +1481,5 @@ private[spark] object RandomForest extends Logging with Serializable { 1.0 } } + } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 768e14f4b74e4..e5f542366be75 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -211,10 +211,27 @@ private[ml] trait TreeClassifierParams extends Params { (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) - setDefault(impurity -> "gini") + /** + * If true, the trained tree will undergo a pruning process after training, in which nodes + * with the same class predictions are merged. The resulting tree will be smaller and have + * faster predictions, but class probabilities will be lost. + * If false, no pruning is applied after training, and class probabilities are preserved. + * (default = true) + * @group param + */ + final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" + + "If true, the trained tree will undergo a pruning process after training, in which nodes" + + " with the same class predictions are merged. The resulting tree will be smaller and have" + + " faster predictions, but class probabilities will be lost." + + " If false, no pruning is applied after training, and class probabilities are preserved." + ) + + setDefault(impurity -> "gini", pruneTree -> true) /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) + /** @group getParam */ + final def getPruneTree: Boolean = $(pruneTree) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 200d10130eed7..85f4bcc642677 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -55,6 +55,8 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * @param minInfoGain Minimum information gain a split must get. Default value is 0.0. * If a split has less information gain than minInfoGain, * this split will not be considered as a valid split. + * @param pruneTree If this is true, the final training tree will undergo a pruning in which + * nodes with the same classifications are merged. * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is * 256 MB. If too small, then 1 node will be split per iteration, and * its aggregates may exceed this size. @@ -77,6 +79,7 @@ class Strategy @Since("1.3.0") ( @Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), @Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1, @Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0, + @Since("4.3.0") @BeanProperty var pruneTree: Boolean = true, @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256, @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, @@ -113,12 +116,13 @@ class Strategy @Since("1.3.0") ( categoricalFeaturesInfo: Map[Int, Int], minInstancesPerNode: Int, minInfoGain: Double, + pruneTree: Boolean, maxMemoryInMB: Int, subsamplingRate: Double, useNodeIdCache: Boolean, checkpointInterval: Int) = { this(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, - categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB, + categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, pruneTree, maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval, 0.0) } // scalastyle:on argcount @@ -200,7 +204,7 @@ class Strategy @Since("1.3.0") ( def copy: Strategy = { new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, - minInfoGain, maxMemoryInMB, subsamplingRate, useNodeIdCache, + minInfoGain, pruneTree, maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval, minWeightFractionPerNode) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 62f25474e9476..0c60441813159 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -72,8 +72,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(splits(0).length === 0) } - test("Binary classification with 3-ary (ordered) categorical features," + - " with no samples for one category: split calculation") { + test( + "Binary classification with 3-ary (ordered) categorical features," + + " with no samples for one category: split calculation") { val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance) assert(arr.length === 1000) val rdd = sc.parallelize(arr.toImmutableArraySeq) @@ -108,16 +109,29 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // SPARK-16957: Use midpoints for split values. { - val fakeMetadata = new DecisionTreeMetadata(1, 8, 8.0, 0, 0, - Map(), Set(), - Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0.0, 0, 0 - ) + val fakeMetadata = new DecisionTreeMetadata( + 1, + 8, + 8.0, + 0, + 0, + Map(), + Set(), + Array(3), + Gini, + QuantileStrategy.Sort, + 0, + 0, + 0.0, + 0.0, + 0, + 0) // possibleSplits <= numSplits { val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1) - .map(x => (1.0, x.toDouble)).filter(_._2 != 0.0) + .map(x => (1.0, x.toDouble)) + .filter(_._2 != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2) assert(splits === expectedSplits) @@ -126,7 +140,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // possibleSplits > numSplits { val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3) - .map(x => (1.0, x.toDouble)).filter(_._2 != 0.0) + .map(x => (1.0, x.toDouble)) + .filter(_._2 != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) assert(splits === expectedSplits) @@ -136,11 +151,23 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { - val fakeMetadata = new DecisionTreeMetadata(1, 12, 12.0, 0, 0, - Map(), Set(), - Array(5), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0.0, 0, 0 - ) + val fakeMetadata = new DecisionTreeMetadata( + 1, + 12, + 12.0, + 0, + 0, + Map(), + Set(), + Array(5), + Gini, + QuantileStrategy.Sort, + 0, + 0, + 0.0, + 0.0, + 0, + 0) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(x => (1.0, x.toDouble)) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2) @@ -151,11 +178,23 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the minimum { - val fakeMetadata = new DecisionTreeMetadata(1, 18, 18.0, 0, 0, - Map(), Set(), - Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0.0, 0, 0 - ) + val fakeMetadata = new DecisionTreeMetadata( + 1, + 18, + 18.0, + 0, + 0, + Map(), + Set(), + Array(3), + Gini, + QuantileStrategy.Sort, + 0, + 0, + 0.0, + 0.0, + 0, + 0) val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(x => (1.0, x.toDouble)) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -165,11 +204,23 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the maximum { - val fakeMetadata = new DecisionTreeMetadata(1, 17, 17.0, 0, 0, - Map(), Set(), - Array(2), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0.0, 0, 0 - ) + val fakeMetadata = new DecisionTreeMetadata( + 1, + 17, + 17.0, + 0, + 0, + Map(), + Set(), + Array(2), + Gini, + QuantileStrategy.Sort, + 0, + 0, + 0.0, + 0.0, + 0, + 0) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(x => (1.0, x.toDouble)) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -199,11 +250,23 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most weight is close to the minimum { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0, - Map(), Set(), - Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0.0, 0, 0 - ) + val fakeMetadata = new DecisionTreeMetadata( + 1, + 0, + 0.0, + 0, + 0, + Map(), + Set(), + Array(3), + Gini, + QuantileStrategy.Sort, + 0, + 0, + 0.0, + 0.0, + 0, + 0) val featureSamples = Array((10, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6)).map { case (w, x) => (w.toDouble, x.toDouble) } @@ -217,10 +280,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val data = Array.fill(5)(lp) val rdd = sc.parallelize(data.toImmutableArraySeq) - val strategy = new OldStrategy(OldAlgo.Regression, Gini, maxDepth = 2, - maxBins = 5) - withClue("DecisionTree requires number of features > 0," + - " but was given an empty features vector") { + val strategy = new OldStrategy(OldAlgo.Regression, Gini, maxDepth = 2, maxBins = 5) + withClue( + "DecisionTree requires number of features > 0," + + " but was given an empty features vector") { intercept[IllegalArgumentException] { RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) } @@ -232,23 +295,19 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val data = Array.fill(5)(instance) val rdd = sc.parallelize(data.toImmutableArraySeq) val strategy = new OldStrategy( - OldAlgo.Classification, - Gini, - maxDepth = 2, - numClasses = 2, - maxBins = 5, - categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5)) + OldAlgo.Classification, + Gini, + maxDepth = 2, + numClasses = 2, + maxBins = 5, + categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5)) val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) assert(tree.rootNode.impurity === -1.0) assert(tree.depth === 0) assert(tree.rootNode.prediction === instance.label) // Test with no categorical features - val strategy2 = new OldStrategy( - OldAlgo.Regression, - Variance, - maxDepth = 2, - maxBins = 5) + val strategy2 = new OldStrategy(OldAlgo.Regression, Variance, maxDepth = 2, maxBins = 5) val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = None) assert(tree2.rootNode.impurity === -1.0) assert(tree2.depth === 0) @@ -279,12 +338,15 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(metadata.numBins(1) === 3) // Expecting 2^2 - 1 = 3 splits per feature - def checkCategoricalSplit(s: Split, featureIndex: Int, leftCategories: Array[Double]): Unit = { + def checkCategoricalSplit( + s: Split, + featureIndex: Int, + leftCategories: Array[Double]): Unit = { assert(s.featureIndex === featureIndex) assert(s.isInstanceOf[CategoricalSplit]) val s0 = s.asInstanceOf[CategoricalSplit] assert(s0.leftCategories === leftCategories) - assert(s0.numCategories === 3) // for this unit test + assert(s0.numCategories === 3) // for this unit test } // Feature 0 checkCategoricalSplit(splits(0)(0), 0, Array(0.0)) @@ -297,7 +359,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } test("Multiclass classification with ordered categorical features: split calculations") { - val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + val arr = OldDTSuite + .generateCategoricalDataPointsForMulticlassForOrderedFeatures() .map(_.asML.toInstance) assert(arr.length === 3000) val rdd = sc.parallelize(arr.toImmutableArraySeq) @@ -332,8 +395,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) val input = sc.parallelize(arr.map(_.toInstance).toImmutableArraySeq) - val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val strategy = new OldStrategy( + algo = OldAlgo.Classification, + impurity = Gini, + maxDepth = 1, + numClasses = 2, + categoricalFeaturesInfo = Map(0 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) val splits = RandomForest.findSplits(input, metadata, seed = 42) val bcSplits = input.sparkContext.broadcast(splits) @@ -346,12 +413,17 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats === null) val nodesForGroup = Map(0 -> Array(topNode)) - val treeToNodeToIndexInfo = Map(0 -> Map( - topNode.id -> new RandomForest.NodeIndexInfo(0, None) - )) + val treeToNodeToIndexInfo = + Map(0 -> Map(topNode.id -> new RandomForest.NodeIndexInfo(0, None))) val nodeStack = new mutable.ListBuffer[(Int, LearningNode)] - RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), - nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack) + RandomForest.findBestSplits( + baggedInput, + metadata, + Map(0 -> topNode), + nodesForGroup, + treeToNodeToIndexInfo, + bcSplits, + nodeStack) bcSplits.destroy() // don't enqueue leaf nodes into node queue @@ -376,8 +448,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) val input = sc.parallelize(arr.map(_.toInstance).toImmutableArraySeq) - val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val strategy = new OldStrategy( + algo = OldAlgo.Classification, + impurity = Gini, + maxDepth = 5, + numClasses = 2, + categoricalFeaturesInfo = Map(0 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) val splits = RandomForest.findSplits(input, metadata, seed = 42) val bcSplits = input.sparkContext.broadcast(splits) @@ -390,12 +466,17 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats === null) val nodesForGroup = Map(0 -> Array(topNode)) - val treeToNodeToIndexInfo = Map(0 -> Map( - topNode.id -> new RandomForest.NodeIndexInfo(0, None) - )) + val treeToNodeToIndexInfo = + Map(0 -> Map(topNode.id -> new RandomForest.NodeIndexInfo(0, None))) val nodeStack = new mutable.ListBuffer[(Int, LearningNode)] - RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), - nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack) + RandomForest.findBestSplits( + baggedInput, + metadata, + Map(0 -> topNode), + nodesForGroup, + treeToNodeToIndexInfo, + bcSplits, + nodeStack) bcSplits.destroy() // don't enqueue a node into node queue if its impurity is 0.0 @@ -431,18 +512,32 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val input = sc.parallelize(arr.map(_.toInstance).toImmutableArraySeq) // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. - val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) - - val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None, prune = false).head + val strategy = new OldStrategy( + algo = OldAlgo.Classification, + impurity = Gini, + maxDepth = 1, + numClasses = 2, + categoricalFeaturesInfo = Map(0 -> 3), + maxBins = 3) + + strategy.pruneTree = false + val model = RandomForest + .run( + input, + strategy, + numTrees = 1, + featureSubsetStrategy = "all", + seed = 42, + instr = None) + .head model.rootNode match { - case n: InternalNode => n.split match { - case s: CategoricalSplit => - assert(s.leftCategories === Array(1.0)) - case _ => fail("model.rootNode.split was not a CategoricalSplit") - } + case n: InternalNode => + n.split match { + case s: CategoricalSplit => + assert(s.leftCategories === Array(1.0)) + case _ => fail("model.rootNode.split was not a CategoricalSplit") + } case _ => fail("model.rootNode was not an InternalNode") } } @@ -458,18 +553,21 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val strategy2 = new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 0) - val tree1 = RandomForest.run(rdd, strategy1, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None).head - val tree2 = RandomForest.run(rdd, strategy2, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None).head - - def getChildren(rootNode: Node): Array[InternalNode] = rootNode match { - case n: InternalNode => - assert(n.leftChild.isInstanceOf[InternalNode]) - assert(n.rightChild.isInstanceOf[InternalNode]) - Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode]) - case _ => fail("rootNode was not an InternalNode") - } + val tree1 = RandomForest + .run(rdd, strategy1, numTrees = 1, featureSubsetStrategy = "all", seed = 42, instr = None) + .head + val tree2 = RandomForest + .run(rdd, strategy2, numTrees = 1, featureSubsetStrategy = "all", seed = 42, instr = None) + .head + + def getChildren(rootNode: Node): Array[InternalNode] = + rootNode match { + case n: InternalNode => + assert(n.leftChild.isInstanceOf[InternalNode]) + assert(n.rightChild.isInstanceOf[InternalNode]) + Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode]) + case _ => fail("rootNode was not an InternalNode") + } // Single group second level tree construction. val children1 = getChildren(tree1.rootNode) @@ -515,8 +613,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { nodeStack.prepend((treeIndex, topNodes(treeIndex))) } val rng = new scala.util.Random(seed = seed) - val (nodesForGroup: Map[Int, Array[LearningNode]], - treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) = + val ( + nodesForGroup: Map[Int, Array[LearningNode]], + treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) = RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) assert(nodesForGroup.size === numTrees, failString) @@ -524,12 +623,15 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { if (numFeaturesPerNode == numFeatures) { // featureSubset values should all be None - assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), + assert( + treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), failString) } else { // Check number of features. - assert(treeToNodeToIndexInfo.values.forall(_.values.forall( - _.featureSubset.get.length === numFeaturesPerNode)), failString) + assert( + treeToNodeToIndexInfo.values.forall( + _.values.forall(_.featureSubset.get.length === numFeaturesPerNode)), + failString) } } } @@ -537,7 +639,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures) checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures) checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt) - checkFeatureSubsetStrategy(numTrees = 1, "log2", + checkFeatureSubsetStrategy( + numTrees = 1, + "log2", (math.log(numFeatures) / math.log(2)).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) @@ -555,7 +659,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0") for (invalidStrategy <- invalidStrategies) { - intercept[IllegalArgumentException]{ + intercept[IllegalArgumentException] { val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy) } @@ -564,7 +668,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) - checkFeatureSubsetStrategy(numTrees = 2, "log2", + checkFeatureSubsetStrategy( + numTrees = 2, + "log2", (math.log(numFeatures) / math.log(2)).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) @@ -578,7 +684,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) } for (invalidStrategy <- invalidStrategies) { - intercept[IllegalArgumentException]{ + intercept[IllegalArgumentException] { val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy) } @@ -587,15 +693,23 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("Binary classification with continuous features: subsampling features") { val categoricalFeaturesInfo = Map.empty[Int, Int] - val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2, - numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + val strategy = new OldStrategy( + algo = OldAlgo.Classification, + impurity = Gini, + maxDepth = 2, + numClasses = 2, + categoricalFeaturesInfo = categoricalFeaturesInfo) binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) } test("Binary classification with continuous features and node Id cache: subsampling features") { val categoricalFeaturesInfo = Map.empty[Int, Int] - val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2, - numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + val strategy = new OldStrategy( + algo = OldAlgo.Classification, + impurity = Gini, + maxDepth = 2, + numClasses = 2, + categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) } @@ -648,7 +762,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2) val tree2norm = feature0importance + feature1importance - val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0, + val expected = Vectors.dense( + (1.0 + feature0importance / tree2norm) / 2.0, (feature1importance / tree2norm) / 2.0) assert(importances ~== expected relTol 0.01) } @@ -682,18 +797,45 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr.toImmutableArraySeq) val numClasses = 2 - val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, - numClasses = numClasses, maxBins = 32) - - val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", - seed = 42, instr = None).head - - val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", - seed = 42, instr = None, prune = false).head + val strategy = new OldStrategy( + algo = OldAlgo.Classification, + impurity = Gini, + maxDepth = 4, + numClasses = numClasses, + maxBins = 32) + + strategy.pruneTree = true + val prunedTree = RandomForest + .run( + rdd, + strategy, + numTrees = 1, + featureSubsetStrategy = "auto", + seed = 42, + instr = None) + .head + + strategy.pruneTree = false + val unprunedTree = RandomForest + .run( + rdd, + strategy, + numTrees = 1, + featureSubsetStrategy = "auto", + seed = 42, + instr = None) + .head + + strategy.pruneTree = true + val defaultBehaviorTree = RandomForest + .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed = 42, instr = None) + .head assert(prunedTree.numNodes === 5) assert(unprunedTree.numNodes === 7) + assert(defaultBehaviorTree.numNodes == prunedTree.numNodes) + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.length) } @@ -712,17 +854,45 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ) val rdd = sc.parallelize(arr.toImmutableArraySeq) - val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4, - numClasses = 0, maxBins = 32) - - val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", - seed = 42, instr = None).head - - val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", - seed = 42, instr = None, prune = false).head + val strategy = new OldStrategy( + algo = OldAlgo.Regression, + impurity = Variance, + maxDepth = 4, + numClasses = 0, + maxBins = 32) + + strategy.pruneTree = true + val prunedTree = RandomForest + .run( + rdd, + strategy, + numTrees = 1, + featureSubsetStrategy = "auto", + seed = 42, + instr = None) + .head + + strategy.pruneTree = false + val unprunedTree = RandomForest + .run( + rdd, + strategy, + numTrees = 1, + featureSubsetStrategy = "auto", + seed = 42, + instr = None) + .head + + strategy.pruneTree = true + val defaultBehaviorTree = RandomForest + .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed = 42, instr = None) + .head assert(prunedTree.numNodes === 3) assert(unprunedTree.numNodes === 5) + + assert(defaultBehaviorTree.numNodes == prunedTree.numNodes) + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.length) } @@ -739,13 +909,15 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val unitWeightTrees = RandomForest.run(rddWithUnitWeights, strategy, 3, "all", 42L, None) val smallWeightTrees = RandomForest.run(rddWithSmallWeights, strategy, 3, "all", 42L, None) - unitWeightTrees.zip(smallWeightTrees).foreach { case (unitTree, smallWeightTree) => - TreeTests.checkEqual(unitTree, smallWeightTree) + unitWeightTrees.zip(smallWeightTrees).foreach { + case (unitTree, smallWeightTree) => + TreeTests.checkEqual(unitTree, smallWeightTree) } val bigWeightTrees = RandomForest.run(rddWithBigWeights, strategy, 3, "all", 42L, None) - unitWeightTrees.zip(bigWeightTrees).foreach { case (unitTree, bigWeightTree) => - TreeTests.checkEqual(unitTree, bigWeightTree) + unitWeightTrees.zip(bigWeightTrees).foreach { + case (unitTree, bigWeightTree) => + TreeTests.checkEqual(unitTree, bigWeightTree) } } @@ -778,6 +950,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } private object RandomForestSuite { + def mapToVec(map: Map[Int, Double]): Vector = { val size = (map.keys.toSeq :+ 0).max + 1 val (indices, values) = map.toSeq.sortBy(_._1).unzip @@ -788,12 +961,12 @@ private object RandomForestSuite { private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long = { if (nodes.isEmpty) { acc - } - else { + } else { nodes.head match { case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc) case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.rawCount) } } } + } diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index f69ecf115f5ab..8f0646e2b24d0 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1678,6 +1678,7 @@ def __init__(self, *args: Any): maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + pruneTree=True, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -1789,6 +1790,7 @@ def __init__( maxBins: int = 32, minInstancesPerNode: int = 1, minInfoGain: float = 0.0, + pruneTree: bool = True, maxMemoryInMB: int = 256, cacheNodeIds: bool = False, checkpointInterval: int = 10, @@ -1801,7 +1803,7 @@ def __init__( """ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0) """ @@ -1826,6 +1828,7 @@ def setParams( maxBins: int = 32, minInstancesPerNode: int = 1, minInfoGain: float = 0.0, + pruneTree: bool = True, maxMemoryInMB: int = 256, cacheNodeIds: bool = False, checkpointInterval: int = 10, @@ -1838,7 +1841,7 @@ def setParams( """ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0) Sets params for the DecisionTreeClassifier. @@ -1861,6 +1864,12 @@ def setMaxBins(self, value: int) -> "DecisionTreeClassifier": """ return self._set(maxBins=value) + def setPruneTree(self, value: bool) -> "DecisionTreeClassifier": + """ + Sets the value of :py:attr:`pruneTree`. + """ + return self._set(pruneTree=value) + def setMinInstancesPerNode(self, value: int) -> "DecisionTreeClassifier": """ Sets the value of :py:attr:`minInstancesPerNode`. @@ -1972,6 +1981,7 @@ def __init__(self, *args: Any): maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + pruneTree=True, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -2081,6 +2091,7 @@ def __init__( maxBins: int = 32, minInstancesPerNode: int = 1, minInfoGain: float = 0.0, + pruneTree: bool = True, maxMemoryInMB: int = 256, cacheNodeIds: bool = False, checkpointInterval: int = 10, @@ -2097,7 +2108,7 @@ def __init__( """ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \ leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True) @@ -2123,6 +2134,7 @@ def setParams( maxBins: int = 32, minInstancesPerNode: int = 1, minInfoGain: float = 0.0, + pruneTree: bool = True, maxMemoryInMB: int = 256, cacheNodeIds: bool = False, checkpointInterval: int = 10, @@ -2139,7 +2151,7 @@ def setParams( """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \ impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, \ leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True) @@ -2163,6 +2175,12 @@ def setMaxBins(self, value: int) -> "RandomForestClassifier": """ return self._set(maxBins=value) + def setPruneTree(self, value: bool) -> "RandomForestClassifier": + """ + Sets the value of :py:attr:`pruneTree`. + """ + return self._set(pruneTree=value) + def setMinInstancesPerNode(self, value: int) -> "RandomForestClassifier": """ Sets the value of :py:attr:`minInstancesPerNode`. diff --git a/python/pyspark/ml/tree.py b/python/pyspark/ml/tree.py index 63f58272aeefb..41b8bdc600c56 100644 --- a/python/pyspark/ml/tree.py +++ b/python/pyspark/ml/tree.py @@ -415,6 +415,13 @@ class _TreeClassifierParams(Params): typeConverter=TypeConverters.toString, ) + pruneTree = Param(Params._dummy(), "pruneTree", "" + + "If true, the trained tree will undergo a pruning process after training, in which nodes" + + " with the same class predictions are merged. The resulting tree will be smaller and have" + + " faster predictions, but class probabilities will be lost." + + " If false, no pruning is applied after training, and class probabilities are preserved.", + typeConverter=TypeConverters.toBoolean) + def __init__(self) -> None: super().__init__() @@ -424,6 +431,12 @@ def getImpurity(self) -> str: Gets the value of impurity or its default value. """ return self.getOrDefault(self.impurity) + @since("4.3.0") + def getPruneTree(self) -> bool: + """ + Gets the value of pruneTree or its default value. + """ + return self.getOrDefault(self.pruneTree) class _TreeRegressorParams(_HasVarianceImpurity):