Skip to content

Commit 61a0874

Browse files
Revert "XGBoost - Use DaskDMatrix for evals data to ensure metrics in logs match result of evaluate (#682)"
This reverts commit 489137e.
1 parent 8c999c8 commit 61a0874

File tree

2 files changed

+8
-13
lines changed

2 files changed

+8
-13
lines changed

merlin/models/xgb/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,7 @@ def fit(
171171
self.target_columns,
172172
self.qid_column,
173173
)
174-
# using the quantile DMatrix as part of evals results in a
175-
# discrepancy between metrics reported in logs and result
176-
# of evaluate
177-
d_eval = xgb.dask.DaskDMatrix(self.dask_client, X, label=y, qid=qid)
174+
d_eval = dmatrix_cls(self.dask_client, X, label=y, qid=qid)
178175
watchlist.append((d_eval, name))
179176

180177
train_res = xgb.dask.train(

tests/unit/xgb/test_xgboost.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,26 +121,24 @@ def test_pairwise(self, social_data: Dataset):
121121
],
122122
)
123123
@patch("xgboost.dask.train", side_effect=xgboost.dask.train)
124-
def test_gpu_hist_dmatrix(mock_train, fit_kwargs, expected_dtrain_cls, dask_client):
125-
train, valid = generate_data("music-streaming", num_rows=100, set_sizes=(0.5, 0.5))
126-
schema = train.schema
124+
def test_gpu_hist_dmatrix(
125+
mock_train, fit_kwargs, expected_dtrain_cls, dask_client, music_streaming_data: Dataset
126+
):
127+
schema = music_streaming_data.schema
127128
model = XGBoost(schema, objective="reg:logistic", tree_method="gpu_hist")
128-
model.fit(train, evals=[(valid, "valid")], **fit_kwargs)
129-
model.predict(valid)
130-
metrics = model.evaluate(valid)
129+
model.fit(music_streaming_data, **fit_kwargs)
130+
model.predict(music_streaming_data)
131+
metrics = model.evaluate(music_streaming_data)
131132
assert "rmse" in metrics
132133

133134
assert mock_train.called
134135
assert mock_train.call_count == 1
135136

136137
train_call = mock_train.call_args_list[0]
137138
client, params, dtrain = train_call.args
138-
evals = train_call.kwargs["evals"]
139139
assert dask_client == client
140140
assert params["tree_method"] == "gpu_hist"
141141
assert params["objective"] == "reg:logistic"
142-
# check that we don't use quantile dmatrix for non-training eval data
143-
assert not isinstance(evals[0][0], xgboost.dask.DaskDeviceQuantileDMatrix)
144142
assert isinstance(dtrain, expected_dtrain_cls)
145143

146144

0 commit comments

Comments
 (0)