Skip to content

Commit 62628d6

Browse files
oliverholworthykarlhigleyrnyakbenfred
authored
XGBoost - Switch to dask API (#466)
* Use raise-from exception chaining for objective ValueError * Add dask_client fixture to conftest * Remove unncessary else block in get_targets * Add schema to XGBoost constructor and move fit parameters. Moving target_columns and qid_column from fit method to __init__. This makes it clearer the relationship between these attributes and the schema. * Switch to using dask API for XGBoost. Requiring use of dask client. * Remove dask_cudf import and typehints * Setup dask_client fixture to use CPU * Use cluster_type parameter instead of new_cluster Co-authored-by: Karl Higley <kmhigley@gmail.com> Co-authored-by: rnyak <ronayak@hotmail.com> Co-authored-by: Ben Frederickson <github@benfrederickson.com>
1 parent e11cb55 commit 62628d6

File tree

3 files changed

+167
-99
lines changed

3 files changed

+167
-99
lines changed

merlin/models/xgb/__init__.py

Lines changed: 81 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,81 @@
11
import warnings
22
from typing import Dict, List, Optional, Union
33

4+
import distributed
45
import numpy as np
56
import xgboost as xgb
67

8+
from merlin.core.utils import global_dask_client
79
from merlin.io import Dataset
8-
from merlin.schema import Tags
10+
from merlin.schema import Schema, Tags
911

1012

1113
class XGBoost:
12-
"""Create an XGBoost model.
14+
"""Create an XGBoost model from a merlin dataset.
1315
The class adapts an XGBoost model to work with the high level merlin-models API.
1416
1517
Example usage::
1618
1719
# get the movielens dataset
1820
from merlin.datasets.entertainment import get_movielens
1921
20-
train, valid = get_movielens()
22+
train, valid = get_movielens(variant="ml-1m")
2123
2224
# Train an XGBoost model
23-
from merlin.schema import Tags
25+
from merlin.core.utils import Distributed
2426
from merlin.models.xgb import XGBoost
2527
26-
model = XGBoost(objective="binary:logistic")
27-
model.fit(train)
28-
29-
model.evaluate(valid)
28+
with Distributed():
29+
model = XGBoost(train.schema, objective="binary:logistic")
30+
model.fit(train)
31+
metrics = model.evaluate(valid)
3032
"""
3133

32-
def __init__(self, *, objective="reg:squarederror", **params):
34+
def __init__(
35+
self,
36+
schema: Schema,
37+
*,
38+
target_columns: Optional[Union[str, list]] = None,
39+
qid_column: Optional[str] = None,
40+
objective: str = "reg:squarederror",
41+
booster: Optional[xgb.Booster] = None,
42+
**params,
43+
):
3344
"""
3445
Parameters
3546
----------
47+
schema : merlin.schema.Schema
48+
The schema of the data that will be used to train and evaluate the model.
49+
target_columns : Optional[Union[list, str]]
50+
The target columns to use. If provided, will be used as the label(s).
51+
Otherwise the targets are automatically inferred from the objective and column tags.
52+
qid_column : Optional[str]
53+
For ranking objectives. The query ID column. If not provided will use
54+
the user ID (tagged with merlin.schema.Tags.USER_ID) column.
3655
objective : str
3756
The XGBoost objective to use. List of XGBoost objective functions:
3857
https://xgboost.readthedocs.io/en/stable/gpu/index.html#objective-functions
3958
**params
4059
The parameters to use for the XGBoost train method
4160
"""
4261
self.params = {**params, "objective": objective}
43-
self.bst = None
62+
63+
target_tag = get_target_tag(objective)
64+
self.target_columns = target_columns or get_targets(schema, target_tag)
65+
66+
if objective.startswith("rank") and qid_column is None:
67+
qid_column = schema.select_by_tag(Tags.USER_ID).column_names[0]
68+
self.qid_column = qid_column
69+
70+
self.booster = booster
71+
72+
@property
73+
def dask_client(self) -> Optional[distributed.Client]:
74+
return global_dask_client()
4475

4576
def fit(
4677
self,
4778
train: Dataset,
48-
*,
49-
target_columns: Optional[Union[str, list]] = None,
50-
qid_column: Optional[str] = None,
5179
**train_kwargs,
5280
) -> xgb.Booster:
5381
"""Trains the XGBoost Model.
@@ -62,12 +90,6 @@ def fit(
6290
The training dataset to use to fit the model.
6391
We will use the column(s) tagged with merlin.schema.Tags.TARGET that match the
6492
objective as the label(s).
65-
target_columns : Optional[Union[list, str]]
66-
The target columns to use. If provided, will be used as the label(s).
67-
Otherwise the targets are automatically inferred from the objective and column tags.
68-
qid_column : Optional[str]
69-
For ranking objectives. The query ID column. If not provided will use
70-
the user ID (tagged with merlin.schema.Tags.USER_ID) column.
7193
**train_kwargs
7294
Additional keyword arguments passed to the xgboost.train function
7395
@@ -80,21 +102,21 @@ def fit(
80102
ValueError
81103
If objective is not supported. Or if the target columns cannot be found.
82104
"""
83-
objective = self.params["objective"]
84-
target_tag = get_target_tag(objective)
85-
self.target_columns = target_columns or get_targets(train, target_tag)
86-
87-
# for ranking objectives, set the grouping
88-
if objective.startswith("rank") and qid_column is None:
89-
qid_column = train.schema.select_by_tag(Tags.USER_ID).column_names[0]
90-
self.qid_column = qid_column
105+
X, y, qid = dataset_to_xy(
106+
train,
107+
self.target_columns,
108+
self.qid_column,
109+
)
91110

92-
dtrain = dataset_to_dmatrix(train, self.target_columns, self.qid_column)
111+
dtrain = xgb.dask.DaskDMatrix(self.dask_client, X, label=y, qid=qid)
93112
watchlist = [(dtrain, "train")]
94113

95-
self.bst = xgb.train(self.params, dtrain, evals=watchlist, **train_kwargs)
114+
booster: xgb.Booster = xgb.dask.train(
115+
self.dask_client, self.params, dtrain, evals=watchlist, **train_kwargs
116+
)["booster"]
117+
self.booster = booster
96118

97-
return self.bst
119+
return booster
98120

99121
def evaluate(self, dataset: Dataset, **predict_kwargs) -> Dict[str, float]:
100122
"""Evaluates the model on the dataset provided.
@@ -111,14 +133,20 @@ def evaluate(self, dataset: Dataset, **predict_kwargs) -> Dict[str, float]:
111133
Dict[str, float]
112134
Dictionary of metrics of the form {metric_name: value}.
113135
"""
114-
if self.bst is None:
136+
if self.booster is None:
115137
raise ValueError("The fit method must be called before evaluate.")
116138

117-
data: xgb.DMatrix = dataset_to_dmatrix(dataset, self.target_columns, self.qid_column)
118-
preds = self.bst.predict(data, **predict_kwargs)
119-
data.set_label(preds)
139+
X, _, qid = dataset_to_xy(dataset, self.target_columns, self.qid_column)
140+
data = xgb.dask.DaskDMatrix(self.dask_client, X, qid=qid)
141+
preds = xgb.dask.predict(self.dask_client, self.booster, data, **predict_kwargs)
142+
143+
# convert to DMatrix
144+
# (eval doesn't have dask support currently)
145+
if qid is not None:
146+
qid = qid.compute()
147+
eval_data = xgb.DMatrix(X.compute(), label=preds.compute(), qid=qid)
120148

121-
metrics_str = self.bst.eval(data)
149+
metrics_str = self.booster.eval(eval_data)
122150
metrics = {}
123151
for metric in metrics_str.split("\t")[1:]:
124152
metric_name, metric_value = metric.split(":")
@@ -141,11 +169,12 @@ def predict(self, dataset: Dataset, **predict_kwargs) -> np.ndarray:
141169
numpy.ndarray
142170
The predicions data
143171
"""
144-
if self.bst is None:
172+
if self.booster is None:
145173
raise ValueError("The fit method must be called before predict.")
146174

147-
data: xgb.DMatrix = dataset_to_dmatrix(dataset, self.target_columns, self.qid_column)
148-
preds = self.bst.predict(data, **predict_kwargs)
175+
X, _, qid = dataset_to_xy(dataset, self.target_columns, self.qid_column)
176+
data = xgb.dask.DaskDMatrix(self.dask_client, X, qid=qid)
177+
preds = xgb.dask.predict(self.dask_client, self.booster, data, **predict_kwargs).compute()
149178

150179
return preds
151180

@@ -164,28 +193,29 @@ def get_target_tag(objective: str) -> Tags:
164193
"""Get the target tag from the specified objective"""
165194
try:
166195
return OBJECTIVES[objective]
167-
except KeyError:
196+
except KeyError as exc:
168197
target_options_str = str(list(OBJECTIVES.keys()))
169-
raise ValueError(f"Objective not supported. Must be one of: {target_options_str}")
198+
raise ValueError(f"Objective not supported. Must be one of: {target_options_str}") from exc
170199

171200

172-
def get_targets(dataset: Dataset, target_tag: Tags) -> List[str]:
201+
def get_targets(schema: Schema, target_tag: Tags) -> List[str]:
173202
"""Find target columns from dataset or specified target_column"""
174-
targets = dataset.schema.select_by_tag(Tags.TARGET).select_by_tag(target_tag)
203+
targets = schema.select_by_tag(Tags.TARGET).select_by_tag(target_tag)
175204

176205
if len(targets) >= 1:
177206
return targets.column_names
178-
else:
179-
raise ValueError(
180-
f"No target columns in the dataset schema with tags TARGET and {target_tag.name}"
181-
)
207+
raise ValueError(
208+
f"No target columns in the dataset schema with tags TARGET and {target_tag.name}"
209+
)
182210

183211

184-
def dataset_to_dmatrix(
185-
dataset: Dataset, target_columns: Union[str, list], qid_column: Optional[str]
186-
) -> xgb.DMatrix:
212+
def dataset_to_xy(
213+
dataset: Dataset,
214+
target_columns: Union[str, list],
215+
qid_column: Optional[str],
216+
):
187217
"""Convert Merlin Dataset to XGBoost DMatrix"""
188-
df = dataset.to_ddf().compute()
218+
df = dataset.to_ddf()
189219

190220
qid = None
191221
if qid_column:
@@ -210,6 +240,4 @@ def dataset_to_dmatrix(
210240
# Ensure columns are in a consistent order
211241
X = X[sorted(X.columns)]
212242

213-
data = xgb.DMatrix(X, label=y, qid=qid)
214-
215-
return data
243+
return X, y, qid

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919

2020
from pathlib import Path
2121

22+
import distributed
2223
import pytest
2324

25+
from merlin.core.utils import Distributed
2426
from merlin.datasets.synthetic import generate_data
2527
from merlin.io import Dataset
2628

@@ -60,6 +62,12 @@ def testing_data() -> Dataset:
6062
return data
6163

6264

65+
@pytest.fixture(scope="module")
66+
def dask_client() -> distributed.Client:
67+
with Distributed(cluster_type="cpu") as dist:
68+
yield dist.client
69+
70+
6371
try:
6472
import tensorflow as tf # noqa
6573

tests/unit/xgb/test_xgboost.py

Lines changed: 78 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,53 +19,85 @@
1919
from merlin.models.xgb import XGBoost
2020

2121

22-
def test_music_regression(music_streaming_data: Dataset):
23-
model = XGBoost(objective="reg:logistic")
24-
model.fit(music_streaming_data)
25-
model.predict(music_streaming_data)
26-
metrics = model.evaluate(music_streaming_data)
27-
28-
assert "rmse" in metrics
29-
30-
31-
def test_unsupported_objective(music_streaming_data: Dataset):
32-
with pytest.raises(ValueError) as excinfo:
33-
model = XGBoost(objective="reg:unknown")
22+
def test_without_dask_client(music_streaming_data: Dataset):
23+
with pytest.raises(ValueError) as exc_info:
24+
model = XGBoost(music_streaming_data.schema, objective="reg:logistic")
3425
model.fit(music_streaming_data)
35-
assert "Objective not supported" in str(excinfo.value)
36-
37-
38-
def test_ecommerce_click(ecommerce_data: Dataset):
39-
model = XGBoost(objective="binary:logistic", eval_metric="auc")
40-
model.fit(ecommerce_data, target_columns=["click"])
41-
model.predict(ecommerce_data)
42-
metrics = model.evaluate(ecommerce_data)
43-
44-
assert "auc" in metrics
45-
26+
assert "No global client found" in str(exc_info.value)
4627

47-
def test_social_click(social_data: Dataset):
48-
model = XGBoost(objective="binary:logistic", eval_metric=["auc"])
49-
model.fit(social_data, target_columns=["click"])
50-
model.predict(social_data)
51-
metrics = model.evaluate(social_data)
5228

53-
assert "auc" in metrics
29+
@pytest.mark.usefixtures("dask_client")
30+
class TestXGBoost:
31+
def test_unsupported_objective(self, music_streaming_data: Dataset):
32+
with pytest.raises(ValueError) as excinfo:
33+
model = XGBoost(music_streaming_data.schema, objective="reg:unknown")
34+
model.fit(music_streaming_data)
35+
assert "Objective not supported" in str(excinfo.value)
5436

55-
56-
def test_criteo(criteo_data: Dataset):
57-
model = XGBoost(objective="binary:logistic", eval_metric=["auc"])
58-
model.fit(criteo_data)
59-
model.predict(criteo_data)
60-
metrics = model.evaluate(criteo_data)
61-
62-
assert "auc" in metrics
63-
64-
65-
def test_social_ranking(social_data: Dataset):
66-
model = XGBoost(objective="rank:ndcg", eval_metric=["auc", "ndcg", "map"])
67-
model.fit(social_data, target_columns="click")
68-
model.predict(social_data)
69-
metrics = model.evaluate(social_data)
70-
71-
assert "map" in metrics
37+
def test_music_regression(self, music_streaming_data: Dataset):
38+
schema = music_streaming_data.schema
39+
model = XGBoost(schema, objective="reg:logistic")
40+
model.fit(music_streaming_data)
41+
model.predict(music_streaming_data)
42+
metrics = model.evaluate(music_streaming_data)
43+
44+
assert "rmse" in metrics
45+
46+
def test_ecommerce_click(self, ecommerce_data: Dataset):
47+
schema = ecommerce_data.schema
48+
model = XGBoost(
49+
schema, target_columns=["click"], objective="binary:logistic", eval_metric="auc"
50+
)
51+
model.fit(ecommerce_data)
52+
model.predict(ecommerce_data)
53+
metrics = model.evaluate(ecommerce_data)
54+
55+
assert "auc" in metrics
56+
57+
def test_social_click(self, social_data: Dataset):
58+
schema = social_data.schema
59+
model = XGBoost(
60+
schema, target_columns=["click"], objective="binary:logistic", eval_metric=["auc"]
61+
)
62+
model.fit(social_data)
63+
model.predict(social_data)
64+
metrics = model.evaluate(social_data)
65+
66+
assert "auc" in metrics
67+
68+
def test_logistic(self, criteo_data: Dataset):
69+
schema = criteo_data.schema
70+
model = XGBoost(schema, objective="binary:logistic", eval_metric=["auc"])
71+
model.fit(criteo_data)
72+
model.predict(criteo_data)
73+
metrics = model.evaluate(criteo_data)
74+
75+
assert "auc" in metrics
76+
77+
def test_ndcg(self, social_data: Dataset):
78+
schema = social_data.schema
79+
model = XGBoost(
80+
schema,
81+
target_columns="click",
82+
qid_column="user_id",
83+
objective="rank:ndcg",
84+
eval_metric=["auc", "ndcg", "map"],
85+
)
86+
model.fit(social_data)
87+
model.predict(social_data)
88+
metrics = model.evaluate(social_data)
89+
90+
assert "map" in metrics
91+
92+
def test_pairwise(self, social_data: Dataset):
93+
schema = social_data.schema
94+
model = XGBoost(
95+
schema,
96+
target_columns=["click"],
97+
qid_column="user_id",
98+
objective="rank:pairwise",
99+
eval_metric=["ndcg", "auc", "map"],
100+
)
101+
model.fit(social_data)
102+
model.predict(social_data)
103+
model.evaluate(social_data)

0 commit comments

Comments
 (0)