Skip to content

Commit 994f893

Browse files
authored
Optimize the implementation of uri & Fix async log bug (#1364)
* Optimize the implementation of uri * remove redundant func * Set the right order of _set_client_uri * Update qlib/workflow/expm.py * Simplify client & add test.Add docs; Fix async bug * Fix comments & pylint * Improve README
1 parent b51e881 commit 994f893

File tree

7 files changed

+94
-63
lines changed

7 files changed

+94
-63
lines changed

qlib/workflow/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from .recorder import Recorder
99
from ..utils import Wrapper
1010
from ..utils.exceptions import RecorderInitializationError
11-
from qlib.config import C
1211

1312

1413
class QlibRecorder:
@@ -347,14 +346,14 @@ def get_uri(self):
347346

348347
def set_uri(self, uri: Optional[Text]):
349348
"""
350-
Method to reset the current uri of current experiment manager.
349+
Method to reset the **default** uri of current experiment manager.
351350
352351
NOTE:
353352
354353
- When the uri is refer to a file path, please using the absolute path instead of strings like "~/mlruns/"
355354
The backend don't support strings like this.
356355
"""
357-
self.exp_manager.set_uri(uri)
356+
self.exp_manager.default_uri = uri
358357

359358
@contextmanager
360359
def uri_context(self, uri: Text):
@@ -370,11 +369,11 @@ def uri_context(self, uri: Text):
370369
the temporal uri
371370
"""
372371
prev_uri = self.exp_manager.default_uri
373-
C.exp_manager["kwargs"]["uri"] = uri
372+
self.exp_manager.default_uri = uri
374373
try:
375374
yield
376375
finally:
377-
C.exp_manager["kwargs"]["uri"] = prev_uri
376+
self.exp_manager.default_uri = prev_uri
378377

379378
def get_recorder(
380379
self,

qlib/workflow/exp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,6 @@ class MLflowExperiment(Experiment):
249249
def __init__(self, id, name, uri):
250250
super(MLflowExperiment, self).__init__(id, name)
251251
self._uri = uri
252-
self._default_name = None
253252
self._default_rec_name = "mlflow_recorder"
254253
self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
255254

qlib/workflow/expm.py

Lines changed: 50 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,32 @@
1515
from ..log import get_module_logger
1616
from ..utils.exceptions import ExpAlreadyExistError
1717

18+
1819
logger = get_module_logger("workflow")
1920

2021

2122
class ExpManager:
2223
"""
23-
This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow.
24-
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
24+
This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow.
25+
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
26+
27+
The `ExpManager` is expected to be a singleton (btw, we can have multiple `Experiment`s with different uri. user can get different experiments from different uri, and then compare records of them). Global Config (i.e. `C`) is also a singleton.
28+
So we try to align them together. They share the same variable, which is called **default uri**. Please refer to `ExpManager.default_uri` for details of variable sharing.
29+
30+
When the user starts an experiment, the user may want to set the uri to a specific uri (it will override **default uri** during this period), and then unset the **specific uri** and fallback to the **default uri**. `ExpManager._active_exp_uri` is that **specific uri**.
2531
"""
2632

33+
active_experiment: Optional[Experiment]
34+
2735
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
28-
self._current_uri = uri
36+
self.default_uri = uri
37+
self._active_exp_uri = None # No active experiments. So it is set to None
2938
self._default_exp_name = default_exp_name
3039
self.active_experiment = None # only one experiment can be active each time
31-
logger.info(f"experiment manager uri is at {self._current_uri}")
40+
logger.info(f"experiment manager uri is at {self.uri}")
3241

3342
def __repr__(self):
34-
return "{name}(current_uri={curi})".format(name=self.__class__.__name__, curi=self._current_uri)
43+
return "{name}(uri={uri})".format(name=self.__class__.__name__, uri=self.uri)
3544

3645
def start_exp(
3746
self,
@@ -43,11 +52,13 @@ def start_exp(
4352
uri: Optional[Text] = None,
4453
resume: bool = False,
4554
**kwargs,
46-
):
55+
) -> Experiment:
4756
"""
4857
Start an experiment. This method includes first get_or_create an experiment, and then
4958
set it to be active.
5059
60+
Maintaining `_active_exp_uri` is included in start_exp, remaining implementation should be included in _end_exp in subclass
61+
5162
Parameters
5263
----------
5364
experiment_id : str
@@ -67,19 +78,41 @@ def start_exp(
6778
-------
6879
An active experiment.
6980
"""
81+
self._active_exp_uri = uri
82+
# The subclass may set the underlying uri back.
83+
# So setting `_active_exp_uri` come before `_start_exp`
84+
return self._start_exp(
85+
experiment_id=experiment_id,
86+
experiment_name=experiment_name,
87+
recorder_id=recorder_id,
88+
recorder_name=recorder_name,
89+
resume=resume,
90+
**kwargs,
91+
)
92+
93+
def _start_exp(self, *args, **kwargs) -> Experiment:
94+
"""Please refer to the doc of `start_exp`"""
7095
raise NotImplementedError(f"Please implement the `start_exp` method.")
7196

7297
def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs):
7398
"""
7499
End an active experiment.
75100
101+
Maintaining `_active_exp_uri` is included in end_exp, remaining implementation should be included in _end_exp in subclass
102+
76103
Parameters
77104
----------
78105
experiment_name : str
79106
name of the active experiment.
80107
recorder_status : str
81108
the status of the active recorder of the experiment.
82109
"""
110+
self._active_exp_uri = None
111+
# The subclass may set the underlying uri back.
112+
# So setting `_active_exp_uri` come before `_end_exp`
113+
self._end_exp(recorder_status=recorder_status, **kwargs)
114+
115+
def _end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs):
83116
raise NotImplementedError(f"Please implement the `end_exp` method.")
84117

85118
def create_exp(self, experiment_name: Optional[Text] = None):
@@ -254,6 +287,10 @@ def default_uri(self):
254287
raise ValueError("The default URI is not set in qlib.config.C")
255288
return C.exp_manager["kwargs"]["uri"]
256289

290+
@default_uri.setter
291+
def default_uri(self, value):
292+
C.exp_manager.setdefault("kwargs", {})["uri"] = value
293+
257294
@property
258295
def uri(self):
259296
"""
@@ -263,33 +300,7 @@ def uri(self):
263300
-------
264301
The tracking URI string.
265302
"""
266-
return self._current_uri or self.default_uri
267-
268-
def set_uri(self, uri: Optional[Text] = None):
269-
"""
270-
Set the current tracking URI and the corresponding variables.
271-
272-
Parameters
273-
----------
274-
uri : str
275-
276-
"""
277-
if uri is None:
278-
if self._current_uri is None:
279-
logger.debug("No tracking URI is provided. Use the default tracking URI.")
280-
self._current_uri = self.default_uri
281-
else:
282-
# Temporarily re-set the current uri as the uri argument.
283-
self._current_uri = uri
284-
# Customized features for subclasses.
285-
self._set_uri()
286-
287-
def _set_uri(self):
288-
"""
289-
Customized features for subclasses' set_uri function.
290-
This method is designed for the underlying experiment backend storage.
291-
"""
292-
raise NotImplementedError(f"Please implement the `_set_uri` method.")
303+
return self._active_exp_uri or self.default_uri
293304

294305
def list_experiments(self):
295306
"""
@@ -307,33 +318,21 @@ class MLflowExpManager(ExpManager):
307318
Use mlflow to implement ExpManager.
308319
"""
309320

310-
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
311-
super(MLflowExpManager, self).__init__(uri, default_exp_name)
312-
self._client = None
313-
314-
def _set_uri(self):
315-
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
316-
logger.info("{:}".format(self._client))
317-
318321
@property
319322
def client(self):
320-
# Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib
321-
if self._client is None:
322-
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
323-
return self._client
323+
# Please refer to `tests/dependency_tests/test_mlflow.py::MLflowTest::test_creating_client`
324+
# The test ensure the speed of create a new client
325+
return mlflow.tracking.MlflowClient(tracking_uri=self.uri)
324326

325-
def start_exp(
327+
def _start_exp(
326328
self,
327329
*,
328330
experiment_id: Optional[Text] = None,
329331
experiment_name: Optional[Text] = None,
330332
recorder_id: Optional[Text] = None,
331333
recorder_name: Optional[Text] = None,
332-
uri: Optional[Text] = None,
333334
resume: bool = False,
334335
):
335-
# Set the tracking uri
336-
self.set_uri(uri)
337336
# Create experiment
338337
if experiment_name is None:
339338
experiment_name = self._default_exp_name
@@ -345,12 +344,10 @@ def start_exp(
345344

346345
return self.active_experiment
347346

348-
def end_exp(self, recorder_status: Text = Recorder.STATUS_S):
347+
def _end_exp(self, recorder_status: Text = Recorder.STATUS_S):
349348
if self.active_experiment is not None:
350349
self.active_experiment.end(recorder_status)
351350
self.active_experiment = None
352-
# When an experiment end, we will release the current uri.
353-
self._current_uri = None
354351

355352
def create_exp(self, experiment_name: Optional[Text] = None):
356353
assert experiment_name is not None
@@ -362,9 +359,7 @@ def create_exp(self, experiment_name: Optional[Text] = None):
362359
raise ExpAlreadyExistError() from e
363360
raise e
364361

365-
experiment = MLflowExperiment(experiment_id, experiment_name, self.uri)
366-
experiment._default_name = self._default_exp_name
367-
return experiment
362+
return MLflowExperiment(experiment_id, experiment_name, self.uri)
368363

369364
def _get_exp(self, experiment_id=None, experiment_name=None):
370365
"""

qlib/workflow/recorder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,14 +378,15 @@ def end_run(self, status: str = Recorder.STATUS_S):
378378
Recorder.STATUS_FI,
379379
Recorder.STATUS_FA,
380380
], f"The status type {status} is not supported."
381-
mlflow.end_run(status)
382381
self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
383382
if self.status != Recorder.STATUS_S:
384383
self.status = status
385384
if self.async_log is not None:
385+
# Waiting Queue should go before mlflow.end_run. Otherwise mlflow will raise error
386386
with TimeInspector.logt("waiting `async_log`"):
387387
self.async_log.wait()
388388
self.async_log = None
389+
mlflow.end_run(status)
389390

390391
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
391392
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_version(rel_path: str) -> str:
6262
"matplotlib>=3.3",
6363
"tables>=3.6.1",
6464
"pyyaml>=5.3.1",
65-
"mlflow>=1.12.1",
65+
"mlflow>=1.12.1, <=1.30.0",
6666
"tqdm",
6767
"loguru",
6868
"lightgbm>=3.3.0",

tests/dependency_tests/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Some implementations of Qlib depend on some assumptions of its dependencies.
2+
3+
So some tests are requried to ensure that these assumptions are valid.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
import unittest
4+
import mlflow
5+
import time
6+
from pathlib import Path
7+
import shutil
8+
9+
10+
class MLflowTest(unittest.TestCase):
11+
TMP_PATH = Path("./.mlruns_tmp/")
12+
13+
def tearDown(self) -> None:
14+
if self.TMP_PATH.exists():
15+
shutil.rmtree(self.TMP_PATH)
16+
17+
def test_creating_client(self):
18+
"""
19+
Please refer to qlib/workflow/expm.py:MLflowExpManager._client
20+
we don't cache _client (this is helpful to reduce maintainance work when MLflowExpManager's uri is chagned)
21+
22+
This implementation is based on the assumption creating a client is fast
23+
"""
24+
start = time.time()
25+
for i in range(10):
26+
_ = mlflow.tracking.MlflowClient(tracking_uri=str(self.TMP_PATH))
27+
end = time.time()
28+
elasped = end - start
29+
self.assertLess(elasped, 1e-2) # it can be done in less than 10ms
30+
print(elasped)
31+
32+
33+
if __name__ == "__main__":
34+
unittest.main()

0 commit comments

Comments
 (0)