11import warnings
22from typing import Dict , List , Optional , Union
33
4+ import distributed
45import numpy as np
56import xgboost as xgb
67
8+ from merlin .core .utils import global_dask_client
79from merlin .io import Dataset
8- from merlin .schema import Tags
10+ from merlin .schema import Schema , Tags
911
1012
1113class XGBoost :
12- """Create an XGBoost model.
14+ """Create an XGBoost model from a merlin dataset .
1315 The class adapts an XGBoost model to work with the high level merlin-models API.
1416
1517 Example usage::
1618
1719 # get the movielens dataset
1820 from merlin.datasets.entertainment import get_movielens
1921
20- train, valid = get_movielens()
22+ train, valid = get_movielens(variant="ml-1m" )
2123
2224 # Train an XGBoost model
23- from merlin.schema import Tags
25+ from merlin.core.utils import Distributed
2426 from merlin.models.xgb import XGBoost
2527
26- model = XGBoost(objective="binary:logistic")
27- model.fit (train)
28-
29- model.evaluate(valid)
28+ with Distributed():
29+ model = XGBoost (train.schema, objective="binary:logistic" )
30+ model.fit(train)
31+ metrics = model.evaluate(valid)
3032 """
3133
32- def __init__ (self , * , objective = "reg:squarederror" , ** params ):
34+ def __init__ (
35+ self ,
36+ schema : Schema ,
37+ * ,
38+ target_columns : Optional [Union [str , list ]] = None ,
39+ qid_column : Optional [str ] = None ,
40+ objective : str = "reg:squarederror" ,
41+ booster : Optional [xgb .Booster ] = None ,
42+ ** params ,
43+ ):
3344 """
3445 Parameters
3546 ----------
47+ schema : merlin.schema.Schema
48+ The schema of the data that will be used to train and evaluate the model.
49+ target_columns : Optional[Union[list, str]]
50+ The target columns to use. If provided, will be used as the label(s).
51+ Otherwise the targets are automatically inferred from the objective and column tags.
52+ qid_column : Optional[str]
53+ For ranking objectives. The query ID column. If not provided will use
54+ the user ID (tagged with merlin.schema.Tags.USER_ID) column.
3655 objective : str
3756 The XGBoost objective to use. List of XGBoost objective functions:
3857 https://xgboost.readthedocs.io/en/stable/gpu/index.html#objective-functions
3958 **params
4059 The parameters to use for the XGBoost train method
4160 """
4261 self .params = {** params , "objective" : objective }
43- self .bst = None
62+
63+ target_tag = get_target_tag (objective )
64+ self .target_columns = target_columns or get_targets (schema , target_tag )
65+
66+ if objective .startswith ("rank" ) and qid_column is None :
67+ qid_column = schema .select_by_tag (Tags .USER_ID ).column_names [0 ]
68+ self .qid_column = qid_column
69+
70+ self .booster = booster
71+
72+ @property
73+ def dask_client (self ) -> Optional [distributed .Client ]:
74+ return global_dask_client ()
4475
4576 def fit (
4677 self ,
4778 train : Dataset ,
48- * ,
49- target_columns : Optional [Union [str , list ]] = None ,
50- qid_column : Optional [str ] = None ,
5179 ** train_kwargs ,
5280 ) -> xgb .Booster :
5381 """Trains the XGBoost Model.
@@ -62,12 +90,6 @@ def fit(
6290 The training dataset to use to fit the model.
6391 We will use the column(s) tagged with merlin.schema.Tags.TARGET that match the
6492 objective as the label(s).
65- target_columns : Optional[Union[list, str]]
66- The target columns to use. If provided, will be used as the label(s).
67- Otherwise the targets are automatically inferred from the objective and column tags.
68- qid_column : Optional[str]
69- For ranking objectives. The query ID column. If not provided will use
70- the user ID (tagged with merlin.schema.Tags.USER_ID) column.
7193 **train_kwargs
7294 Additional keyword arguments passed to the xgboost.train function
7395
@@ -80,21 +102,21 @@ def fit(
80102 ValueError
81103 If objective is not supported. Or if the target columns cannot be found.
82104 """
83- objective = self .params ["objective" ]
84- target_tag = get_target_tag (objective )
85- self .target_columns = target_columns or get_targets (train , target_tag )
86-
87- # for ranking objectives, set the grouping
88- if objective .startswith ("rank" ) and qid_column is None :
89- qid_column = train .schema .select_by_tag (Tags .USER_ID ).column_names [0 ]
90- self .qid_column = qid_column
105+ X , y , qid = dataset_to_xy (
106+ train ,
107+ self .target_columns ,
108+ self .qid_column ,
109+ )
91110
92- dtrain = dataset_to_dmatrix ( train , self .target_columns , self . qid_column )
111+ dtrain = xgb . dask . DaskDMatrix ( self .dask_client , X , label = y , qid = qid )
93112 watchlist = [(dtrain , "train" )]
94113
95- self .bst = xgb .train (self .params , dtrain , evals = watchlist , ** train_kwargs )
114+ booster : xgb .Booster = xgb .dask .train (
115+ self .dask_client , self .params , dtrain , evals = watchlist , ** train_kwargs
116+ )["booster" ]
117+ self .booster = booster
96118
97- return self . bst
119+ return booster
98120
99121 def evaluate (self , dataset : Dataset , ** predict_kwargs ) -> Dict [str , float ]:
100122 """Evaluates the model on the dataset provided.
@@ -111,14 +133,20 @@ def evaluate(self, dataset: Dataset, **predict_kwargs) -> Dict[str, float]:
111133 Dict[str, float]
112134 Dictionary of metrics of the form {metric_name: value}.
113135 """
114- if self .bst is None :
136+ if self .booster is None :
115137 raise ValueError ("The fit method must be called before evaluate." )
116138
117- data : xgb .DMatrix = dataset_to_dmatrix (dataset , self .target_columns , self .qid_column )
118- preds = self .bst .predict (data , ** predict_kwargs )
119- data .set_label (preds )
139+ X , _ , qid = dataset_to_xy (dataset , self .target_columns , self .qid_column )
140+ data = xgb .dask .DaskDMatrix (self .dask_client , X , qid = qid )
141+ preds = xgb .dask .predict (self .dask_client , self .booster , data , ** predict_kwargs )
142+
143+ # convert to DMatrix
144+ # (eval doesn't have dask support currently)
145+ if qid is not None :
146+ qid = qid .compute ()
147+ eval_data = xgb .DMatrix (X .compute (), label = preds .compute (), qid = qid )
120148
121- metrics_str = self .bst .eval (data )
149+ metrics_str = self .booster .eval (eval_data )
122150 metrics = {}
123151 for metric in metrics_str .split ("\t " )[1 :]:
124152 metric_name , metric_value = metric .split (":" )
@@ -141,11 +169,12 @@ def predict(self, dataset: Dataset, **predict_kwargs) -> np.ndarray:
141169 numpy.ndarray
142170 The predicions data
143171 """
144- if self .bst is None :
172+ if self .booster is None :
145173 raise ValueError ("The fit method must be called before predict." )
146174
147- data : xgb .DMatrix = dataset_to_dmatrix (dataset , self .target_columns , self .qid_column )
148- preds = self .bst .predict (data , ** predict_kwargs )
175+ X , _ , qid = dataset_to_xy (dataset , self .target_columns , self .qid_column )
176+ data = xgb .dask .DaskDMatrix (self .dask_client , X , qid = qid )
177+ preds = xgb .dask .predict (self .dask_client , self .booster , data , ** predict_kwargs ).compute ()
149178
150179 return preds
151180
@@ -164,28 +193,29 @@ def get_target_tag(objective: str) -> Tags:
164193 """Get the target tag from the specified objective"""
165194 try :
166195 return OBJECTIVES [objective ]
167- except KeyError :
196+ except KeyError as exc :
168197 target_options_str = str (list (OBJECTIVES .keys ()))
169- raise ValueError (f"Objective not supported. Must be one of: { target_options_str } " )
198+ raise ValueError (f"Objective not supported. Must be one of: { target_options_str } " ) from exc
170199
171200
172- def get_targets (dataset : Dataset , target_tag : Tags ) -> List [str ]:
201+ def get_targets (schema : Schema , target_tag : Tags ) -> List [str ]:
173202 """Find target columns from dataset or specified target_column"""
174- targets = dataset . schema .select_by_tag (Tags .TARGET ).select_by_tag (target_tag )
203+ targets = schema .select_by_tag (Tags .TARGET ).select_by_tag (target_tag )
175204
176205 if len (targets ) >= 1 :
177206 return targets .column_names
178- else :
179- raise ValueError (
180- f"No target columns in the dataset schema with tags TARGET and { target_tag .name } "
181- )
207+ raise ValueError (
208+ f"No target columns in the dataset schema with tags TARGET and { target_tag .name } "
209+ )
182210
183211
184- def dataset_to_dmatrix (
185- dataset : Dataset , target_columns : Union [str , list ], qid_column : Optional [str ]
186- ) -> xgb .DMatrix :
212+ def dataset_to_xy (
213+ dataset : Dataset ,
214+ target_columns : Union [str , list ],
215+ qid_column : Optional [str ],
216+ ):
187217 """Convert Merlin Dataset to XGBoost DMatrix"""
188- df = dataset .to_ddf (). compute ()
218+ df = dataset .to_ddf ()
189219
190220 qid = None
191221 if qid_column :
@@ -210,6 +240,4 @@ def dataset_to_dmatrix(
210240 # Ensure columns are in a consistent order
211241 X = X [sorted (X .columns )]
212242
213- data = xgb .DMatrix (X , label = y , qid = qid )
214-
215- return data
243+ return X , y , qid
0 commit comments