Skip to content

Commit 398cc66

Browse files
authored
Warn when loading a model with a different Python version (#192)
1 parent a414b66 commit 398cc66

29 files changed

+86
-30
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
**🆕 New functionality**
66

7+
`modelstore` will issue a warning if you `load()` a model with a different version of Python than the version that was used to train the model ([#192](https://github.com/operatorai/modelstore/pull/192)).
8+
79
You can now add any extra metadata to your model when uploading it, using `upload(domain, model, extra_metadata={ ... })` ([#185](https://github.com/operatorai/modelstore/pull/185)); if you want to upload extra _files_ with your model, then you should now use `extra_files=` instead of `extras=` ([#187](https://github.com/operatorai/modelstore/pull/187)).
810

911
**🐛 Bug fixes & general updates**

bin/_config

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
export PYTHON_VERSION=3.7.13
2-
#export PYTHON_VERSION=3.8.12
1+
#export PYTHON_VERSION=3.7.13
2+
export PYTHON_VERSION=3.8.12
33
#export PYTHON_VERSION=3.9.9
44

55
export VIRTUALENV_NAME=$(pyenv local) || ""

modelstore/metadata/code/code.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def generate(cls, deps_list: list, created: datetime = None) -> "Code":
4242
# control time stamps of mock model objects
4343
created = datetime.now()
4444
return Code(
45-
runtime=f"python:{runtime.get_python_version()}",
45+
runtime=runtime.get_python_version(),
4646
user=runtime.get_user(),
4747
created=created.strftime("%Y/%m/%d/%H:%M:%S"),
4848
dependencies=remove_nones(versioned_deps),

modelstore/metadata/code/runtime.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
def get_python_version() -> str:
1919
""" Returns the current python version """
2020
vers = sys.version_info
21-
return ".".join(str(x) for x in [vers.major, vers.minor, vers.micro])
21+
version = ".".join(str(x) for x in [vers.major, vers.minor, vers.micro])
22+
return f"python:{version}"
2223

2324

2425
def get_user() -> str:

modelstore/models/annoy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def get_params(self, **kwargs) -> dict:
6666
}
6767

6868
def load(self, model_path: str, meta_data: metadata.Summary) -> Any:
69+
super().load(model_path, meta_data)
70+
6971
# pylint: disable=import-outside-toplevel
7072
from annoy import AnnoyIndex
7173

modelstore/models/catboost.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def get_params(self, **kwargs) -> dict:
9292
return kwargs["model"].get_params()
9393

9494
def load(self, model_path: str, meta_data: metadata.Summary) -> Any:
95+
super().load(model_path, meta_data)
96+
9597
# pylint: disable=import-outside-toplevel
9698
import catboost
9799

modelstore/models/fastai.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from functools import partial
1616
from pathlib import Path
1717
from typing import Any
18+
import warnings
1819

1920
from modelstore.metadata import metadata
2021
from modelstore.models.model_manager import ModelManager
@@ -78,6 +79,8 @@ def _get_functions(self, **kwargs) -> list:
7879
]
7980

8081
def load(self, model_path: str, meta_data: metadata.Summary) -> Any:
82+
super().load(model_path, meta_data)
83+
8184
# pylint: disable=import-outside-toplevel
8285
import fastai
8386

@@ -88,11 +91,8 @@ def load(self, model_path: str, meta_data: metadata.Summary) -> Any:
8891

8992
version = meta_data.code.dependencies.get(FastAIManager.NAME, "?")
9093
if version != fastai.__version__:
91-
logger.warn(
92-
"Model was saved with fastai==%s, trying to load it with fastai==%s",
93-
version,
94-
fastai.__version__,
95-
)
94+
msg = f"Model was saved with fastai=={version}, loading it with fastai=={fastai.__version__}"
95+
warnings.warn(msg, RuntimeWarning)
9696

9797
model_file = _model_file_path(model_path)
9898
return load_learner(model_file)

modelstore/models/gensim.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def get_params(self, **kwargs) -> dict:
6868
return params
6969

7070
def load(self, model_path: str, meta_data: metadata.Summary) -> Any:
71+
super().load(model_path, meta_data)
72+
7173
# pylint: disable=import-outside-toplevel
7274
from gensim.models import Word2Vec
7375

modelstore/models/lightgbm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def get_params(self, **kwargs) -> dict:
6262
return kwargs["model"].params
6363

6464
def load(self, model_path: str, meta_data: metadata.Summary) -> Any:
65+
super().load(model_path, meta_data)
66+
6567
# pylint: disable=import-outside-toplevel
6668
import lightgbm as lgb
6769

modelstore/models/model_manager.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import numpy as np
2323
from modelstore.metadata import metadata
24+
from modelstore.metadata.code.runtime import get_python_version
2425
from modelstore.storage.storage import CloudStorage
2526

2627

@@ -88,10 +89,16 @@ def _required_kwargs(self) -> list:
8889

8990
@abstractmethod
9091
def load(self, model_path: str, meta_data: metadata.Summary) -> Any:
91-
"""
92-
Loads a model, stored in model_path, back into memory
93-
"""
94-
raise NotImplementedError()
92+
""" Loads a model, stored in model_path, back into memory """
93+
version = get_python_version()
94+
if meta_data is not None and meta_data.code is not None:
95+
if version != meta_data.code.runtime:
96+
train = f"model was trained with {meta_data.code.runtime}"
97+
load = f"but is being loaded with {version}"
98+
warnings.warn(
99+
f"{train}, {load}",
100+
category=RuntimeWarning,
101+
)
95102

96103
def _validate_kwargs(self, **kwargs):
97104
"""Ensures that the required kwargs are set"""

0 commit comments

Comments
 (0)