Skip to content

Commit 01afd06

Browse files
authored
fix workflow bug (#882)
* fix workflow bug * Fix output of pytorch NN * Fix parameter bug
1 parent d533219 commit 01afd06

File tree

4 files changed

+15
-9
lines changed

4 files changed

+15
-9
lines changed

qlib/contrib/model/pytorch_nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def fit(
240240
R.log_metrics(val_loss=loss_val.val, step=step)
241241
if verbose:
242242
self.logger.info(
243-
"[Epoch {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val)
243+
"[Step {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val)
244244
)
245245
evals_result["train"].append(train_loss)
246246
evals_result["valid"].append(loss_val.val)

qlib/model/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def _log_task_info(task_config: dict):
3737
def _exe_task(task_config: dict):
3838
rec = R.get_recorder()
3939
# model & dataset initiation
40-
model: Model = init_instance_by_config(task_config["model"])
41-
dataset: Dataset = init_instance_by_config(task_config["dataset"])
40+
model: Model = init_instance_by_config(task_config["model"], accept_types=Model)
41+
dataset: Dataset = init_instance_by_config(task_config["dataset"], accept_types=Dataset)
4242
reweighter: Reweighter = task_config.get("reweighter", None)
4343
# model training
4444
auto_filter_kwargs(model.fit)(dataset, reweighter=reweighter)

qlib/workflow/__init__.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,9 @@ def list_recorders(self, experiment_id=None, experiment_name=None):
226226
"""
227227
return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders()
228228

229-
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
229+
def get_exp(
230+
self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False
231+
) -> Experiment:
230232
"""
231233
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
232234
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
@@ -291,6 +293,10 @@ def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = Tr
291293
create : boolean
292294
an argument determines whether the method will automatically create a new experiment
293295
according to user's specification if the experiment hasn't been created before.
296+
start : bool
297+
when start is True,
298+
if the experiment has not started(not activated), it will start
299+
It is designed for R.log_params to auto start experiments
294300
295301
Returns
296302
-------
@@ -300,7 +306,7 @@ def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = Tr
300306
experiment_id=experiment_id,
301307
experiment_name=experiment_name,
302308
create=create,
303-
start=False,
309+
start=start,
304310
)
305311

306312
def delete_exp(self, experiment_id=None, experiment_name=None):
@@ -542,7 +548,7 @@ def log_params(self, **kwargs):
542548
keyword argument:
543549
name1=value1, name2=value2, ...
544550
"""
545-
self.get_exp().get_recorder(start=True).log_params(**kwargs)
551+
self.get_exp(start=True).get_recorder(start=True).log_params(**kwargs)
546552

547553
def log_metrics(self, step=None, **kwargs):
548554
"""
@@ -567,7 +573,7 @@ def log_metrics(self, step=None, **kwargs):
567573
keyword argument:
568574
name1=value1, name2=value2, ...
569575
"""
570-
self.get_exp().get_recorder(start=True).log_metrics(step, **kwargs)
576+
self.get_exp(start=True).get_recorder(start=True).log_metrics(step, **kwargs)
571577

572578
def set_tags(self, **kwargs):
573579
"""
@@ -592,7 +598,7 @@ def set_tags(self, **kwargs):
592598
keyword argument:
593599
name1=value1, name2=value2, ...
594600
"""
595-
self.get_exp().get_recorder(start=True).set_tags(**kwargs)
601+
self.get_exp(start=True).get_recorder(start=True).set_tags(**kwargs)
596602

597603

598604
class RecorderWrapper(Wrapper):

qlib/workflow/expm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = Tr
178178
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
179179
False,
180180
)
181-
if is_new and start:
181+
if self.active_experiment is None and start:
182182
self.active_experiment = exp
183183
# start the recorder
184184
self.active_experiment.start()

0 commit comments

Comments
 (0)