From a058cd8107666cb8bc5dd090fd1c52aadd896304 Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 7 Aug 2015 17:04:13 -0700 Subject: [PATCH 1/4] Adding stratified sampling to cross validation and train validation split in ml/tuning --- .../apache/spark/rdd/PairRDDFunctions.scala | 76 ++++++++++++++++++- .../util/random/StratifiedSamplingUtils.scala | 27 +++++++ .../spark/ml/tuning/CrossValidator.scala | 31 +++++++- .../ml/tuning/TrainValidationSplit.scala | 32 +++++++- .../org/apache/spark/mllib/util/MLUtils.scala | 25 +++++- 5 files changed, 184 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 104e0cb37155f..24fc5cd8e8c98 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -288,7 +288,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def sampleByKeyExact( withReplacement: Boolean, fractions: Map[K, Double], - seed: Long = Utils.random.nextLong): RDD[(K, V)] = self.withScope { + seed: Long = Utils.random.nextLong, + complement: Boolean = false): RDD[(K, V)] = self.withScope { require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") @@ -301,9 +302,76 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Merge the values for each key using an associative and commutative reduce function. This will - * also perform the merging locally on each mapper before sending results to a reducer, similarly - * to a "combiner" in MapReduce. + * ::Experimental:: + * Return random, non-overlapping splits of this RDD sampled by key (via stratified sampling) + * with each split containing exactly math.ceil(numItems * samplingRate) for each stratum. + * + * This method differs from [[sampleByKey]] and [[sampleByKeyExact]] in that it provides random + * splits (and their complements) instead of just a subsample of the data. This requires segmenting + * random keys into ranges with upper and lower bounds instead of segmenting the keys into a high/low + * bisection of the entire dataset. + * + * @param weights array of maps of specific keys to sampling rates for each split, normalized by key to sum to 1 + * @param exact boolean specifying whether to use exact subsampling + * @param seed seed for the random number generator + * @return Array of tuples containing the subsample and complement RDDs for each split + */ + @Experimental + def randomSplitByKey( + weights: Array[Map[K, Double]], + exact: Boolean = false, + seed: Long = Utils.random.nextLong): Array[(RDD[(K, V)], RDD[(K, V)])] = self.withScope { + + require(weights.flatMap(_.values).forall(v => v >= 0.0), "Negative sampling rates.") + + // normalize and cumulative sum + val baseFold = weights(0).map(x => (x._1, 0.0)) + val cumWeightsByKey = weights.scanLeft(baseFold){ case (accMap, iterMap) => + accMap.map { case (k, v) => (k, v + iterMap(k)) } + }.drop(1) + + val weightSumsByKey = cumWeightsByKey.last + val normalizedCumWeightsByKey = cumWeightsByKey.dropRight(1).map(_.map { case (key, threshold) => + (key, threshold / weightSumsByKey(key)) + }) + + // compute exact thresholds for each stratum if required + val splitArray = if (exact) { + normalizedCumWeightsByKey.map { fractions => + val finalResult = StratifiedSamplingUtils.getAcceptanceResults(self, false, fractions, None, seed) + StratifiedSamplingUtils.computeThresholdByKey(finalResult, fractions) + } + } else normalizedCumWeightsByKey + + // get the exact threshold for each segment + val totalSplitArray = weights(0).map(x => (x._1, 0.0)) +: splitArray :+ weights(0).map(x => (x._1, 1.0)) + totalSplitArray.sliding(2).map { x => + (randomSampleByKeyWithRange(x(0), x(1), seed), randomSampleByKeyWithRange(x(0), x(1), seed, complement = true)) + }.toArray + } + + /** + * Internal method exposed for Stratified Random Splits in DataFrames. Samples an RDD given probability + * bounds for each stratum. + * + * @param lb map of lower bound for each key to use for the Bernoulli cell sampler + * @param ub map of upper bound for each key to use for the Bernoulli cell sampler + * @param seed the seed for the Random number generator + * @param complement boolean specifying whether to return subsample or its complement + * @return A random, stratified sub-sample of the RDD without replacement. + */ + private[spark] def randomSampleByKeyWithRange(lb: Map[K, Double], + ub: Map[K, Double], + seed: Long, + complement: Boolean = false): RDD[(K, V)] = { + val samplingFunc = StratifiedSamplingUtils.getBernoulliCellSamplingFunction(self, lb, ub, seed, complement) + self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) + } + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. */ def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope { combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner) diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala index 67822749112c6..6ab6a4cd2075b 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -215,6 +215,33 @@ private[spark] object StratifiedSamplingUtils extends Logging { } } + /** + * WIP sample with range + */ + def getBernoulliCellSamplingFunction[K, V](rdd: RDD[(K, V)], + lb: Map[K, Double], + ub: Map[K, Double], + seed: Long, + complement: Boolean = false): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = { + (idx: Int, iter: Iterator[(K, V)]) => { + val rng = new RandomDataGenerator() + rng.reSeed(seed + idx) + // Must use the same invoke pattern on the rng as in getSeqOp for without replacement + // in order to generate the same sequence of random numbers when creating the sample + if (complement) { + iter.filter { t => + val x = rng.nextUniform() + (x < lb(t._1)) || (x >= ub(t._1)) + } + } else { + iter.filter { t => + val x = rng.nextUniform() + (x >= lb(t._1)) && (x < ub(t._1)) + } + } + } + } + /** * Return the per partition sampling function used for sampling with replacement. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 520557849b9e2..6fe27a645529c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -48,10 +48,22 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation (>= 2)", ParamValidators.gtEq(2)) + /** + * Param for stratified sampling column name + * Default: "None" + * @group param + */ + val stratifiedCol: Param[String] = new Param[String](this, "stratifiedCol", "stratified column name") + + /** @group getParam */ + def getStratifiedCol: String = $(stratifiedCol) + /** @group getParam */ def getNumFolds: Int = $(numFolds) setDefault(numFolds -> 3) + setDefault(stratifiedCol -> "None") + } /** @@ -87,6 +99,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) + /** @group setParam */ + @Since("2.0.0") + def setStratifiedCol(value: String): this.type = set(stratifiedCol, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): CrossValidatorModel = { val schema = dataset.schema @@ -97,7 +113,20 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) - val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) + + val splits = if (dataset.columns.contains($(stratifiedCol))) { + // stratified kFold + val stratifiedColIndex = dataset.columns.indexOf($(stratifiedCol)) + val splitsWithKeys = + MLUtils.kFoldStrat(dataset.toDF.rdd.map(row => (row(stratifiedColIndex), row)), $(numFolds), 0) + splitsWithKeys.map { case (training, validation) => (training.values, validation.values)} + } else { + if (isSet(stratifiedCol)) + logWarning(s"Stratified column does not exist. Performing regular k-fold subsampling.") + // regular kFold + MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) + } + splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sparkSession.createDataFrame(training, schema).cache() val validationDataset = sparkSession.createDataFrame(validation, schema).cache() diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 0fdba1cb8814a..338d59d179337 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -27,6 +27,8 @@ import org.json4s.DefaultFormats import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging +import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.param._ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} @@ -47,10 +49,21 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio", "ratio between training set and validation set (>= 0 && <= 1)", ParamValidators.inRange(0, 1)) + /** + * Param for stratified sampling column name + * Default: "None" + * @group param + */ + val stratifiedCol: Param[String] = new Param[String](this, "stratifiedCol", "stratified column name") + /** @group getParam */ def getTrainRatio: Double = $(trainRatio) + /** @group getParam */ + def getStratifiedCol: String = $(stratifiedCol) + setDefault(trainRatio -> 0.75) + setDefault(stratifiedCol -> "None") } /** @@ -87,6 +100,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) + def setStratifiedCol(value: String): this.type = set(stratifiedCol, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { val schema = dataset.schema @@ -98,7 +113,22 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St val metrics = new Array[Double](epm.length) val Array(trainingDataset, validationDataset) = - dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed)) + if (dataset.columns.contains($(stratifiedCol))) { + val stratifiedColIndex = dataset.columns.indexOf($(stratifiedCol)) + val keyedRDD = dataset.toDF.rdd.map(row => (row(stratifiedColIndex), row)) + val keys = keyedRDD.keys.distinct.collect() + val weights: Array[scala.collection.Map[Any, Double]] = + Array(keys.map(k => (k, $(trainRatio))).toMap, + keys.map(k => (k, 1 - $(trainRatio))).toMap) + val splitsWithKeys = keyedRDD.randomSplitByKey(weights, exact = true, 0) + val Array(training, validation) = + splitsWithKeys.map { case (subsample, complement) => subsample.values } + (sqlCtx.createDataFrame(training, schema).cache(), + sqlCtx.createDataFrame(validation, schema).cache()) + } else { + dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed)) + } + trainingDataset.cache() validationDataset.cache() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index e96c2bc6edfc3..43850d0b6bf04 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -26,13 +26,22 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.linalg.{MatrixUDT => MLMatrixUDT, VectorUDT => MLVectorUDT} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS.dot + import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, PairRDDFunctions} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.BernoulliCellSampler +// My changes +import scala.util.Random +import org.apache.spark.util.random.StratifiedSamplingUtils +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import org.apache.spark.util.random.XORShiftRandom + + /** * Helper methods to load, save and pre-process data used in ML Lib. */ @@ -227,6 +236,20 @@ object MLUtils extends Logging { }.toArray } + /** + * Seth Code + */ + @Experimental + def kFoldStrat[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)], + numFolds: Int, + seed: Int): Array[(RDD[(K, V)], RDD[(K, V)])] = { + val keys: Array[K] = rdd.keys.collect().distinct + val weights: Array[scala.collection.Map[K, Double]] = (1 to numFolds).map { + n => keys.map(k => (k, 1 / numFolds.toDouble)).toMap + }.toArray + rdd.randomSplitByKey(weights, exact = true, seed) + } + /** * Returns a new vector with `1.0` (bias) appended to the input vector. */ From 5f244d1cb5bd747e7383a85b54394a2fa9efa32e Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 10 Aug 2015 15:26:38 -0700 Subject: [PATCH 2/4] Adding some tests and style fixes --- .../apache/spark/rdd/PairRDDFunctions.scala | 57 +++--- .../util/random/StratifiedSamplingUtils.scala | 16 +- .../spark/rdd/PairRDDFunctionsSuite.scala | 183 ++++++++++++++++++ .../spark/ml/tuning/CrossValidator.scala | 12 +- .../ml/tuning/TrainValidationSplit.scala | 12 +- .../org/apache/spark/mllib/util/MLUtils.scala | 21 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 3 + .../ml/tuning/TrainValidationSplitSuite.scala | 3 + .../spark/mllib/util/MLUtilsSuite.scala | 31 +++ 9 files changed, 284 insertions(+), 54 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 24fc5cd8e8c98..b3c75c1660e04 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -288,8 +288,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def sampleByKeyExact( withReplacement: Boolean, fractions: Map[K, Double], - seed: Long = Utils.random.nextLong, - complement: Boolean = false): RDD[(K, V)] = self.withScope { + seed: Long = Utils.random.nextLong): RDD[(K, V)] = self.withScope { require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") @@ -307,14 +306,14 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * with each split containing exactly math.ceil(numItems * samplingRate) for each stratum. * * This method differs from [[sampleByKey]] and [[sampleByKeyExact]] in that it provides random - * splits (and their complements) instead of just a subsample of the data. This requires segmenting - * random keys into ranges with upper and lower bounds instead of segmenting the keys into a high/low - * bisection of the entire dataset. + * splits (and their complements) instead of just a subsample of the data. This requires + * segmenting random keys into ranges with upper and lower bounds instead of segmenting the keys + * into a high/low bisection of the entire dataset. * - * @param weights array of maps of specific keys to sampling rates for each split, normalized by key to sum to 1 + * @param weights array of maps of (key -> samplingRate) pairs for each split, normed by key * @param exact boolean specifying whether to use exact subsampling * @param seed seed for the random number generator - * @return Array of tuples containing the subsample and complement RDDs for each split + * @return array of tuples containing the subsample and complement RDDs for each split */ @Experimental def randomSplitByKey( @@ -323,36 +322,47 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) seed: Long = Utils.random.nextLong): Array[(RDD[(K, V)], RDD[(K, V)])] = self.withScope { require(weights.flatMap(_.values).forall(v => v >= 0.0), "Negative sampling rates.") + if (weights.length > 1) { + require(weights.map(m => m.keys.toSet).sliding(2).forall(t => t(0) == t(1)), + "Inconsistent keys between splits.") + } + + // maps of sampling threshold boundaries at 0.0 and 1.0 + val leftBoundary = weights(0).map(x => (x._1, 0.0)) + val rightBoundary = weights(0).map(x => (x._1, 1.0)) // normalize and cumulative sum - val baseFold = weights(0).map(x => (x._1, 0.0)) - val cumWeightsByKey = weights.scanLeft(baseFold){ case (accMap, iterMap) => + val cumWeightsByKey = weights.scanLeft(leftBoundary) { case (accMap, iterMap) => accMap.map { case (k, v) => (k, v + iterMap(k)) } }.drop(1) val weightSumsByKey = cumWeightsByKey.last - val normalizedCumWeightsByKey = cumWeightsByKey.dropRight(1).map(_.map { case (key, threshold) => - (key, threshold / weightSumsByKey(key)) + val normedCumWeightsByKey = cumWeightsByKey.dropRight(1).map(_.map { case (key, threshold) => + val keyWeightSum = weightSumsByKey(key) + val norm = if (keyWeightSum > 0.0) keyWeightSum else 1.0 + (key, threshold / norm) }) // compute exact thresholds for each stratum if required - val splitArray = if (exact) { - normalizedCumWeightsByKey.map { fractions => - val finalResult = StratifiedSamplingUtils.getAcceptanceResults(self, false, fractions, None, seed) - StratifiedSamplingUtils.computeThresholdByKey(finalResult, fractions) + val splitPoints = if (exact) { + normedCumWeightsByKey.map { w => + val finalResult = StratifiedSamplingUtils.getAcceptanceResults(self, false, w, None, seed) + StratifiedSamplingUtils.computeThresholdByKey(finalResult, w) } - } else normalizedCumWeightsByKey + } else { + normedCumWeightsByKey + } - // get the exact threshold for each segment - val totalSplitArray = weights(0).map(x => (x._1, 0.0)) +: splitArray :+ weights(0).map(x => (x._1, 1.0)) - totalSplitArray.sliding(2).map { x => - (randomSampleByKeyWithRange(x(0), x(1), seed), randomSampleByKeyWithRange(x(0), x(1), seed, complement = true)) + val splitsPointsAndBounds = leftBoundary +: splitPoints :+ rightBoundary + splitsPointsAndBounds.sliding(2).map { x => + (randomSampleByKeyWithRange(x(0), x(1), seed), + randomSampleByKeyWithRange(x(0), x(1), seed, complement = true)) }.toArray } /** - * Internal method exposed for Stratified Random Splits in DataFrames. Samples an RDD given probability - * bounds for each stratum. + * Internal method exposed for Stratified Random Splits in DataFrames. Samples an RDD given + * probability bounds for each stratum. * * @param lb map of lower bound for each key to use for the Bernoulli cell sampler * @param ub map of upper bound for each key to use for the Bernoulli cell sampler @@ -364,7 +374,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) ub: Map[K, Double], seed: Long, complement: Boolean = false): RDD[(K, V)] = { - val samplingFunc = StratifiedSamplingUtils.getBernoulliCellSamplingFunction(self, lb, ub, seed, complement) + val samplingFunc = StratifiedSamplingUtils.getBernoulliCellSamplingFunction(self, + lb, ub, seed, complement) self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) } diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala index 6ab6a4cd2075b..dfba6b2b1481d 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -216,7 +216,10 @@ private[spark] object StratifiedSamplingUtils extends Logging { } /** - * WIP sample with range + * Return the per partition sampling function used for partitioning a dataset without + * replacement. + * + * The sampling function has a unique seed per partition. */ def getBernoulliCellSamplingFunction[K, V](rdd: RDD[(K, V)], lb: Map[K, Double], @@ -226,17 +229,16 @@ private[spark] object StratifiedSamplingUtils extends Logging { (idx: Int, iter: Iterator[(K, V)]) => { val rng = new RandomDataGenerator() rng.reSeed(seed + idx) - // Must use the same invoke pattern on the rng as in getSeqOp for without replacement - // in order to generate the same sequence of random numbers when creating the sample + if (complement) { - iter.filter { t => + iter.filter { case(k, _) => val x = rng.nextUniform() - (x < lb(t._1)) || (x >= ub(t._1)) + (x < lb(k)) || (x >= ub(k)) } } else { - iter.filter { t => + iter.filter { case(k, _) => val x = rng.nextUniform() - (x >= lb(t._1)) && (x < ub(t._1)) + (x >= lb(k)) && (x < ub(k)) } } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index b0d69de6e2ef4..4de7168b13adc 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -168,6 +168,118 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } } + test("randomSplitByKey exact") { + val defaultSeed = 1L + + // vary RDD size + for (n <- List(100, 1000, 10000)) { + val data = sc.parallelize(1 to n, 2) + val fractionPositive = 0.3 + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = stratifiedData.keys.distinct().collect() + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, true) + } + + // vary fractionPositive + for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = stratifiedData.keys.distinct().collect() + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, true) + } + + // use same data for remaining tests + val n = 100 + val fractionPositive = 0.3 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = stratifiedData.keys.distinct().collect() + + // use different weights for each key in the split + val unevenWeights: Array[scala.collection.Map[String, Double]] = + Array(Map("0" -> 0.2, "1" -> 0.3), Map("0" -> 0.1, "1" -> 0.4), Map("0" -> 0.7, "1" -> 0.3)) + StratifiedAuxiliary.testSplits(stratifiedData, unevenWeights, defaultSeed, n, true) + + // vary the seed + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + for (seed <- defaultSeed to defaultSeed + 3L) { + StratifiedAuxiliary.testSplits(stratifiedData, weights, seed, n, true) + } + + // vary the number of splits + for (numSplits <- 1 to 3) { + val splitWeights = (1 to numSplits).map(n => 1.toDouble).toArray // check normalization too + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, true) + } + } + + test("randomSplitByKey") { + val defaultSeed = 1L + + // vary RDD size + for (n <- List(100, 1000, 10000)) { + val data = sc.parallelize(1 to n, 2) + val fractionPositive = 0.3 + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = stratifiedData.keys.distinct().collect() + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, false) + } + + // vary fractionPositive + for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = stratifiedData.keys.distinct().collect() + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, false) + } + + // use same data for remaining tests + val n = 100 + val fractionPositive = 0.3 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = stratifiedData.keys.distinct().collect() + + // use different weights for each key in the split + val unevenWeights: Array[scala.collection.Map[String, Double]] = + Array(Map("0" -> 0.2, "1" -> 0.3), Map("0" -> 0.1, "1" -> 0.4), Map("0" -> 0.7, "1" -> 0.3)) + StratifiedAuxiliary.testSplits(stratifiedData, unevenWeights, defaultSeed, n, false) + + // vary the seed + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + for (seed <- defaultSeed to defaultSeed + 5L) { + StratifiedAuxiliary.testSplits(stratifiedData, weights, seed, n, false) + } + + // vary the number of splits + for (numSplits <- 1 to 5) { + val splitWeights = (1 to numSplits).map(n => 1.toDouble).toArray // check normalization too + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, false) + } + } + test("reduceByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) val sums = pairs.reduceByKey(_ + _).collect() @@ -646,6 +758,19 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } } + def checkSplitSize(exact: Boolean, + expected: Long, + actual: Long, + p: Double): Boolean = { + if (exact) { + // all splits will not be exact, but must be within 1 of expected size + return math.abs(expected - actual) <= 1 + } + val stdev = math.sqrt(expected * p * (1 - p)) + // Very forgiving margin since we're dealing with very small sample sizes most of the time + math.abs(actual - expected) <= 6 * stdev + } + def testSampleExact(stratifiedData: RDD[(String, Int)], samplingRate: Double, seed: Long, @@ -662,6 +787,64 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { testPoisson(stratifiedData, false, samplingRate, seed, n) } + def testSplits(stratifiedData: RDD[(String, Int)], + weights: Array[scala.collection.Map[String, Double]], + seed: Long, + n: Int, + exact: Boolean): Unit = { + val baseFold = weights(0).map(x => (x._1, 0.0)) + val totalWeightByKey = weights.foldLeft(baseFold) { case (accMap, iterMap) => + accMap.map { case (k, v) => (k, v + iterMap(k)) } + } + val normedWeights = weights.map(m => m.map { case(k, v) => (k, v / totalWeightByKey(k))}) + + val splits = stratifiedData.randomSplitByKey(weights, exact, seed) + val stratCounts = stratifiedData.countByKey() + + + val expectedSampleSizes = normedWeights.map { m => + stratCounts.map { case (key, count) => + (key, math.ceil(count * m(key)).toLong) + }.toMap + } + val expectedComplementSizes = normedWeights.map { m => + stratCounts.map { case (key, count) => + (key, math.ceil(count * (1 - m(key))).toLong) + }.toMap + } + + val samples = splits.map{ case(subsample, complement) => subsample.collect()} + val complements = splits.map{ case(subsample, complement) => complement.collect()} + + // check for the correct sample size for each split by key + (samples.map(_.groupBy(_._1).map(x => (x._1, x._2.length))) zip expectedSampleSizes) + .zipWithIndex.foreach { case ((actual, expected), idx) => + actual.foreach { case (k, v) => + checkSplitSize(exact, expected(k), v, normedWeights(idx)(k)) + } + } + (complements.map(_.groupBy(_._1).map(x => (x._1, x._2.length))) zip expectedComplementSizes) + .zipWithIndex.foreach { case ((actual, expected), idx) => + actual.foreach{ case (k, v) => + checkSplitSize(exact, expected(k), v, normedWeights(idx)(k)) + } + } + + // make sure samples ++ complements equals the original set + (samples zip complements).foreach { case (sample, complement) => + assert((sample ++ complement).sortBy(_._2).toList == stratifiedData.collect().toList) + } + + // make sure the elements are members of the original set + samples.map(sample => sample.map(x => assert(x._2 >= 1 && x._2 <= n))) + + // make sure no duplicates in each sample + samples.map(sample => assert(sample.length == sample.toSet.size)) + + // make sure that union of all samples equals the original set + assert(samples.flatMap(x => x).sortBy(_._2).toList == stratifiedData.collect().toList) + } + // Without replacement validation def testBernoulli(stratifiedData: RDD[(String, Int)], exact: Boolean, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 6fe27a645529c..efd59da41fc43 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -51,9 +51,11 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { /** * Param for stratified sampling column name * Default: "None" + * * @group param */ - val stratifiedCol: Param[String] = new Param[String](this, "stratifiedCol", "stratified column name") + val stratifiedCol: Param[String] = new Param[String](this, "stratifiedCol", + "stratified column name") /** @group getParam */ def getStratifiedCol: String = $(stratifiedCol) @@ -115,14 +117,14 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val metrics = new Array[Double](epm.length) val splits = if (dataset.columns.contains($(stratifiedCol))) { - // stratified kFold val stratifiedColIndex = dataset.columns.indexOf($(stratifiedCol)) - val splitsWithKeys = - MLUtils.kFoldStrat(dataset.toDF.rdd.map(row => (row(stratifiedColIndex), row)), $(numFolds), 0) + val pairData = dataset.toDF.rdd.map(row => (row(stratifiedColIndex), row)) + val splitsWithKeys = MLUtils.kFoldStratified(pairData, $(numFolds), 0) splitsWithKeys.map { case (training, validation) => (training.values, validation.values)} } else { - if (isSet(stratifiedCol)) + if (isSet(stratifiedCol)) { logWarning(s"Stratified column does not exist. Performing regular k-fold subsampling.") + } // regular kFold MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 338d59d179337..87f73644b6aa5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -27,10 +27,9 @@ import org.json4s.DefaultFormats import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging -import org.apache.spark.ml.evaluation.Evaluator -import org.apache.spark.ml.param._ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -54,7 +53,8 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { * Default: "None" * @group param */ - val stratifiedCol: Param[String] = new Param[String](this, "stratifiedCol", "stratified column name") + val stratifiedCol: Param[String] = new Param[String](this, "stratifiedCol", + "stratified column name") /** @group getParam */ def getTrainRatio: Double = $(trainRatio) @@ -115,12 +115,12 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St val Array(trainingDataset, validationDataset) = if (dataset.columns.contains($(stratifiedCol))) { val stratifiedColIndex = dataset.columns.indexOf($(stratifiedCol)) - val keyedRDD = dataset.toDF.rdd.map(row => (row(stratifiedColIndex), row)) - val keys = keyedRDD.keys.distinct.collect() + val pairData = dataset.rdd.map(row => (row(stratifiedColIndex), row)) + val keys = pairData.keys.distinct.collect() val weights: Array[scala.collection.Map[Any, Double]] = Array(keys.map(k => (k, $(trainRatio))).toMap, keys.map(k => (k, 1 - $(trainRatio))).toMap) - val splitsWithKeys = keyedRDD.randomSplitByKey(weights, exact = true, 0) + val splitsWithKeys = pairData.randomSplitByKey(weights, exact = true, 0) val Array(training, validation) = splitsWithKeys.map { case (subsample, complement) => subsample.values } (sqlCtx.createDataFrame(training, schema).cache(), diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 43850d0b6bf04..315ecf3a4e3d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -26,7 +26,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.linalg.{MatrixUDT => MLMatrixUDT, VectorUDT => MLVectorUDT} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS.dot - import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} @@ -34,14 +33,6 @@ import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, PairRDDFunctions} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.BernoulliCellSampler -// My changes -import scala.util.Random -import org.apache.spark.util.random.StratifiedSamplingUtils -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import org.apache.spark.util.random.XORShiftRandom - - /** * Helper methods to load, save and pre-process data used in ML Lib. */ @@ -237,13 +228,17 @@ object MLUtils extends Logging { } /** - * Seth Code + * Return a k element array of pairs of RDDs with the first element of each pair + * containing the training data, a complement of the validation data and the second + * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds. + * The training and validation data are stratified by the key of the rdd, and the key + * ratios in the original data are maintained in each stratum of the train and validation + * data. */ - @Experimental - def kFoldStrat[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)], + def kFoldStratified[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)], numFolds: Int, seed: Int): Array[(RDD[(K, V)], RDD[(K, V)])] = { - val keys: Array[K] = rdd.keys.collect().distinct + val keys = rdd.keys.distinct().collect() val weights: Array[scala.collection.Map[K, Double]] = (1 to numFolds).map { n => keys.map(k => (k, 1 / numFolds.toDouble)).toMap }.toArray diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 30bd390381e97..d407165a71ec6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -55,6 +55,7 @@ class CrossValidatorSuite .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setNumFolds(3) + .setStratifiedCol("label") val cvModel = cv.fit(dataset) // copied model must have the same paren. @@ -109,6 +110,8 @@ class CrossValidatorSuite .setEstimator(est) .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) + .setNumFolds(3) + .setStratifiedCol("label") cv.transformSchema(new StructType()) // This should pass. diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index c1e9c2fc1dc11..31b3ceaf04918 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -49,6 +49,8 @@ class TrainValidationSplitSuite .setEvaluator(eval) .setTrainRatio(0.5) .setSeed(42L) + .setStratifiedCol("label") + val cvModel = cv.fit(dataset) val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(cv.getTrainRatio === 0.5) @@ -102,6 +104,7 @@ class TrainValidationSplitSuite .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) .setTrainRatio(0.5) + .setStratifiedCol("label") cv.transformSchema(new StructType()) // This should pass. val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 6aa93c9076007..c8993b7aa1193 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -210,6 +210,37 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("kFoldStratified") { + /** + * Most of the functionality of [[kFoldStratified]] is tested in the PairRDD function + * `randomSplitByKey`. All that needs to be checked here is that the folds are even + * splits for each key. + */ + val defaultSeed = 1 + val n = 100 + val data = sc.parallelize(1 to n, 2) + val fractionPositive = 0.3 + val keys = Array("0", "1") + val stratifiedData = data.map { x => + if (x > n*fractionPositive) ("0", x) else ("1", x) + } + val counts = stratifiedData.countByKey() + for (numFolds <- 1 to 3) { + val folds = kFoldStratified(stratifiedData, numFolds, defaultSeed) + val expectedSize = keys.map(k => (k, counts(k) / numFolds.toDouble)).toMap + for ((sample, complement) <- folds) { + val sampleCounts = sample.countByKey() + val complementCounts = complement.countByKey() + sampleCounts.foreach { case(key, count) => + assert(math.abs(count - expectedSize(key)) <= 1) + } + complementCounts.foreach { case(key, count) => + assert(math.abs(count - (counts(key) - expectedSize(key))) <= 1) + } + } + } + } + test("loadVectors") { val vectors = sc.parallelize(Seq( Vectors.dense(1.0, 2.0), From 67f60027158fae37d3f3973fd22217298097ebd7 Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 22 Apr 2016 07:51:13 -0700 Subject: [PATCH 3/4] Refactor for efficiency when computing multiple waitlists. --- .../apache/spark/rdd/PairRDDFunctions.scala | 49 +++--- .../util/random/StratifiedSamplingUtils.scala | 155 ++++++++++-------- .../spark/rdd/PairRDDFunctionsSuite.scala | 111 +++++++------ .../spark/ml/tuning/CrossValidator.scala | 22 +-- .../ml/tuning/TrainValidationSplit.scala | 29 +--- .../spark/ml/tuning/ValidatorParams.scala | 11 ++ .../org/apache/spark/mllib/util/MLUtils.scala | 3 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 31 ++++ 8 files changed, 232 insertions(+), 179 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index b3c75c1660e04..bbcbcc614341b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -327,36 +327,40 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) "Inconsistent keys between splits.") } - // maps of sampling threshold boundaries at 0.0 and 1.0 - val leftBoundary = weights(0).map(x => (x._1, 0.0)) - val rightBoundary = weights(0).map(x => (x._1, 1.0)) + val zeroBoundary = weights.head.map { case (k, v) => (k, 0.0)} // normalize and cumulative sum - val cumWeightsByKey = weights.scanLeft(leftBoundary) { case (accMap, iterMap) => - accMap.map { case (k, v) => (k, v + iterMap(k)) } - }.drop(1) + val sumWeightsByKey = weights.foldLeft(zeroBoundary) { case (acc, splitWeights) => + acc.map { case (k, v) => (k, v + splitWeights(k)) } + } - val weightSumsByKey = cumWeightsByKey.last - val normedCumWeightsByKey = cumWeightsByKey.dropRight(1).map(_.map { case (key, threshold) => - val keyWeightSum = weightSumsByKey(key) - val norm = if (keyWeightSum > 0.0) keyWeightSum else 1.0 - (key, threshold / norm) - }) + val normedCumWeightsByKey = weights.scanLeft(zeroBoundary) { case (acc, splitWeights) => + splitWeights.map { case (key, fraction) => + val keySum = sumWeightsByKey(key) + val norm = if (keySum > 0.0) keySum else 1.0 + key -> (acc(key) + (fraction / norm)) + } + } // compute exact thresholds for each stratum if required - val splitPoints = if (exact) { - normedCumWeightsByKey.map { w => - val finalResult = StratifiedSamplingUtils.getAcceptanceResults(self, false, w, None, seed) - StratifiedSamplingUtils.computeThresholdByKey(finalResult, w) + val splitThresholds = if (exact) { + val left = normedCumWeightsByKey.head + val right = normedCumWeightsByKey.last + val innerSplits = normedCumWeightsByKey.drop(1).dropRight(1) + val finalResults = + StratifiedSamplingUtils.getAcceptanceResults(self, withReplacement = false, innerSplits, + None, seed) + val exactInnerSplits = finalResults.zip(innerSplits).map { case (finalResult, fractions) => + StratifiedSamplingUtils.computeThresholdByKey(finalResult, fractions) } + left +: exactInnerSplits :+ right } else { normedCumWeightsByKey } - val splitsPointsAndBounds = leftBoundary +: splitPoints :+ rightBoundary - splitsPointsAndBounds.sliding(2).map { x => - (randomSampleByKeyWithRange(x(0), x(1), seed), - randomSampleByKeyWithRange(x(0), x(1), seed, complement = true)) + splitThresholds.sliding(2).map { case Array(lb, ub) => + (randomSampleByKeyWithRange(lb, ub, seed), + randomSampleByKeyWithRange(lb, ub, seed, complement = true)) }.toArray } @@ -370,11 +374,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * @param complement boolean specifying whether to return subsample or its complement * @return A random, stratified sub-sample of the RDD without replacement. */ - private[spark] def randomSampleByKeyWithRange(lb: Map[K, Double], + private[spark] def randomSampleByKeyWithRange( + lb: Map[K, Double], ub: Map[K, Double], seed: Long, complement: Boolean = false): RDD[(K, V)] = { - val samplingFunc = StratifiedSamplingUtils.getBernoulliCellSamplingFunction(self, + val samplingFunc = StratifiedSamplingUtils.getBernoulliCellSamplingFunction[K, V]( lb, ub, seed, complement) self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) } diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala index dfba6b2b1481d..8cd8e1178281e 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -50,7 +50,7 @@ import org.apache.spark.rdd.RDD * http://jmlr.org/proceedings/papers/v28/meng13a.html */ -private[spark] object StratifiedSamplingUtils extends Logging { +object StratifiedSamplingUtils extends Logging { /** * Count the number of items instantly accepted and generate the waitlist for each stratum. @@ -59,96 +59,121 @@ private[spark] object StratifiedSamplingUtils extends Logging { */ def getAcceptanceResults[K, V](rdd: RDD[(K, V)], withReplacement: Boolean, - fractions: Map[K, Double], + fractions: Array[Map[K, Double]], counts: Option[Map[K, Long]], - seed: Long): mutable.Map[K, AcceptanceResult] = { + seed: Long): Array[mutable.Map[K, AcceptanceResult]] = { val combOp = getCombOp[K] val mappedPartitionRDD = rdd.mapPartitionsWithIndex { case (partition, iter) => - val zeroU: mutable.Map[K, AcceptanceResult] = new mutable.HashMap[K, AcceptanceResult]() - val rng = new RandomDataGenerator() - rng.reSeed(seed + partition) - val seqOp = getSeqOp(withReplacement, fractions, rng, counts) + val zeroU: Array[mutable.Map[K, AcceptanceResult]] = Array.fill(fractions.length) { + new mutable.HashMap[K, AcceptanceResult]() + } + val rngs = Array.fill(fractions.length) { + val rng = new RandomDataGenerator() + rng.reSeed(seed + partition) + rng + } + val seqOp = getSeqOp(withReplacement, fractions, rngs, counts) Iterator(iter.aggregate(zeroU)(seqOp, combOp)) } mappedPartitionRDD.reduce(combOp) } + def getAcceptanceResults[K, V]( + rdd: RDD[(K, V)], + withReplacement: Boolean, + fractions: Map[K, Double], + counts: Option[Map[K, Long]], + seed: Long): mutable.Map[K, AcceptanceResult] = { + getAcceptanceResults(rdd, withReplacement, Array(fractions), counts, seed).head + } + /** * Returns the function used by aggregate to collect sampling statistics for each partition. */ def getSeqOp[K, V](withReplacement: Boolean, - fractions: Map[K, Double], - rng: RandomDataGenerator, + fractions: Array[Map[K, Double]], + rngs: Array[RandomDataGenerator], counts: Option[Map[K, Long]]): - (mutable.Map[K, AcceptanceResult], (K, V)) => mutable.Map[K, AcceptanceResult] = { + (Array[mutable.Map[K, AcceptanceResult]], (K, V)) => Array[mutable.Map[K, AcceptanceResult]] = { val delta = 5e-5 - (result: mutable.Map[K, AcceptanceResult], item: (K, V)) => { + (results: Array[mutable.Map[K, AcceptanceResult]], item: (K, V)) => { val key = item._1 - val fraction = fractions(key) - if (!result.contains(key)) { - result += (key -> new AcceptanceResult()) - } - val acceptResult = result(key) - if (withReplacement) { - // compute acceptBound and waitListBound only if they haven't been computed already - // since they don't change from iteration to iteration. - // TODO change this to the streaming version - if (acceptResult.areBoundsEmpty) { - val n = counts.get(key) - val sampleSize = math.ceil(n * fraction).toLong - val lmbd1 = PoissonBounds.getLowerBound(sampleSize) - val lmbd2 = PoissonBounds.getUpperBound(sampleSize) - acceptResult.acceptBound = lmbd1 / n - acceptResult.waitListBound = (lmbd2 - lmbd1) / n - } - val acceptBound = acceptResult.acceptBound - val copiesAccepted = if (acceptBound == 0.0) 0L else rng.nextPoisson(acceptBound) - if (copiesAccepted > 0) { - acceptResult.numAccepted += copiesAccepted + var j = 0 + while (j < fractions.length) { + val fraction = fractions(j)(key) + if (!results(j).contains(key)) { + results(j) += (key -> new AcceptanceResult()) } - val copiesWaitlisted = rng.nextPoisson(acceptResult.waitListBound) - if (copiesWaitlisted > 0) { - acceptResult.waitList ++= ArrayBuffer.fill(copiesWaitlisted)(rng.nextUniform()) - } - } else { - // We use the streaming version of the algorithm for sampling without replacement to avoid - // using an extra pass over the RDD for computing the count. - // Hence, acceptBound and waitListBound change on every iteration. - acceptResult.acceptBound = - BinomialBounds.getLowerBound(delta, acceptResult.numItems, fraction) - acceptResult.waitListBound = - BinomialBounds.getUpperBound(delta, acceptResult.numItems, fraction) + val acceptResult = results(j)(key) - val x = rng.nextUniform() - if (x < acceptResult.acceptBound) { - acceptResult.numAccepted += 1 - } else if (x < acceptResult.waitListBound) { - acceptResult.waitList += x + if (withReplacement) { + // compute acceptBound and waitListBound only if they haven't been computed already + // since they don't change from iteration to iteration. + // TODO change this to the streaming version + if (acceptResult.areBoundsEmpty) { + val n = counts.get(key) + val sampleSize = math.ceil(n * fraction).toLong + val lmbd1 = PoissonBounds.getLowerBound(sampleSize) + val lmbd2 = PoissonBounds.getUpperBound(sampleSize) + acceptResult.acceptBound = lmbd1 / n + acceptResult.waitListBound = (lmbd2 - lmbd1) / n + } + val acceptBound = acceptResult.acceptBound + val copiesAccepted = if (acceptBound == 0.0) 0L else rngs(j).nextPoisson(acceptBound) + if (copiesAccepted > 0) { + acceptResult.numAccepted += copiesAccepted + } + val copiesWaitlisted = rngs(j).nextPoisson(acceptResult.waitListBound) + if (copiesWaitlisted > 0) { + acceptResult.waitList ++= ArrayBuffer.fill(copiesWaitlisted)(rngs(j).nextUniform()) + } + } else { + // We use the streaming version of the algorithm for sampling without replacement to avoid + // using an extra pass over the RDD for computing the count. + // Hence, acceptBound and waitListBound change on every iteration. + acceptResult.acceptBound = + BinomialBounds.getLowerBound(delta, acceptResult.numItems, fraction) + acceptResult.waitListBound = + BinomialBounds.getUpperBound(delta, acceptResult.numItems, fraction) + + val x = rngs(j).nextUniform() + if (x < acceptResult.acceptBound) { + acceptResult.numAccepted += 1 + } else if (x < acceptResult.waitListBound) { + acceptResult.waitList += x + } } + acceptResult.numItems += 1 + + j += 1 } - acceptResult.numItems += 1 - result + results } } /** * Returns the function used combine results returned by seqOp from different partitions. */ - def getCombOp[K]: (mutable.Map[K, AcceptanceResult], mutable.Map[K, AcceptanceResult]) - => mutable.Map[K, AcceptanceResult] = { - (result1: mutable.Map[K, AcceptanceResult], result2: mutable.Map[K, AcceptanceResult]) => { - // take union of both key sets in case one partition doesn't contain all keys - result1.keySet.union(result2.keySet).foreach { key => - // Use result2 to keep the combined result since r1 is usual empty - val entry1 = result1.get(key) - if (result2.contains(key)) { - result2(key).merge(entry1) - } else { - if (entry1.isDefined) { - result2 += (key -> entry1.get) + def getCombOp[K]: (Array[mutable.Map[K, AcceptanceResult]], + Array[mutable.Map[K, AcceptanceResult]]) => Array[mutable.Map[K, AcceptanceResult]] = { + (result1: Array[mutable.Map[K, AcceptanceResult]], + result2: Array[mutable.Map[K, AcceptanceResult]]) => { + var j = 0 + while (j < result1.length) { + // take union of both key sets in case one partition doesn't contain all keys + result1(j).keySet.union(result2(j).keySet).foreach { key => + // Use result2 to keep the combined result since r1 is usual empty + val entry1 = result1(j).get(key) + if (result2(j).contains(key)) { + result2(j)(key).merge(entry1) + } else { + if (entry1.isDefined) { + result2(j) += (key -> entry1.get) + } } } + j += 1 } result2 } @@ -221,7 +246,7 @@ private[spark] object StratifiedSamplingUtils extends Logging { * * The sampling function has a unique seed per partition. */ - def getBernoulliCellSamplingFunction[K, V](rdd: RDD[(K, V)], + def getBernoulliCellSamplingFunction[K, V]( lb: Map[K, Double], ub: Map[K, Double], seed: Long, @@ -333,7 +358,7 @@ private[spark] object StratifiedSamplingUtils extends Logging { * * `[random]` here is necessary since it's in the return type signature of seqOp defined above */ -private[random] class AcceptanceResult(var numItems: Long = 0L, var numAccepted: Long = 0L) +class AcceptanceResult(var numItems: Long = 0L, var numAccepted: Long = 0L) extends Serializable { val waitList = new ArrayBuffer[Double] diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 4de7168b13adc..18f817b402ffb 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException -import scala.collection.mutable.{ArrayBuffer, HashSet} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.Random import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} @@ -220,6 +220,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val splitWeights = (1 to numSplits).map(n => 1.toDouble).toArray // check normalization too val weights: Array[scala.collection.Map[String, Double]] = splitWeights.map(w => keys.map(k => (k, w)).toMap) +// println(weights.mkString("***")) StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, true) } } @@ -228,7 +229,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val defaultSeed = 1L // vary RDD size - for (n <- List(100, 1000, 10000)) { + for (n <- List(500, 1000, 10000)) { val data = sc.parallelize(1 to n, 2) val fractionPositive = 0.3 val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) @@ -241,7 +242,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { // vary fractionPositive for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { - val n = 100 + val n = 500 val data = sc.parallelize(1 to n, 2) val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) val keys = stratifiedData.keys.distinct().collect() @@ -252,7 +253,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } // use same data for remaining tests - val n = 100 + val n = 500 val fractionPositive = 0.3 val data = sc.parallelize(1 to n, 2) val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) @@ -273,7 +274,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { // vary the number of splits for (numSplits <- 1 to 5) { - val splitWeights = (1 to numSplits).map(n => 1.toDouble).toArray // check normalization too + val splitWeights = Array.fill(numSplits)(1.0) // check normalization too val weights: Array[scala.collection.Map[String, Double]] = splitWeights.map(w => keys.map(k => (k, w)).toMap) StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, false) @@ -325,7 +326,8 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { def error(est: Long, size: Long): Double = math.abs(est - size) / size.toDouble /* Since HyperLogLog unique counting is approximate, and the relative standard deviation is - * only a statistical bound, the tests can fail for large values of relativeSD. We will be using + * only a statistical bound, the tests can fail for large values of relativeSD. We will be + using * relatively tight error bounds to check correctness of functionality rather than checking * whether the approximation conforms with the requested bound. */ @@ -648,7 +650,8 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(FakeOutputCommitter.ran, "OutputCommitter was never called") } - test("failure callbacks should be called before calling writer.close() in saveNewAPIHadoopFile") { + test("failure callbacks should be called before calling writer.close() in saveNewAPIHadoopFile") + { val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) FakeWriterWithCallback.calledBy = "" @@ -761,14 +764,14 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { def checkSplitSize(exact: Boolean, expected: Long, actual: Long, - p: Double): Boolean = { + p: Double): Unit = { if (exact) { // all splits will not be exact, but must be within 1 of expected size - return math.abs(expected - actual) <= 1 + assert(math.abs(expected - actual) <= 1) } val stdev = math.sqrt(expected * p * (1 - p)) // Very forgiving margin since we're dealing with very small sample sizes most of the time - math.abs(actual - expected) <= 6 * stdev + assert(math.abs(actual - expected) <= 6 * stdev) } def testSampleExact(stratifiedData: RDD[(String, Int)], @@ -787,62 +790,66 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { testPoisson(stratifiedData, false, samplingRate, seed, n) } - def testSplits(stratifiedData: RDD[(String, Int)], - weights: Array[scala.collection.Map[String, Double]], + def testSplits( + stratifiedData: RDD[(String, Int)], + splitWeights: Array[scala.collection.Map[String, Double]], seed: Long, n: Int, exact: Boolean): Unit = { - val baseFold = weights(0).map(x => (x._1, 0.0)) - val totalWeightByKey = weights.foldLeft(baseFold) { case (accMap, iterMap) => - accMap.map { case (k, v) => (k, v + iterMap(k)) } - } - val normedWeights = weights.map(m => m.map { case(k, v) => (k, v / totalWeightByKey(k))}) - - val splits = stratifiedData.randomSplitByKey(weights, exact, seed) - val stratCounts = stratifiedData.countByKey() - - val expectedSampleSizes = normedWeights.map { m => - stratCounts.map { case (key, count) => - (key, math.ceil(count * m(key)).toLong) + def countByKey[K, V](xs: TraversableOnce[(K, V)]): Map[K, Int] = { + xs.foldLeft(HashMap.empty[K, Int].withDefaultValue(0)) { case (acc, (k, v)) => + acc(k) += 1 + acc }.toMap } - val expectedComplementSizes = normedWeights.map { m => - stratCounts.map { case (key, count) => - (key, math.ceil(count * (1 - m(key))).toLong) - }.toMap + + val baseFold = splitWeights.head.mapValues(_ => 0.0) + val totalWeightByKey = splitWeights.foldLeft(baseFold) { case (cumWeights, weights) => + cumWeights.map { case (k, sum) => (k, sum + weights(k)) } + } + val normedWeights = splitWeights.map{weights => + weights.map { case(k, v) => (k, v / totalWeightByKey(k))} } - val samples = splits.map{ case(subsample, complement) => subsample.collect()} - val complements = splits.map{ case(subsample, complement) => complement.collect()} + val splits = stratifiedData.randomSplitByKey(splitWeights, exact, seed) + val data = stratifiedData.collect() + val dataSet = data.toSet + val totalCounts = countByKey(data) + + val sampleSet = scala.collection.mutable.Set[(String, Int)]() + splits.zip(normedWeights).foreach { case ((sample, complement), fractions) => + val takeSample = sample.collect() + val takeComplement = complement.collect() + + // no duplicates in samples + assert(takeSample.length == takeSample.toSet.size) + assert(takeComplement.length == takeComplement.toSet.size) + + val sampleCounts = countByKey(takeSample) + val complementCounts = countByKey(takeComplement) +// println(sampleCounts, complementCounts, fractions, totalCounts) + val observedTotals = totalCounts.map { case (k, v) => + k -> (sampleCounts.getOrElse(k, 0) + complementCounts.getOrElse(k, 0)) + } + assert(observedTotals == totalCounts) - // check for the correct sample size for each split by key - (samples.map(_.groupBy(_._1).map(x => (x._1, x._2.length))) zip expectedSampleSizes) - .zipWithIndex.foreach { case ((actual, expected), idx) => - actual.foreach { case (k, v) => - checkSplitSize(exact, expected(k), v, normedWeights(idx)(k)) - } + sampleCounts.foreach { case (k, count) => + val expectedCount = math.ceil(totalCounts(k) * fractions(k)).toInt + checkSplitSize(exact, expectedCount, count, fractions(k)) } - (complements.map(_.groupBy(_._1).map(x => (x._1, x._2.length))) zip expectedComplementSizes) - .zipWithIndex.foreach { case ((actual, expected), idx) => - actual.foreach{ case (k, v) => - checkSplitSize(exact, expected(k), v, normedWeights(idx)(k)) - } + complementCounts.foreach { case (k, count) => + val expectedCount = math.ceil(totalCounts(k) * (1 - fractions(k))).toInt + checkSplitSize(exact, expectedCount, count, fractions(k)) } - // make sure samples ++ complements equals the original set - (samples zip complements).foreach { case (sample, complement) => - assert((sample ++ complement).sortBy(_._2).toList == stratifiedData.collect().toList) + sampleSet ++= takeSample + val samplesPlusComplements = (takeSample ++ takeComplement).toSet + assert(samplesPlusComplements == dataSet) } - // make sure the elements are members of the original set - samples.map(sample => sample.map(x => assert(x._2 >= 1 && x._2 <= n))) - - // make sure no duplicates in each sample - samples.map(sample => assert(sample.length == sample.toSet.size)) - - // make sure that union of all samples equals the original set - assert(samples.flatMap(x => x).sortBy(_._2).toList == stratifiedData.collect().toList) + // union of all samples equals original data + assert(sampleSet == dataSet) } // Without replacement validation diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index efd59da41fc43..8d0c7071f3378 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -48,23 +48,10 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation (>= 2)", ParamValidators.gtEq(2)) - /** - * Param for stratified sampling column name - * Default: "None" - * - * @group param - */ - val stratifiedCol: Param[String] = new Param[String](this, "stratifiedCol", - "stratified column name") - - /** @group getParam */ - def getStratifiedCol: String = $(stratifiedCol) - /** @group getParam */ def getNumFolds: Int = $(numFolds) setDefault(numFolds -> 3) - setDefault(stratifiedCol -> "None") } @@ -104,6 +91,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("2.0.0") def setStratifiedCol(value: String): this.type = set(stratifiedCol, value) + setDefault(stratifiedCol -> "") @Since("2.0.0") override def fit(dataset: Dataset[_]): CrossValidatorModel = { @@ -116,16 +104,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val numModels = epm.length val metrics = new Array[Double](epm.length) - val splits = if (dataset.columns.contains($(stratifiedCol))) { - val stratifiedColIndex = dataset.columns.indexOf($(stratifiedCol)) + val splits = if ($(stratifiedCol).nonEmpty) { + val stratifiedColIndex = schema.fieldNames.indexOf($(stratifiedCol)) val pairData = dataset.toDF.rdd.map(row => (row(stratifiedColIndex), row)) val splitsWithKeys = MLUtils.kFoldStratified(pairData, $(numFolds), 0) splitsWithKeys.map { case (training, validation) => (training.values, validation.values)} } else { - if (isSet(stratifiedCol)) { - logWarning(s"Stratified column does not exist. Performing regular k-fold subsampling.") - } - // regular kFold MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 87f73644b6aa5..986c9ce1781a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -48,22 +48,10 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio", "ratio between training set and validation set (>= 0 && <= 1)", ParamValidators.inRange(0, 1)) - /** - * Param for stratified sampling column name - * Default: "None" - * @group param - */ - val stratifiedCol: Param[String] = new Param[String](this, "stratifiedCol", - "stratified column name") - /** @group getParam */ def getTrainRatio: Double = $(trainRatio) - /** @group getParam */ - def getStratifiedCol: String = $(stratifiedCol) - setDefault(trainRatio -> 0.75) - setDefault(stratifiedCol -> "None") } /** @@ -101,10 +89,12 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St def setSeed(value: Long): this.type = set(seed, value) def setStratifiedCol(value: String): this.type = set(stratifiedCol, value) + setDefault(stratifiedCol -> "") @Since("2.0.0") override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { val schema = dataset.schema + val sparkSession = dataset.sparkSession transformSchema(schema, logging = true) val est = $(estimator) val eval = $(evaluator) @@ -113,20 +103,19 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St val metrics = new Array[Double](epm.length) val Array(trainingDataset, validationDataset) = - if (dataset.columns.contains($(stratifiedCol))) { - val stratifiedColIndex = dataset.columns.indexOf($(stratifiedCol)) - val pairData = dataset.rdd.map(row => (row(stratifiedColIndex), row)) + if ($(stratifiedCol).nonEmpty) { + val stratifiedColIndex = schema.fieldNames.indexOf($(stratifiedCol)) + val pairData = dataset.toDF.rdd.map(row => (row(stratifiedColIndex), row)) val keys = pairData.keys.distinct.collect() val weights: Array[scala.collection.Map[Any, Double]] = - Array(keys.map(k => (k, $(trainRatio))).toMap, - keys.map(k => (k, 1 - $(trainRatio))).toMap) + Array(keys.map((_, $(trainRatio))).toMap, keys.map((_, 1 - $(trainRatio))).toMap) val splitsWithKeys = pairData.randomSplitByKey(weights, exact = true, 0) val Array(training, validation) = splitsWithKeys.map { case (subsample, complement) => subsample.values } - (sqlCtx.createDataFrame(training, schema).cache(), - sqlCtx.createDataFrame(validation, schema).cache()) + Array(sparkSession.createDataFrame(training, schema), + sparkSession.createDataFrame(validation, schema)) } else { - dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed)) + dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio))) } trainingDataset.cache() diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 26fd73814d70a..58bb951f0a11d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -67,6 +67,17 @@ private[ml] trait ValidatorParams extends HasSeed with Params { /** @group getParam */ def getEvaluator: Evaluator = $(evaluator) + /** + * Param for stratified sampling column name + * Default: empty + * @group param + */ + val stratifiedCol: Param[String] = new Param[String](this, "stratifiedCol", + "stratified column name") + + /** @group getParam */ + def getStratifiedCol: String = $(stratifiedCol) + protected def transformSchemaImpl(schema: StructType): StructType = { require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps") val firstEstimatorParamMap = $(estimatorParamMaps).head diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 315ecf3a4e3d0..b020db2f0eb97 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -235,7 +235,8 @@ object MLUtils extends Logging { * ratios in the original data are maintained in each stratum of the train and validation * data. */ - def kFoldStratified[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)], + def kFoldStratified[K: ClassTag, V: ClassTag]( + rdd: RDD[(K, V)], numFolds: Int, seed: Int): Array[(RDD[(K, V)], RDD[(K, V)])] = { val keys = rdd.keys.distinct().collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index d407165a71ec6..5483add02e1de 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -28,6 +28,9 @@ import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -67,6 +70,34 @@ class CrossValidatorSuite assert(cvModel.avgMetrics.length === lrParamMaps.length) } + test("strat") { + val numFolds = 10 + // generate imbalanced data + val data = Seq.tabulate(100) { i => + if (i >= numFolds) { + LabeledPoint(0.0, Vectors.dense(1.0)) + } else { + LabeledPoint(1.0, Vectors.dense(1.0)) + } + } + val df = sqlContext.createDataFrame(data) + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.maxIter, Array(0, 10)) + .build() + val eval = new BinaryClassificationEvaluator + val cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(numFolds) + .setStratifiedCol("label") + val cvModel = cv.fit(df) + // without stratified sampling, there is a 99.964% that one of the splits has + // no negative examples, so some of the metrics will be < 0.5, bringing down the avg metrics. + assert(cvModel.avgMetrics.forall(_ === 0.5)) + } + test("cross validation with linear regression") { val dataset = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( From 37be0b5c6a0a4bd6fcc4a0f59c5f575ef6f623ae Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 22 Jul 2016 10:16:31 -0700 Subject: [PATCH 4/4] Move some logic back into SSUtils --- .../apache/spark/rdd/PairRDDFunctions.scala | 71 ++++----------- .../util/random/StratifiedSamplingUtils.scala | 87 +++++++++++++------ .../spark/rdd/PairRDDFunctionsSuite.scala | 47 +++++----- .../spark/ml/tuning/CrossValidator.scala | 5 +- .../ml/tuning/TrainValidationSplit.scala | 6 +- .../org/apache/spark/mllib/util/MLUtils.scala | 12 ++- .../spark/ml/tuning/CrossValidatorSuite.scala | 63 +++++++------- .../ml/tuning/TrainValidationSplitSuite.scala | 38 +++++++- .../spark/mllib/util/MLUtilsSuite.scala | 4 +- 9 files changed, 191 insertions(+), 142 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index bbcbcc614341b..b1c890f644c52 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -265,7 +265,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val samplingFunc = if (withReplacement) { StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, false, seed) } else { - StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, false, seed) + StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, false, seed)._1 } self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) } @@ -295,7 +295,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val samplingFunc = if (withReplacement) { StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, true, seed) } else { - StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, true, seed) + StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, true, seed)._1 } self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) } @@ -324,66 +324,29 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) require(weights.flatMap(_.values).forall(v => v >= 0.0), "Negative sampling rates.") if (weights.length > 1) { require(weights.map(m => m.keys.toSet).sliding(2).forall(t => t(0) == t(1)), - "Inconsistent keys between splits.") + "randomSplitByKey(): Each split must specify fractions for each key.") } - - val zeroBoundary = weights.head.map { case (k, v) => (k, 0.0)} - - // normalize and cumulative sum - val sumWeightsByKey = weights.foldLeft(zeroBoundary) { case (acc, splitWeights) => - acc.map { case (k, v) => (k, v + splitWeights(k)) } - } - - val normedCumWeightsByKey = weights.scanLeft(zeroBoundary) { case (acc, splitWeights) => - splitWeights.map { case (key, fraction) => - val keySum = sumWeightsByKey(key) - val norm = if (keySum > 0.0) keySum else 1.0 - key -> (acc(key) + (fraction / norm)) - } + require(weights.nonEmpty, "randomSplitByKey(): Split weights cannot be empty.") + val sumWeights = weights.foldLeft(mutable.HashMap.empty[K, Double].withDefaultValue(0.0)) { + case (acc, fractions) => + fractions.foreach { case (k, v) => acc(k) += v } + acc } - - // compute exact thresholds for each stratum if required - val splitThresholds = if (exact) { - val left = normedCumWeightsByKey.head - val right = normedCumWeightsByKey.last - val innerSplits = normedCumWeightsByKey.drop(1).dropRight(1) - val finalResults = - StratifiedSamplingUtils.getAcceptanceResults(self, withReplacement = false, innerSplits, - None, seed) - val exactInnerSplits = finalResults.zip(innerSplits).map { case (finalResult, fractions) => - StratifiedSamplingUtils.computeThresholdByKey(finalResult, fractions) + val normedWeights = weights.map { case fractions => + fractions.map { case (k, v) => + val keySum = sumWeights(k) + k -> (if (keySum > 0.0) v / keySum else 0.0) } - left +: exactInnerSplits :+ right - } else { - normedCumWeightsByKey } + val samplingFuncs = + StratifiedSamplingUtils.getBernoulliCellSamplingFunctions(self, normedWeights, exact, seed) - splitThresholds.sliding(2).map { case Array(lb, ub) => - (randomSampleByKeyWithRange(lb, ub, seed), - randomSampleByKeyWithRange(lb, ub, seed, complement = true)) + samplingFuncs.map { case (func, complementFunc) => + (self.mapPartitionsWithIndex(func, preservesPartitioning = true), + self.mapPartitionsWithIndex(complementFunc, preservesPartitioning = true)) }.toArray } - /** - * Internal method exposed for Stratified Random Splits in DataFrames. Samples an RDD given - * probability bounds for each stratum. - * - * @param lb map of lower bound for each key to use for the Bernoulli cell sampler - * @param ub map of upper bound for each key to use for the Bernoulli cell sampler - * @param seed the seed for the Random number generator - * @param complement boolean specifying whether to return subsample or its complement - * @return A random, stratified sub-sample of the RDD without replacement. - */ - private[spark] def randomSampleByKeyWithRange( - lb: Map[K, Double], - ub: Map[K, Double], - seed: Long, - complement: Boolean = false): RDD[(K, V)] = { - val samplingFunc = StratifiedSamplingUtils.getBernoulliCellSamplingFunction[K, V]( - lb, ub, seed, complement) - self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) - } - /** * Merge the values for each key using an associative reduce function. This will also perform * the merging locally on each mapper before sending results to a reducer, similarly to a diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala index 8cd8e1178281e..ff2a826471048 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -50,7 +50,9 @@ import org.apache.spark.rdd.RDD * http://jmlr.org/proceedings/papers/v28/meng13a.html */ -object StratifiedSamplingUtils extends Logging { +private[spark] object StratifiedSamplingUtils extends Logging { + + type StratifiedSamplingFunc[K, V] = (Int, Iterator[(K, V)]) => Iterator[(K, V)] /** * Count the number of items instantly accepted and generate the waitlist for each stratum. @@ -59,9 +61,9 @@ object StratifiedSamplingUtils extends Logging { */ def getAcceptanceResults[K, V](rdd: RDD[(K, V)], withReplacement: Boolean, - fractions: Array[Map[K, Double]], + fractions: Seq[Map[K, Double]], counts: Option[Map[K, Long]], - seed: Long): Array[mutable.Map[K, AcceptanceResult]] = { + seed: Long): Seq[mutable.Map[K, AcceptanceResult]] = { val combOp = getCombOp[K] val mappedPartitionRDD = rdd.mapPartitionsWithIndex { case (partition, iter) => val zeroU: Array[mutable.Map[K, AcceptanceResult]] = Array.fill(fractions.length) { @@ -78,20 +80,23 @@ object StratifiedSamplingUtils extends Logging { mappedPartitionRDD.reduce(combOp) } + /** + * Convenience version of [[getAcceptanceResults()]] for a single sample. + */ def getAcceptanceResults[K, V]( rdd: RDD[(K, V)], withReplacement: Boolean, fractions: Map[K, Double], counts: Option[Map[K, Long]], seed: Long): mutable.Map[K, AcceptanceResult] = { - getAcceptanceResults(rdd, withReplacement, Array(fractions), counts, seed).head + getAcceptanceResults(rdd, withReplacement, Seq(fractions), counts, seed).head } /** * Returns the function used by aggregate to collect sampling statistics for each partition. */ def getSeqOp[K, V](withReplacement: Boolean, - fractions: Array[Map[K, Double]], + fractions: Seq[Map[K, Double]], rngs: Array[RandomDataGenerator], counts: Option[Map[K, Long]]): (Array[mutable.Map[K, AcceptanceResult]], (K, V)) => Array[mutable.Map[K, AcceptanceResult]] = { @@ -213,6 +218,18 @@ object StratifiedSamplingUtils extends Logging { thresholdByKey } + /** + * Convenience version of [[getBernoulliSamplingFunction()]] for a single split. + */ + def getBernoulliSamplingFunction[K: ClassTag, V: ClassTag]( + rdd: RDD[(K, V)], + fractions: Map[K, Double], + exact: Boolean, + seed: Long): (StratifiedSamplingFunc[K, V], StratifiedSamplingFunc[K, V]) = { + val complementFractions = fractions.map { case (k, v) => k -> (1.0 - v) } + getBernoulliCellSamplingFunctions(rdd, Seq(fractions, complementFractions), exact, seed).head + } + /** * Return the per partition sampling function used for sampling without replacement. * @@ -221,23 +238,43 @@ object StratifiedSamplingUtils extends Logging { * * The sampling function has a unique seed per partition. */ - def getBernoulliSamplingFunction[K, V](rdd: RDD[(K, V)], - fractions: Map[K, Double], + def getBernoulliCellSamplingFunctions[K: ClassTag, V: ClassTag]( + rdd: RDD[(K, V)], + fractions: Seq[Map[K, Double]], exact: Boolean, - seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = { - var samplingRateByKey = fractions - if (exact) { - // determine threshold for each stratum and resample - val finalResult = getAcceptanceResults(rdd, false, fractions, None, seed) - samplingRateByKey = computeThresholdByKey(finalResult, fractions) - } - (idx: Int, iter: Iterator[(K, V)]) => { - val rng = new RandomDataGenerator() - rng.reSeed(seed + idx) - // Must use the same invoke pattern on the rng as in getSeqOp for without replacement - // in order to generate the same sequence of random numbers when creating the sample - iter.filter(t => rng.nextUniform() < samplingRateByKey(t._1)) + seed: Long): Seq[(StratifiedSamplingFunc[K, V], StratifiedSamplingFunc[K, V])] = { + val thresholds = splitFractionsToSplitPoints(fractions) + val innerThresholds = if (exact) { + val finalResults = + getAcceptanceResults(rdd, withReplacement = false, thresholds, None, seed) + finalResults.zip(thresholds).map { case (finalResult, thresh) => + computeThresholdByKey(finalResult, thresh) + } + } else { + thresholds } + val leftBound = fractions.head.map { case (k, v) => (k, 0.0)} + val rightBound = fractions.head.map { case (k, v) => (k, 1.0)} + val outerThresholds = leftBound +: innerThresholds :+ rightBound + outerThresholds.sliding(2).map { case Seq(lb, ub) => + (getBernoulliCellSamplingFunction[K, V](lb, ub, seed), + getBernoulliCellSamplingFunction[K, V](lb, ub, seed, complement = true)) + }.toSeq + } + + /** + * Helper function to cumulative sum a sequence of Maps. + */ + private def splitFractionsToSplitPoints[K]( + fractions: Seq[Map[K, Double]]): Seq[Map[K, Double]] = { + val acc = new mutable.HashMap[K, Double]() + fractions.map { case splitWeights => + splitWeights.map { case (k, v) => + val thisKeySum = acc.getOrElseUpdate(k, 0.0) + acc(k) += v + k -> (v + thisKeySum) + } + }.dropRight(1) } /** @@ -250,18 +287,18 @@ object StratifiedSamplingUtils extends Logging { lb: Map[K, Double], ub: Map[K, Double], seed: Long, - complement: Boolean = false): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = { + complement: Boolean = false): StratifiedSamplingFunc[K, V] = { (idx: Int, iter: Iterator[(K, V)]) => { val rng = new RandomDataGenerator() rng.reSeed(seed + idx) if (complement) { - iter.filter { case(k, _) => + iter.filter { case (k, _) => val x = rng.nextUniform() (x < lb(k)) || (x >= ub(k)) } } else { - iter.filter { case(k, _) => + iter.filter { case (k, _) => val x = rng.nextUniform() (x >= lb(k)) && (x < ub(k)) } @@ -282,7 +319,7 @@ object StratifiedSamplingUtils extends Logging { def getPoissonSamplingFunction[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)], fractions: Map[K, Double], exact: Boolean, - seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = { + seed: Long): StratifiedSamplingFunc[K, V] = { // TODO implement the streaming version of sampling w/ replacement that doesn't require counts if (exact) { val counts = Some(rdd.countByKey()) @@ -358,7 +395,7 @@ object StratifiedSamplingUtils extends Logging { * * `[random]` here is necessary since it's in the return type signature of seqOp defined above */ -class AcceptanceResult(var numItems: Long = 0L, var numAccepted: Long = 0L) +private[random] class AcceptanceResult(var numItems: Long = 0L, var numAccepted: Long = 0L) extends Serializable { val waitList = new ArrayBuffer[Double] diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 18f817b402ffb..5851b1b8024e9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -176,7 +176,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val data = sc.parallelize(1 to n, 2) val fractionPositive = 0.3 val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) - val keys = stratifiedData.keys.distinct().collect() + val keys = Array("0", "1") val splitWeights = Array(0.3, 0.2, 0.5) val weights: Array[scala.collection.Map[String, Double]] = splitWeights.map(w => keys.map(k => (k, w)).toMap) @@ -188,7 +188,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val n = 100 val data = sc.parallelize(1 to n, 2) val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) - val keys = stratifiedData.keys.distinct().collect() + val keys = Array("0", "1") val splitWeights = Array(0.3, 0.2, 0.5) val weights: Array[scala.collection.Map[String, Double]] = splitWeights.map(w => keys.map(k => (k, w)).toMap) @@ -200,7 +200,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val fractionPositive = 0.3 val data = sc.parallelize(1 to n, 2) val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) - val keys = stratifiedData.keys.distinct().collect() + val keys = Array("0", "1") // use different weights for each key in the split val unevenWeights: Array[scala.collection.Map[String, Double]] = @@ -217,12 +217,15 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { // vary the number of splits for (numSplits <- 1 to 3) { - val splitWeights = (1 to numSplits).map(n => 1.toDouble).toArray // check normalization too + val splitWeights = Array.fill(numSplits)(1.0) // check normalization too val weights: Array[scala.collection.Map[String, Double]] = splitWeights.map(w => keys.map(k => (k, w)).toMap) -// println(weights.mkString("***")) StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, true) } + val thrown = intercept[IllegalArgumentException] { + stratifiedData.randomSplitByKey(Array.empty[scala.collection.Map[String, Double]], true, 42L) + } + assert(thrown.getMessage.contains("weights cannot be empty")) } test("randomSplitByKey") { @@ -233,7 +236,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val data = sc.parallelize(1 to n, 2) val fractionPositive = 0.3 val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) - val keys = stratifiedData.keys.distinct().collect() + val keys = Array("0", "1") val splitWeights = Array(0.3, 0.2, 0.5) val weights: Array[scala.collection.Map[String, Double]] = splitWeights.map(w => keys.map(k => (k, w)).toMap) @@ -245,7 +248,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val n = 500 val data = sc.parallelize(1 to n, 2) val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) - val keys = stratifiedData.keys.distinct().collect() + val keys = Array("0", "1") val splitWeights = Array(0.3, 0.2, 0.5) val weights: Array[scala.collection.Map[String, Double]] = splitWeights.map(w => keys.map(k => (k, w)).toMap) @@ -257,7 +260,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val fractionPositive = 0.3 val data = sc.parallelize(1 to n, 2) val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) - val keys = stratifiedData.keys.distinct().collect() + val keys = Array("0", "1") // use different weights for each key in the split val unevenWeights: Array[scala.collection.Map[String, Double]] = @@ -279,6 +282,10 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { splitWeights.map(w => keys.map(k => (k, w)).toMap) StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, false) } + val thrown = intercept[IllegalArgumentException] { + stratifiedData.randomSplitByKey(Array.empty[scala.collection.Map[String, Double]], false, 42L) + } + assert(thrown.getMessage.contains("weights cannot be empty")) } test("reduceByKey") { @@ -326,8 +333,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { def error(est: Long, size: Long): Double = math.abs(est - size) / size.toDouble /* Since HyperLogLog unique counting is approximate, and the relative standard deviation is - * only a statistical bound, the tests can fail for large values of relativeSD. We will be - using + * only a statistical bound, the tests can fail for large values of relativeSD. We will be using * relatively tight error bounds to check correctness of functionality rather than checking * whether the approximation conforms with the requested bound. */ @@ -650,8 +656,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(FakeOutputCommitter.ran, "OutputCommitter was never called") } - test("failure callbacks should be called before calling writer.close() in saveNewAPIHadoopFile") - { + test("failure callbacks should be called before calling writer.close() in saveNewAPIHadoopFile") { val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) FakeWriterWithCallback.calledBy = "" @@ -768,10 +773,11 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { if (exact) { // all splits will not be exact, but must be within 1 of expected size assert(math.abs(expected - actual) <= 1) + } else { + val stdev = math.sqrt(expected * p * (1 - p)) + // Very forgiving margin since we're dealing with very small sample sizes most of the time + assert(math.abs(actual - expected) <= 6 * stdev) } - val stdev = math.sqrt(expected * p * (1 - p)) - // Very forgiving margin since we're dealing with very small sample sizes most of the time - assert(math.abs(actual - expected) <= 6 * stdev) } def testSampleExact(stratifiedData: RDD[(String, Int)], @@ -823,16 +829,15 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val takeComplement = complement.collect() // no duplicates in samples - assert(takeSample.length == takeSample.toSet.size) - assert(takeComplement.length == takeComplement.toSet.size) + assert(takeSample.length === takeSample.toSet.size) + assert(takeComplement.length === takeComplement.toSet.size) val sampleCounts = countByKey(takeSample) val complementCounts = countByKey(takeComplement) -// println(sampleCounts, complementCounts, fractions, totalCounts) val observedTotals = totalCounts.map { case (k, v) => k -> (sampleCounts.getOrElse(k, 0) + complementCounts.getOrElse(k, 0)) } - assert(observedTotals == totalCounts) + assert(observedTotals === totalCounts) sampleCounts.foreach { case (k, count) => val expectedCount = math.ceil(totalCounts(k) * fractions(k)).toInt @@ -845,11 +850,11 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { sampleSet ++= takeSample val samplesPlusComplements = (takeSample ++ takeComplement).toSet - assert(samplesPlusComplements == dataSet) + assert(samplesPlusComplements === dataSet) } // union of all samples equals original data - assert(sampleSet == dataSet) + assert(sampleSet === dataSet) } // Without replacement validation diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 8d0c7071f3378..6d481d904e85b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -52,7 +52,6 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { def getNumFolds: Int = $(numFolds) setDefault(numFolds -> 3) - } /** @@ -89,7 +88,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) def setSeed(value: Long): this.type = set(seed, value) /** @group setParam */ - @Since("2.0.0") + @Since("2.1.0") def setStratifiedCol(value: String): this.type = set(stratifiedCol, value) setDefault(stratifiedCol -> "") @@ -107,7 +106,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val splits = if ($(stratifiedCol).nonEmpty) { val stratifiedColIndex = schema.fieldNames.indexOf($(stratifiedCol)) val pairData = dataset.toDF.rdd.map(row => (row(stratifiedColIndex), row)) - val splitsWithKeys = MLUtils.kFoldStratified(pairData, $(numFolds), 0) + val splitsWithKeys = MLUtils.kFoldStratified(pairData, $(numFolds), $(seed)) splitsWithKeys.map { case (training, validation) => (training.values, validation.values)} } else { MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 986c9ce1781a6..2340f63f67067 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -88,6 +88,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) + /** @group setParam */ + @Since("2.1.0") def setStratifiedCol(value: String): this.type = set(stratifiedCol, value) setDefault(stratifiedCol -> "") @@ -109,13 +111,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St val keys = pairData.keys.distinct.collect() val weights: Array[scala.collection.Map[Any, Double]] = Array(keys.map((_, $(trainRatio))).toMap, keys.map((_, 1 - $(trainRatio))).toMap) - val splitsWithKeys = pairData.randomSplitByKey(weights, exact = true, 0) + val splitsWithKeys = pairData.randomSplitByKey(weights, exact = true, $(seed)) val Array(training, validation) = splitsWithKeys.map { case (subsample, complement) => subsample.values } Array(sparkSession.createDataFrame(training, schema), sparkSession.createDataFrame(validation, schema)) } else { - dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio))) + dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed)) } trainingDataset.cache() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index b020db2f0eb97..22fc19ff532bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -27,9 +27,9 @@ import org.apache.spark.ml.linalg.{MatrixUDT => MLMatrixUDT, VectorUDT => MLVect import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, PairRDDFunctions} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.BernoulliCellSampler @@ -239,6 +239,16 @@ object MLUtils extends Logging { rdd: RDD[(K, V)], numFolds: Int, seed: Int): Array[(RDD[(K, V)], RDD[(K, V)])] = { + kFoldStratified(rdd, numFolds, seed.toLong) + } + + /** + * Version of [[kFoldStratified()]] taking a Long seed. + */ + def kFoldStratified[K: ClassTag, V: ClassTag]( + rdd: RDD[(K, V)], + numFolds: Int, + seed: Long): Array[(RDD[(K, V)], RDD[(K, V)])] = { val keys = rdd.keys.distinct().collect() val weights: Array[scala.collection.Map[K, Double]] = (1 to numFolds).map { n => keys.map(k => (k, 1 / numFolds.toDouble)).toMap diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 5483add02e1de..4765c4233e3a1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -22,15 +22,12 @@ import org.apache.spark.ml.{Estimator, Model, Pipeline} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.feature.{HashingTF, LabeledPoint} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -70,34 +67,6 @@ class CrossValidatorSuite assert(cvModel.avgMetrics.length === lrParamMaps.length) } - test("strat") { - val numFolds = 10 - // generate imbalanced data - val data = Seq.tabulate(100) { i => - if (i >= numFolds) { - LabeledPoint(0.0, Vectors.dense(1.0)) - } else { - LabeledPoint(1.0, Vectors.dense(1.0)) - } - } - val df = sqlContext.createDataFrame(data) - val lr = new LogisticRegression - val lrParamMaps = new ParamGridBuilder() - .addGrid(lr.maxIter, Array(0, 10)) - .build() - val eval = new BinaryClassificationEvaluator - val cv = new CrossValidator() - .setEstimator(lr) - .setEstimatorParamMaps(lrParamMaps) - .setEvaluator(eval) - .setNumFolds(numFolds) - .setStratifiedCol("label") - val cvModel = cv.fit(df) - // without stratified sampling, there is a 99.964% that one of the splits has - // no negative examples, so some of the metrics will be < 0.5, bringing down the avg metrics. - assert(cvModel.avgMetrics.forall(_ === 0.5)) - } - test("cross validation with linear regression") { val dataset = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( @@ -153,6 +122,36 @@ class CrossValidatorSuite } } + test("stratified vs. not stratified cross validation") { + val numFolds = 10 + val data = Seq.tabulate(100) { i => + if (i >= numFolds) { + LabeledPoint(0.0, Vectors.dense(1.0)) // 1 per split + } else { + LabeledPoint(1.0, Vectors.dense(1.0)) + } + } + val df = spark.createDataFrame(data) + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.maxIter, Array(2)) + .build() + val eval = new BinaryClassificationEvaluator + val cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(numFolds) + .setSeed(42L) + val notStratifiedModel = cv.fit(df) + cv.setStratifiedCol("label") + val stratifiedModel = cv.fit(df) + // without stratified sampling some of the splits will not contain both examples + // so some of the metrics will be < 0.5, bringing down the avg metrics. + assert(stratifiedModel.avgMetrics.forall(_ === 0.5)) + assert(notStratifiedModel.avgMetrics.exists(_ != 0.5)) + } + test("read/write: CrossValidator with simple estimator") { val lr = new LogisticRegression().setMaxIter(3) val evaluator = new BinaryClassificationEvaluator() diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 31b3ceaf04918..bb1c4d7e25e27 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -19,9 +19,10 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.classification.{DecisionTreeClassifier, LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator} +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol @@ -59,6 +60,39 @@ class TrainValidationSplitSuite assert(cvModel.validationMetrics.length === lrParamMaps.length) } + test("stratified") { + val data = Seq( + List.fill(20)(LabeledPoint(0.0, Vectors.dense(0.0))), + List.fill(20)(LabeledPoint(1.0, Vectors.dense(1.0))), + List.fill(2)(LabeledPoint(2.0, Vectors.dense(2.0))) + ).flatten + val df = spark.createDataFrame(data) + val trainer = new DecisionTreeClassifier() + val dtParamMaps = new ParamGridBuilder() + .addGrid(trainer.maxDepth, Array(2)) + .build() + val eval = new MulticlassClassificationEvaluator() + val cv = new TrainValidationSplit() + .setEstimator(trainer) + .setEstimatorParamMaps(dtParamMaps) + .setEvaluator(eval) + .setTrainRatio(0.5) + val nTrials = 5 + val notStratifiedTrials = (0 until nTrials).map { i => + cv.setSeed(42L + i) + val cvModel = cv.fit(df) + cvModel.validationMetrics.head + } + val stratifiedTrials = (0 until nTrials).map { i => + cv.setSeed(42L + i).setStratifiedCol("label") + val cvModel = cv.fit(df) + cvModel.validationMetrics.head + } + + assert(!stratifiedTrials.exists(metric => math.abs(metric - 1.0) > 1e-6)) + assert(notStratifiedTrials.exists(metric => math.abs(metric - 1.0) > 1e-6)) + } + test("train validation with linear regression") { val dataset = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index c8993b7aa1193..d5dafd88e5379 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -211,7 +211,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { } test("kFoldStratified") { - /** + /* * Most of the functionality of [[kFoldStratified]] is tested in the PairRDD function * `randomSplitByKey`. All that needs to be checked here is that the folds are even * splits for each key. @@ -222,7 +222,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { val fractionPositive = 0.3 val keys = Array("0", "1") val stratifiedData = data.map { x => - if (x > n*fractionPositive) ("0", x) else ("1", x) + if (x > n * fractionPositive) ("0", x) else ("1", x) } val counts = stratifiedData.countByKey() for (numFolds <- 1 to 3) {