Skip to content
87 changes: 77 additions & 10 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,24 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
/** @group expertGetParam */
def getFinalStorageLevel: String = $(finalStorageLevel)

/**
* Param for threshold in computation of dst factors to decide
* if stacking factors to speed up the computation.(>= 1).
* Default: 1024
* @group expertParam
*/
val threshold = new IntParam(this, "threshold", "threshold in computation of dst factors " +
"to decide if stacking factors to speed up the computation.",
ParamValidators.gtEq(1))

/** @group expertGetParam */
def getThreshold: Int = $(threshold)

setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10,
intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK",
coldStartStrategy -> "nan")
coldStartStrategy -> "nan", threshold -> 1024)

/**
* Validates and transforms the input schema.
Expand Down Expand Up @@ -589,6 +602,9 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
@Since("2.2.0")
def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value)

@Since("2.3.0")
def setThreshold(value: Int): this.type = set(threshold, value)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure callers can meaningfully understand and set this. Can't we pick a threshold programmatically?

@mpjlu mpjlu Oct 19, 2017

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, my test result is better than the previous result, especially for native BLAS. I will update my test results here soon, and I will change this set.


/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
*
Expand Down Expand Up @@ -617,15 +633,16 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
val instr = Instrumentation.create(this, ratings)
instr.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, userCol,
itemCol, ratingCol, predictionCol, maxIter, regParam, nonnegative, checkpointInterval,
seed, intermediateStorageLevel, finalStorageLevel)
threshold, seed, intermediateStorageLevel, finalStorageLevel)

val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank),
numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
alpha = $(alpha), nonnegative = $(nonnegative),
intermediateRDDStorageLevel = StorageLevel.fromString($(intermediateStorageLevel)),
finalRDDStorageLevel = StorageLevel.fromString($(finalStorageLevel)),
checkpointInterval = $(checkpointInterval), seed = $(seed))
checkpointInterval = $(checkpointInterval),
seed = $(seed), threshold = $(threshold))
val userDF = userFactors.toDF("id", "features")
val itemDF = itemFactors.toDF("id", "features")
val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this)
Expand Down Expand Up @@ -783,6 +800,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
val atb = new Array[Double](k)

private val da = new Array[Double](k)
private val ata2 = new Array[Double](k * k)
private val upper = "U"

private def copyToDouble(a: Array[Float]): Unit = {
Expand All @@ -793,6 +811,22 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
}

private def copyToTri(): Unit = {
var i = 0
var j = 0
var ii = 0
while (i < k) {
val temp = i * k
j = 0
while (j <= i) {
ata(ii) += ata2(temp + j)
j += 1
ii += 1
}
i += 1
}
}

/** Adds an observation. */
def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = {
require(c >= 0.0)
Expand All @@ -805,6 +839,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
this
}

/** Adds a stack of observations. */
def addStack(a: Array[Double], b: Array[Double], n: Int): this.type = {
require(a.length == n * k)
blas.dsyrk(upper, "N", k, n, 1.0, a, k, 1.0, ata2, k)
copyToTri()
blas.dgemv("N", k, n, 1.0, a, k, b, 1, 1.0, atb, 1)
this
}

