|
4 | 4 | import lightgbm as lgb |
5 | 5 | import numpy as np |
6 | 6 | import pandas as pd |
7 | | - |
| 7 | +from typing import Text, Union |
8 | 8 | from ...model.base import Model |
9 | 9 | from ...data.dataset import DatasetH |
10 | 10 | from ...data.dataset.handler import DataHandlerLP |
@@ -40,6 +40,10 @@ def __init__( |
40 | 40 | self.bins_sr = bins_sr |
41 | 41 | self.bins_fs = bins_fs |
42 | 42 | self.decay = decay |
| 43 | + if sample_ratios is None: # the default values for sample_ratios |
| 44 | + sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4] |
| 45 | + if sub_weights is None: # the default values for sub_weights |
| 46 | + sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2] |
43 | 47 | if not len(sample_ratios) == bins_fs: |
44 | 48 | raise ValueError("The length of sample_ratios should be equal to bins_fs.") |
45 | 49 | self.sample_ratios = sample_ratios |
@@ -228,10 +232,10 @@ def retrieve_loss_curve(self, model, df_train, features): |
228 | 232 | raise ValueError("not implemented yet") |
229 | 233 | return loss_curve |
230 | 234 |
|
231 | | - def predict(self, dataset): |
| 235 | + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): |
232 | 236 | if self.ensemble is None: |
233 | 237 | raise ValueError("model is not fitted yet!") |
234 | | - x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I) |
| 238 | + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) |
235 | 239 | pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index) |
236 | 240 | for i_sub, submodel in enumerate(self.ensemble): |
237 | 241 | feat_sub = self.sub_features[i_sub] |
|
0 commit comments