diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 02e2384afe530..24dba78d284ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -39,6 +39,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -236,6 +237,8 @@ class ALSModel private[ml] ( @transient val itemFactors: DataFrame) extends Model[ALSModel] with ALSModelParams with MLWritable { + import org.apache.spark.ml.recommendation.ALS.Rating + /** @group setParam */ @Since("1.4.0") def setUserCol(value: String): this.type = set(userCol, value) @@ -269,6 +272,44 @@ class ALSModel private[ml] ( predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) } + /** + * Recommends top items for all users. + * + * @param num how many items to return for every user. + * @return a DataFrame that stores recommendations in two columns: `user` and `ratings`, where + * every row contains a userID and an array of [[Rating]] objects which contains the + * same userId, recommended itemID and "score". + */ + @Since("2.1.0") + def recommendItemsForUsers(num: Int): DataFrame = { + val spark = userFactors.sparkSession + import spark.implicits._ + toMLlibModel.recommendProductsForUsers(num).toDF("user", "ratings") + } + + /** + * Recommends top users for all items. + * + * @param num how many users to return for every item. + * @return a DataFrame that stores recommendations in two columns: `item` and `ratings`, where + * every row contains a itemID and an array of [[Rating]] objects which contains the + * same itemID, recommended userID and "score". + */ + @Since("2.1.0") + def recommendUsersForItems(num: Int): DataFrame = { + val spark = userFactors.sparkSession + import spark.implicits._ + toMLlibModel.recommendProductsForUsers(num).toDF("item", "ratings") + } + + private def toMLlibModel: MatrixFactorizationModel = { + val userFeatures = userFactors.select("id", "features").rdd + .map(r => (r.getInt(0), r.getSeq[Float](1).toArray.map(_.toDouble))) + val itemFeatures = itemFactors.select("id", "features").rdd + .map(r => (r.getInt(0), r.getSeq[Float](1).toArray.map(_.toDouble))) + new MatrixFactorizationModel(rank, userFeatures, itemFeatures) + } + @Since("1.3.0") override def transformSchema(schema: StructType): StructType = { // user and item will be cast to Int diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index d0aa2cdfe0fd1..e419a965dc1e6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -450,6 +450,20 @@ class ALSSuite implicitPrefs = true, seed = 0) } + test("recommend for all") { + val spark = this.spark + import spark.implicits._ + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + val model = new ALS().fit(ratings.toDF()) + val items = model.recommendItemsForUsers(2) + assert(items.count() == 4 + && items.select("ratings").rdd.collect().forall(_.getSeq[Rating[Int]](0).length == 2)) + + val users = model.recommendUsersForItems(2) + assert(users.count() == 4 + && users.select("ratings").rdd.collect().forall(_.getSeq[Rating[Int]](0).length == 2)) + } + test("read/write") { import ALSSuite._ val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)