Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 81 additions & 53 deletions merlin/models/xgb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,81 @@
import warnings
from typing import Dict, List, Optional, Union

import distributed
import numpy as np
import xgboost as xgb

from merlin.core.utils import global_dask_client
from merlin.io import Dataset
from merlin.schema import Tags
from merlin.schema import Schema, Tags


class XGBoost:
"""Create an XGBoost model.
"""Create an XGBoost model from a merlin dataset.
The class adapts an XGBoost model to work with the high level merlin-models API.

Example usage::

# get the movielens dataset
from merlin.datasets.entertainment import get_movielens

train, valid = get_movielens()
train, valid = get_movielens(variant="ml-1m")

# Train an XGBoost model
from merlin.schema import Tags
from merlin.core.utils import Distributed
from merlin.models.xgb import XGBoost

model = XGBoost(objective="binary:logistic")
model.fit(train)

model.evaluate(valid)
with Distributed():
model = XGBoost(train.schema, objective="binary:logistic")
model.fit(train)
metrics = model.evaluate(valid)
"""

def __init__(self, *, objective="reg:squarederror", **params):
def __init__(
self,
schema: Schema,
*,
target_columns: Optional[Union[str, list]] = None,
qid_column: Optional[str] = None,
objective: str = "reg:squarederror",
booster: Optional[xgb.Booster] = None,
**params,
):
"""
Parameters
----------
schema : merlin.schema.Schema
The schema of the data that will be used to train and evaluate the model.
target_columns : Optional[Union[list, str]]
The target columns to use. If provided, will be used as the label(s).
Otherwise the targets are automatically inferred from the objective and column tags.
qid_column : Optional[str]
For ranking objectives. The query ID column. If not provided will use
the user ID (tagged with merlin.schema.Tags.USER_ID) column.
objective : str
The XGBoost objective to use. List of XGBoost objective functions:
https://xgboost.readthedocs.io/en/stable/gpu/index.html#objective-functions
**params
The parameters to use for the XGBoost train method
"""
self.params = {**params, "objective": objective}
self.bst = None

target_tag = get_target_tag(objective)
self.target_columns = target_columns or get_targets(schema, target_tag)

if objective.startswith("rank") and qid_column is None:
qid_column = schema.select_by_tag(Tags.USER_ID).column_names[0]
self.qid_column = qid_column

self.booster = booster

@property
def dask_client(self) -> Optional[distributed.Client]:
return global_dask_client()

def fit(
self,
train: Dataset,
*,
target_columns: Optional[Union[str, list]] = None,
qid_column: Optional[str] = None,
**train_kwargs,
) -> xgb.Booster:
"""Trains the XGBoost Model.
Expand All @@ -62,12 +90,6 @@ def fit(
The training dataset to use to fit the model.
We will use the column(s) tagged with merlin.schema.Tags.TARGET that match the
objective as the label(s).
target_columns : Optional[Union[list, str]]
The target columns to use. If provided, will be used as the label(s).
Otherwise the targets are automatically inferred from the objective and column tags.
qid_column : Optional[str]
For ranking objectives. The query ID column. If not provided will use
the user ID (tagged with merlin.schema.Tags.USER_ID) column.
**train_kwargs
Additional keyword arguments passed to the xgboost.train function

Expand All @@ -80,21 +102,21 @@ def fit(
ValueError
If objective is not supported. Or if the target columns cannot be found.
"""
objective = self.params["objective"]
target_tag = get_target_tag(objective)
self.target_columns = target_columns or get_targets(train, target_tag)

# for ranking objectives, set the grouping
if objective.startswith("rank") and qid_column is None:
qid_column = train.schema.select_by_tag(Tags.USER_ID).column_names[0]
self.qid_column = qid_column
X, y, qid = dataset_to_xy(
train,
self.target_columns,
self.qid_column,
)

dtrain = dataset_to_dmatrix(train, self.target_columns, self.qid_column)
dtrain = xgb.dask.DaskDMatrix(self.dask_client, X, label=y, qid=qid)
watchlist = [(dtrain, "train")]

self.bst = xgb.train(self.params, dtrain, evals=watchlist, **train_kwargs)
booster: xgb.Booster = xgb.dask.train(
self.dask_client, self.params, dtrain, evals=watchlist, **train_kwargs
)["booster"]
self.booster = booster

return self.bst
return booster

def evaluate(self, dataset: Dataset, **predict_kwargs) -> Dict[str, float]:
"""Evaluates the model on the dataset provided.
Expand All @@ -111,14 +133,20 @@ def evaluate(self, dataset: Dataset, **predict_kwargs) -> Dict[str, float]:
Dict[str, float]
Dictionary of metrics of the form {metric_name: value}.
"""
if self.bst is None:
if self.booster is None:
raise ValueError("The fit method must be called before evaluate.")

data: xgb.DMatrix = dataset_to_dmatrix(dataset, self.target_columns, self.qid_column)
preds = self.bst.predict(data, **predict_kwargs)
data.set_label(preds)
X, _, qid = dataset_to_xy(dataset, self.target_columns, self.qid_column)
data = xgb.dask.DaskDMatrix(self.dask_client, X, qid=qid)
preds = xgb.dask.predict(self.dask_client, self.booster, data, **predict_kwargs)

# convert to DMatrix
# (eval doesn't have dask support currently)
if qid is not None:
qid = qid.compute()
eval_data = xgb.DMatrix(X.compute(), label=preds.compute(), qid=qid)

metrics_str = self.bst.eval(data)
metrics_str = self.booster.eval(eval_data)
metrics = {}
for metric in metrics_str.split("\t")[1:]:
metric_name, metric_value = metric.split(":")
Expand All @@ -141,11 +169,12 @@ def predict(self, dataset: Dataset, **predict_kwargs) -> np.ndarray:
numpy.ndarray
The predicions data
"""
if self.bst is None:
if self.booster is None:
raise ValueError("The fit method must be called before predict.")

data: xgb.DMatrix = dataset_to_dmatrix(dataset, self.target_columns, self.qid_column)
preds = self.bst.predict(data, **predict_kwargs)
X, _, qid = dataset_to_xy(dataset, self.target_columns, self.qid_column)
data = xgb.dask.DaskDMatrix(self.dask_client, X, qid=qid)
preds = xgb.dask.predict(self.dask_client, self.booster, data, **predict_kwargs).compute()

return preds

Expand All @@ -164,28 +193,29 @@ def get_target_tag(objective: str) -> Tags:
"""Get the target tag from the specified objective"""
try:
return OBJECTIVES[objective]
except KeyError:
except KeyError as exc:
target_options_str = str(list(OBJECTIVES.keys()))
raise ValueError(f"Objective not supported. Must be one of: {target_options_str}")
raise ValueError(f"Objective not supported. Must be one of: {target_options_str}") from exc


def get_targets(dataset: Dataset, target_tag: Tags) -> List[str]:
def get_targets(schema: Schema, target_tag: Tags) -> List[str]:
"""Find target columns from dataset or specified target_column"""
targets = dataset.schema.select_by_tag(Tags.TARGET).select_by_tag(target_tag)
targets = schema.select_by_tag(Tags.TARGET).select_by_tag(target_tag)

if len(targets) >= 1:
return targets.column_names
else:
raise ValueError(
f"No target columns in the dataset schema with tags TARGET and {target_tag.name}"
)
raise ValueError(
f"No target columns in the dataset schema with tags TARGET and {target_tag.name}"
)


def dataset_to_dmatrix(
dataset: Dataset, target_columns: Union[str, list], qid_column: Optional[str]
) -> xgb.DMatrix:
def dataset_to_xy(
dataset: Dataset,
target_columns: Union[str, list],
qid_column: Optional[str],
):
"""Convert Merlin Dataset to XGBoost DMatrix"""
df = dataset.to_ddf().compute()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching to use dask, avoids the need to call compute on the dataset here

df = dataset.to_ddf()

qid = None
if qid_column:
Expand All @@ -210,6 +240,4 @@ def dataset_to_dmatrix(
# Ensure columns are in a consistent order
X = X[sorted(X.columns)]

data = xgb.DMatrix(X, label=y, qid=qid)

return data
return X, y, qid
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

from pathlib import Path

import distributed
import pytest

from merlin.core.utils import Distributed
from merlin.datasets.synthetic import generate_data
from merlin.io import Dataset

Expand Down Expand Up @@ -60,6 +62,12 @@ def testing_data() -> Dataset:
return data


@pytest.fixture(scope="module")
def dask_client() -> distributed.Client:
with Distributed(cluster_type="cpu") as dist:
yield dist.client


try:
import tensorflow as tf # noqa

Expand Down
124 changes: 78 additions & 46 deletions tests/unit/xgb/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,53 +19,85 @@
from merlin.models.xgb import XGBoost


def test_music_regression(music_streaming_data: Dataset):
model = XGBoost(objective="reg:logistic")
model.fit(music_streaming_data)
model.predict(music_streaming_data)
metrics = model.evaluate(music_streaming_data)

assert "rmse" in metrics


def test_unsupported_objective(music_streaming_data: Dataset):
with pytest.raises(ValueError) as excinfo:
model = XGBoost(objective="reg:unknown")
def test_without_dask_client(music_streaming_data: Dataset):
with pytest.raises(ValueError) as exc_info:
model = XGBoost(music_streaming_data.schema, objective="reg:logistic")
model.fit(music_streaming_data)
assert "Objective not supported" in str(excinfo.value)


def test_ecommerce_click(ecommerce_data: Dataset):
model = XGBoost(objective="binary:logistic", eval_metric="auc")
model.fit(ecommerce_data, target_columns=["click"])
model.predict(ecommerce_data)
metrics = model.evaluate(ecommerce_data)

assert "auc" in metrics

assert "No global client found" in str(exc_info.value)

def test_social_click(social_data: Dataset):
model = XGBoost(objective="binary:logistic", eval_metric=["auc"])
model.fit(social_data, target_columns=["click"])
model.predict(social_data)
metrics = model.evaluate(social_data)

assert "auc" in metrics
@pytest.mark.usefixtures("dask_client")
class TestXGBoost:
def test_unsupported_objective(self, music_streaming_data: Dataset):
with pytest.raises(ValueError) as excinfo:
model = XGBoost(music_streaming_data.schema, objective="reg:unknown")
model.fit(music_streaming_data)
assert "Objective not supported" in str(excinfo.value)


def test_criteo(criteo_data: Dataset):
model = XGBoost(objective="binary:logistic", eval_metric=["auc"])
model.fit(criteo_data)
model.predict(criteo_data)
metrics = model.evaluate(criteo_data)

assert "auc" in metrics


def test_social_ranking(social_data: Dataset):
model = XGBoost(objective="rank:ndcg", eval_metric=["auc", "ndcg", "map"])
model.fit(social_data, target_columns="click")
model.predict(social_data)
metrics = model.evaluate(social_data)

assert "map" in metrics
def test_music_regression(self, music_streaming_data: Dataset):
schema = music_streaming_data.schema
model = XGBoost(schema, objective="reg:logistic")
model.fit(music_streaming_data)
model.predict(music_streaming_data)
metrics = model.evaluate(music_streaming_data)

assert "rmse" in metrics

def test_ecommerce_click(self, ecommerce_data: Dataset):
schema = ecommerce_data.schema
model = XGBoost(
schema, target_columns=["click"], objective="binary:logistic", eval_metric="auc"
)
model.fit(ecommerce_data)
model.predict(ecommerce_data)
metrics = model.evaluate(ecommerce_data)

assert "auc" in metrics

def test_social_click(self, social_data: Dataset):
schema = social_data.schema
model = XGBoost(
schema, target_columns=["click"], objective="binary:logistic", eval_metric=["auc"]
)
model.fit(social_data)
model.predict(social_data)
metrics = model.evaluate(social_data)

assert "auc" in metrics

def test_logistic(self, criteo_data: Dataset):
schema = criteo_data.schema
model = XGBoost(schema, objective="binary:logistic", eval_metric=["auc"])
model.fit(criteo_data)
model.predict(criteo_data)
metrics = model.evaluate(criteo_data)

assert "auc" in metrics

def test_ndcg(self, social_data: Dataset):
schema = social_data.schema
model = XGBoost(
schema,
target_columns="click",
qid_column="user_id",
objective="rank:ndcg",
eval_metric=["auc", "ndcg", "map"],
)
model.fit(social_data)
model.predict(social_data)
metrics = model.evaluate(social_data)

assert "map" in metrics

def test_pairwise(self, social_data: Dataset):
schema = social_data.schema
model = XGBoost(
schema,
target_columns=["click"],
qid_column="user_id",
objective="rank:pairwise",
eval_metric=["ndcg", "auc", "map"],
)
model.fit(social_data)
model.predict(social_data)
model.evaluate(social_data)