/** Merges another normal equation object. */
def merge(other: NormalEquation): this.type = {
require(other.k == k)
Expand All @@ -816,6 +859,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
/** Resets everything to zero, which should be called after each solve. */
def reset(): Unit = {
ju.Arrays.fill(ata, 0.0)
ju.Arrays.fill(ata2, 0.0)
ju.Arrays.fill(atb, 0.0)
}
}
Expand Down Expand Up @@ -860,7 +904,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
checkpointInterval: Int = 10,
seed: Long = 0L)(
seed: Long = 0L,
threshold: Int = 1024)(
implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {

require(!ratings.isEmpty(), s"No ratings available from $ratings")
Expand Down Expand Up @@ -918,7 +963,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)
val previousItemFactors = itemFactors
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, implicitPrefs, alpha, solver)
userLocalIndexEncoder, implicitPrefs, alpha, solver, threshold)
previousItemFactors.unpersist()
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
// TODO: Generalize PeriodicGraphCheckpointer and use it here.
Expand All @@ -928,7 +973,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
val previousUserFactors = userFactors
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, implicitPrefs, alpha, solver)
itemLocalIndexEncoder, implicitPrefs, alpha, solver, threshold)
if (shouldCheckpoint(iter)) {
ALS.cleanShuffleDependencies(sc, deps)
deletePreviousCheckpointFile()
Expand All @@ -939,7 +984,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
} else {
for (iter <- 0 until maxIter) {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, solver = solver)
userLocalIndexEncoder, solver = solver, threshold = threshold)
if (shouldCheckpoint(iter)) {
val deps = itemFactors.dependencies
itemFactors.checkpoint()
Expand All @@ -949,7 +994,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
previousCheckpointFile = itemFactors.getCheckpointFile
}
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, solver = solver)
itemLocalIndexEncoder, solver = solver, threshold = threshold)
}
}
val userIdAndFactors = userInBlocks
Expand Down Expand Up @@ -1595,7 +1640,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
srcEncoder: LocalIndexEncoder,
implicitPrefs: Boolean = false,
alpha: Double = 1.0,
solver: LeastSquaresNESolver): RDD[(Int, FactorBlock)] = {
solver: LeastSquaresNESolver,
threshold: Int): RDD[(Int, FactorBlock)] = {
val numSrcBlocks = srcFactorBlocks.partitions.length
val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
Expand All @@ -1621,6 +1667,13 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
var i = srcPtrs(j)
var numExplicits = 0
// Stacking factors(vectors) in matrices to speed up the computation,
// when the number of factors and the rank is large enough.
val doStack = srcPtrs(j + 1) - srcPtrs(j) > threshold && rank > threshold
val srcFactorBuffer = new Array[Double]((srcPtrs(j + 1) - srcPtrs(j)) * rank)
val bBuffer = new Array[Double](srcPtrs(j + 1) - srcPtrs(j))
var srcIndex = 0
var bIndex = 0
while (i < srcPtrs(j + 1)) {
val encoded = srcEncodedIndices(i)
val blockId = srcEncoder.blockId(encoded)
Expand All @@ -1638,11 +1691,25 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
ls.add(srcFactor, if (rating > 0.0) 1.0 + c1 else 0.0, c1)
} else {
ls.add(srcFactor, rating)
numExplicits += 1
if (doStack) {
bBuffer(bIndex) = rating
bIndex += 1
var ii = 0
while(ii < rank) {
srcFactorBuffer(srcIndex) = srcFactor(ii)
srcIndex += 1
ii += 1
}
} else {
ls.add(srcFactor, rating)
}
}
i += 1
}
if (!implicitPrefs && doStack) {
ls.addStack(srcFactorBuffer, bBuffer, numExplicits)
}
// Weight lambda by the number of explicit ratings based on the ALS-WR paper.
dstFactors(j) = solver.solve(ls, numExplicits * regParam)
j += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,8 @@ class ALSSuite
implicitPrefs: Boolean = false,
numUserBlocks: Int = 2,
numItemBlocks: Int = 3,
targetRMSE: Double = 0.05): Unit = {
targetRMSE: Double = 0.05,
threshold: Int = 1024): Unit = {
val spark = this.spark
import spark.implicits._
val als = new ALS()
Expand All @@ -411,6 +412,7 @@ class ALSSuite
.setNumUserBlocks(numUserBlocks)
.setNumItemBlocks(numItemBlocks)
.setSeed(0)
.setThreshold(threshold)
val alpha = als.getAlpha
val model = als.fit(training.toDF())
val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map {
Expand Down Expand Up @@ -481,6 +483,12 @@ class ALSSuite
numItemBlocks = 5, numUserBlocks = 5)
}

test("do stacking factors in matrices") {
val (training, test) = genExplicitTestData(numUsers = 200, numItems = 20, rank = 1)
testALS(training, test, maxIter = 1, rank = 129, regParam = 0.01, targetRMSE = 0.02,
threshold = 128)
}

test("implicit feedback") {
val (training, test) =
genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
Expand Down
3 changes: 3 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,9 @@ object MimaExcludes {
// [SPARK-21680][ML][MLLIB]optimzie Vector coompress
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.toSparseWithSize"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Vector.toSparseWithSize")
) ++ Seq(
// [SPARK-6685][ML]Use DSYRK to compute AtA in ALS
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.recommendation.ALS.train")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks an API unnecessarily, even though it's a dev API. I think we instead need to remove this as a user-facing param and avoid it altogether.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am ok to remove this, and use a loose threshold (e.g. 100), which is helpful for most cases. How about it?

)
}

Expand Down