Skip to content
Merged
5 changes: 4 additions & 1 deletion merlin/models/xgb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,10 @@ def fit(
self.target_columns,
self.qid_column,
)
d_eval = dmatrix_cls(self.dask_client, X, label=y, qid=qid)
# using the quantile DMatrix as part of evals results in a
# discrepancy between metrics reported in logs and result
# of evaluate
d_eval = xgb.dask.DaskDMatrix(self.dask_client, X, label=y, qid=qid)
watchlist.append((d_eval, name))

train_res = xgb.dask.train(
Expand Down
16 changes: 9 additions & 7 deletions tests/unit/xgb/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,26 @@ def test_pairwise(self, social_data: Dataset):
],
)
@patch("xgboost.dask.train", side_effect=xgboost.dask.train)
def test_gpu_hist_dmatrix(
mock_train, fit_kwargs, expected_dtrain_cls, dask_client, music_streaming_data: Dataset
):
schema = music_streaming_data.schema
def test_gpu_hist_dmatrix(mock_train, fit_kwargs, expected_dtrain_cls, dask_client):
train, valid = generate_data("music-streaming", num_rows=100, set_sizes=(0.5, 0.5))
schema = train.schema
model = XGBoost(schema, objective="reg:logistic", tree_method="gpu_hist")
model.fit(music_streaming_data, **fit_kwargs)
model.predict(music_streaming_data)
metrics = model.evaluate(music_streaming_data)
model.fit(train, evals=[(valid, "valid")], **fit_kwargs)
model.predict(valid)
metrics = model.evaluate(valid)
assert "rmse" in metrics

assert mock_train.called
assert mock_train.call_count == 1

train_call = mock_train.call_args_list[0]
client, params, dtrain = train_call.args
evals = train_call.kwargs["evals"]
assert dask_client == client
assert params["tree_method"] == "gpu_hist"
assert params["objective"] == "reg:logistic"
# check that we don't use quantile dmatrix for non-training eval data
assert not isinstance(evals[0][0], xgboost.dask.DaskDeviceQuantileDMatrix)
assert isinstance(dtrain, expected_dtrain_cls)


Expand Down