diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 3d5fd1794de2..27cda5e95f7a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -234,11 +234,24 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w /** @group expertGetParam */ def getFinalStorageLevel: String = $(finalStorageLevel) + /** + * Param for threshold in computation of dst factors to decide + * if stacking factors to speed up the computation.(>= 1). + * Default: 1024 + * @group expertParam + */ + val threshold = new IntParam(this, "threshold", "threshold in computation of dst factors " + + "to decide if stacking factors to speed up the computation.", + ParamValidators.gtEq(1)) + + /** @group expertGetParam */ + def getThreshold: Int = $(threshold) + setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK", - coldStartStrategy -> "nan") + coldStartStrategy -> "nan", threshold -> 1024) /** * Validates and transforms the input schema. @@ -589,6 +602,9 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] @Since("2.2.0") def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value) + @Since("2.3.0") + def setThreshold(value: Int): this.type = set(threshold, value) + /** * Sets both numUserBlocks and numItemBlocks to the specific value. * @@ -617,7 +633,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] val instr = Instrumentation.create(this, ratings) instr.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, userCol, itemCol, ratingCol, predictionCol, maxIter, regParam, nonnegative, checkpointInterval, - seed, intermediateStorageLevel, finalStorageLevel) + threshold, seed, intermediateStorageLevel, finalStorageLevel) val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank), numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks), @@ -625,7 +641,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] alpha = $(alpha), nonnegative = $(nonnegative), intermediateRDDStorageLevel = StorageLevel.fromString($(intermediateStorageLevel)), finalRDDStorageLevel = StorageLevel.fromString($(finalStorageLevel)), - checkpointInterval = $(checkpointInterval), seed = $(seed)) + checkpointInterval = $(checkpointInterval), + seed = $(seed), threshold = $(threshold)) val userDF = userFactors.toDF("id", "features") val itemDF = itemFactors.toDF("id", "features") val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this) @@ -783,6 +800,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { val atb = new Array[Double](k) private val da = new Array[Double](k) + private val ata2 = new Array[Double](k * k) private val upper = "U" private def copyToDouble(a: Array[Float]): Unit = { @@ -793,6 +811,22 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { } } + private def copyToTri(): Unit = { + var i = 0 + var j = 0 + var ii = 0 + while (i < k) { + val temp = i * k + j = 0 + while (j <= i) { + ata(ii) += ata2(temp + j) + j += 1 + ii += 1 + } + i += 1 + } + } + /** Adds an observation. */ def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = { require(c >= 0.0) @@ -805,6 +839,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { this } + /** Adds a stack of observations. */ + def addStack(a: Array[Double], b: Array[Double], n: Int): this.type = { + require(a.length == n * k) + blas.dsyrk(upper, "N", k, n, 1.0, a, k, 1.0, ata2, k) + copyToTri() + blas.dgemv("N", k, n, 1.0, a, k, b, 1, 1.0, atb, 1) + this + } + /** Merges another normal equation object. */ def merge(other: NormalEquation): this.type = { require(other.k == k) @@ -816,6 +859,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { /** Resets everything to zero, which should be called after each solve. */ def reset(): Unit = { ju.Arrays.fill(ata, 0.0) + ju.Arrays.fill(ata2, 0.0) ju.Arrays.fill(atb, 0.0) } } @@ -860,7 +904,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, checkpointInterval: Int = 10, - seed: Long = 0L)( + seed: Long = 0L, + threshold: Int = 1024)( implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = { require(!ratings.isEmpty(), s"No ratings available from $ratings") @@ -918,7 +963,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel) val previousItemFactors = itemFactors itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, - userLocalIndexEncoder, implicitPrefs, alpha, solver) + userLocalIndexEncoder, implicitPrefs, alpha, solver, threshold) previousItemFactors.unpersist() itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel) // TODO: Generalize PeriodicGraphCheckpointer and use it here. @@ -928,7 +973,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { } val previousUserFactors = userFactors userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, - itemLocalIndexEncoder, implicitPrefs, alpha, solver) + itemLocalIndexEncoder, implicitPrefs, alpha, solver, threshold) if (shouldCheckpoint(iter)) { ALS.cleanShuffleDependencies(sc, deps) deletePreviousCheckpointFile() @@ -939,7 +984,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { } else { for (iter <- 0 until maxIter) { itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, - userLocalIndexEncoder, solver = solver) + userLocalIndexEncoder, solver = solver, threshold = threshold) if (shouldCheckpoint(iter)) { val deps = itemFactors.dependencies itemFactors.checkpoint() @@ -949,7 +994,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { previousCheckpointFile = itemFactors.getCheckpointFile } userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, - itemLocalIndexEncoder, solver = solver) + itemLocalIndexEncoder, solver = solver, threshold = threshold) } } val userIdAndFactors = userInBlocks @@ -1595,7 +1640,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { srcEncoder: LocalIndexEncoder, implicitPrefs: Boolean = false, alpha: Double = 1.0, - solver: LeastSquaresNESolver): RDD[(Int, FactorBlock)] = { + solver: LeastSquaresNESolver, + threshold: Int): RDD[(Int, FactorBlock)] = { val numSrcBlocks = srcFactorBlocks.partitions.length val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap { @@ -1621,6 +1667,13 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { } var i = srcPtrs(j) var numExplicits = 0 + // Stacking factors(vectors) in matrices to speed up the computation, + // when the number of factors and the rank is large enough. + val doStack = srcPtrs(j + 1) - srcPtrs(j) > threshold && rank > threshold + val srcFactorBuffer = new Array[Double]((srcPtrs(j + 1) - srcPtrs(j)) * rank) + val bBuffer = new Array[Double](srcPtrs(j + 1) - srcPtrs(j)) + var srcIndex = 0 + var bIndex = 0 while (i < srcPtrs(j + 1)) { val encoded = srcEncodedIndices(i) val blockId = srcEncoder.blockId(encoded) @@ -1638,11 +1691,25 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { } ls.add(srcFactor, if (rating > 0.0) 1.0 + c1 else 0.0, c1) } else { - ls.add(srcFactor, rating) numExplicits += 1 + if (doStack) { + bBuffer(bIndex) = rating + bIndex += 1 + var ii = 0 + while(ii < rank) { + srcFactorBuffer(srcIndex) = srcFactor(ii) + srcIndex += 1 + ii += 1 + } + } else { + ls.add(srcFactor, rating) + } } i += 1 } + if (!implicitPrefs && doStack) { + ls.addStack(srcFactorBuffer, bBuffer, numExplicits) + } // Weight lambda by the number of explicit ratings based on the ALS-WR paper. dstFactors(j) = solver.solve(ls, numExplicits * regParam) j += 1 diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index ac7319110159..f225f385095a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -401,7 +401,8 @@ class ALSSuite implicitPrefs: Boolean = false, numUserBlocks: Int = 2, numItemBlocks: Int = 3, - targetRMSE: Double = 0.05): Unit = { + targetRMSE: Double = 0.05, + threshold: Int = 1024): Unit = { val spark = this.spark import spark.implicits._ val als = new ALS() @@ -411,6 +412,7 @@ class ALSSuite .setNumUserBlocks(numUserBlocks) .setNumItemBlocks(numItemBlocks) .setSeed(0) + .setThreshold(threshold) val alpha = als.getAlpha val model = als.fit(training.toDF()) val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map { @@ -481,6 +483,12 @@ class ALSSuite numItemBlocks = 5, numUserBlocks = 5) } + test("do stacking factors in matrices") { + val (training, test) = genExplicitTestData(numUsers = 200, numItems = 20, rank = 1) + testALS(training, test, maxIter = 1, rank = 129, regParam = 0.01, targetRMSE = 0.02, + threshold = 128) + } + test("implicit feedback") { val (training, test) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dd299e074535..ebab952f2af4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -1043,6 +1043,9 @@ object MimaExcludes { // [SPARK-21680][ML][MLLIB]optimzie Vector coompress ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.toSparseWithSize"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Vector.toSparseWithSize") + ) ++ Seq( + // [SPARK-6685][ML]Use DSYRK to compute AtA in ALS + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.recommendation.ALS.train") ) }