diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index 969cca63dbb..83d217458cd 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -240,7 +240,7 @@ def fit( R.log_metrics(val_loss=loss_val.val, step=step) if verbose: self.logger.info( - "[Epoch {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val) + "[Step {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val) ) evals_result["train"].append(train_loss) evals_result["valid"].append(loss_val.val) diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 187a1481764..a65801bbbf9 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -37,8 +37,8 @@ def _log_task_info(task_config: dict): def _exe_task(task_config: dict): rec = R.get_recorder() # model & dataset initiation - model: Model = init_instance_by_config(task_config["model"]) - dataset: Dataset = init_instance_by_config(task_config["dataset"]) + model: Model = init_instance_by_config(task_config["model"], accept_types=Model) + dataset: Dataset = init_instance_by_config(task_config["dataset"], accept_types=Dataset) reweighter: Reweighter = task_config.get("reweighter", None) # model training auto_filter_kwargs(model.fit)(dataset, reweighter=reweighter) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index c2fd93fffed..47009b79218 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -226,7 +226,9 @@ def list_recorders(self, experiment_id=None, experiment_name=None): """ return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders() - def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment: + def get_exp( + self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False + ) -> Experiment: """ Method for retrieving an experiment with given id or name. Once the `create` argument is set to True, if no valid experiment is found, this method will create one for you. Otherwise, it will @@ -291,6 +293,10 @@ def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = Tr create : boolean an argument determines whether the method will automatically create a new experiment according to user's specification if the experiment hasn't been created before. + start : bool + when start is True, + if the experiment has not started(not activated), it will start + It is designed for R.log_params to auto start experiments Returns ------- @@ -300,7 +306,7 @@ def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = Tr experiment_id=experiment_id, experiment_name=experiment_name, create=create, - start=False, + start=start, ) def delete_exp(self, experiment_id=None, experiment_name=None): @@ -542,7 +548,7 @@ def log_params(self, **kwargs): keyword argument: name1=value1, name2=value2, ... """ - self.get_exp().get_recorder(start=True).log_params(**kwargs) + self.get_exp(start=True).get_recorder(start=True).log_params(**kwargs) def log_metrics(self, step=None, **kwargs): """ @@ -567,7 +573,7 @@ def log_metrics(self, step=None, **kwargs): keyword argument: name1=value1, name2=value2, ... """ - self.get_exp().get_recorder(start=True).log_metrics(step, **kwargs) + self.get_exp(start=True).get_recorder(start=True).log_metrics(step, **kwargs) def set_tags(self, **kwargs): """ @@ -592,7 +598,7 @@ def set_tags(self, **kwargs): keyword argument: name1=value1, name2=value2, ... """ - self.get_exp().get_recorder(start=True).set_tags(**kwargs) + self.get_exp(start=True).get_recorder(start=True).set_tags(**kwargs) class RecorderWrapper(Wrapper): diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 01844b52828..f16c58ddbed 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -178,7 +178,7 @@ def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = Tr self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False, ) - if is_new and start: + if self.active_experiment is None and start: self.active_experiment = exp # start the recorder self.active_experiment.start()