@@ -121,26 +121,24 @@ def test_pairwise(self, social_data: Dataset):
121121 ],
122122)
123123@patch ("xgboost.dask.train" , side_effect = xgboost .dask .train )
124- def test_gpu_hist_dmatrix (mock_train , fit_kwargs , expected_dtrain_cls , dask_client ):
125- train , valid = generate_data ("music-streaming" , num_rows = 100 , set_sizes = (0.5 , 0.5 ))
126- schema = train .schema
124+ def test_gpu_hist_dmatrix (
125+ mock_train , fit_kwargs , expected_dtrain_cls , dask_client , music_streaming_data : Dataset
126+ ):
127+ schema = music_streaming_data .schema
127128 model = XGBoost (schema , objective = "reg:logistic" , tree_method = "gpu_hist" )
128- model .fit (train , evals = [( valid , "valid" )] , ** fit_kwargs )
129- model .predict (valid )
130- metrics = model .evaluate (valid )
129+ model .fit (music_streaming_data , ** fit_kwargs )
130+ model .predict (music_streaming_data )
131+ metrics = model .evaluate (music_streaming_data )
131132 assert "rmse" in metrics
132133
133134 assert mock_train .called
134135 assert mock_train .call_count == 1
135136
136137 train_call = mock_train .call_args_list [0 ]
137138 client , params , dtrain = train_call .args
138- evals = train_call .kwargs ["evals" ]
139139 assert dask_client == client
140140 assert params ["tree_method" ] == "gpu_hist"
141141 assert params ["objective" ] == "reg:logistic"
142- # check that we don't use quantile dmatrix for non-training eval data
143- assert not isinstance (evals [0 ][0 ], xgboost .dask .DaskDeviceQuantileDMatrix )
144142 assert isinstance (dtrain , expected_dtrain_cls )
145143
146144
0 commit comments