Skip to content

Commit 1e257f9

Browse files
authored
Add support for CausalML models (#282)
* Add new manager for CausalML models * Update requirements-dev1.txt * Add unit tests for CausalML model manager * Add examples for Causal ML model manager * Fix formatting * Add verbose -1 param for LGBM to filter training logs * Update unit test for managers * Add tf-keras as dev dependendy to resolve transformers compatibility issue * Pin numpy version compatible with mxnet * Update optional depedencies for causal ML manager
1 parent b061c8b commit 1e257f9

File tree

10 files changed

+332
-4
lines changed

10 files changed

+332
-4
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2024 Neal Lathia
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pandas as pd
16+
17+
from libraries.util.datasets import load_causal_regression_dataset
18+
from libraries.util.domains import DIABETES_DOMAIN
19+
from causalml.inference.meta import XGBTRegressor, BaseSRegressor
20+
from causalml.metrics import qini_score
21+
from lightgbm.sklearn import LGBMRegressor
22+
23+
from modelstore.model_store import ModelStore
24+
25+
26+
def _train_example_model() -> XGBTRegressor:
27+
X_train, X_test, y_train, y_test, treatment_vector_train, treatment_vector_test = load_causal_regression_dataset()
28+
29+
params = {
30+
"n_estimators": 250,
31+
"max_depth": 4,
32+
"learning_rate": 0.01,
33+
"n_jobs": 1,
34+
"verbose": -1
35+
}
36+
37+
# Train causal regressor
38+
lgbm = LGBMRegressor(**params)
39+
model = BaseSRegressor(learner=lgbm)
40+
model.fit(X_train, treatment_vector_train, y_train)
41+
42+
X_test = pd.DataFrame(X_test)
43+
X_test["causal_scores"] = model.predict(X_test)
44+
X_test["outcomes"] = y_test
45+
X_test["treatment"] = treatment_vector_test
46+
47+
result = qini_score(
48+
X_test[["causal_scores", "outcomes", "treatment"]],
49+
outcome_col="outcomes",
50+
treatment_col="treatment",
51+
)
52+
print(f"🔍 Trained model Qini score={result}.")
53+
return model
54+
55+
56+
def train_and_upload(modelstore: ModelStore) -> dict:
57+
# Train a causalml regressor
58+
model = _train_example_model()
59+
60+
# Upload the model to the model store
61+
print(f'⤴️ Uploading the causalml model to the "{DIABETES_DOMAIN}" domain.')
62+
meta_data = modelstore.upload(DIABETES_DOMAIN, model=model)
63+
return meta_data
64+
65+
66+
def load_and_test(modelstore: ModelStore, model_domain: str, model_id: str):
67+
# Load the model back into memory!
68+
print(f'⤵️ Loading the causalml "{model_domain}" domain model={model_id}')
69+
model = modelstore.load(model_domain, model_id)
70+
71+
# Run some example predictions
72+
_, X_test, _, y_test, _, treatment_vector_test = load_causal_regression_dataset()
73+
74+
X_test = pd.DataFrame(X_test)
75+
X_test["causal_scores"] = model.predict(X_test)
76+
X_test["outcomes"] = y_test
77+
X_test["treatment"] = treatment_vector_test
78+
79+
result = qini_score(
80+
X_test[["causal_scores", "outcomes", "treatment"]],
81+
outcome_col="outcomes",
82+
treatment_col="treatment",
83+
)
84+
print(f"🔍 Loaded model Qini score={result}.")

examples/examples-by-ml-library/libraries/util/datasets.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import pandas as pd
15+
import numpy as np
1516
import torch
1617
from sklearn.datasets import (
1718
fetch_20newsgroups,
@@ -50,6 +51,23 @@ def load_regression_dataframe():
5051
return df
5152

5253

54+
def load_causal_regression_dataset():
55+
X_train, X_test, y_train, y_test = load_regression_dataset()
56+
57+
# Dummy treatment vector to simulate experiment
58+
treatment_vector_train = np.zeros(X_train.shape[0])
59+
treatment_vector_test = np.zeros(X_test.shape[0])
60+
61+
# Simulating a 50% treatment / control split
62+
treatment_mask_train = int(X_train.shape[0] * 0.5)
63+
treatment_mask_test = int(X_test.shape[0] * 0.5)
64+
65+
treatment_vector_train[:treatment_mask_train] = 1
66+
treatment_vector_test[:treatment_mask_test] = 1
67+
68+
return X_train, X_test, y_train, y_test, treatment_vector_train, treatment_vector_test
69+
70+
5371
def load_classification_dataset():
5472
print(f"🔍 Loading the breast cancer dataset")
5573
databunch = load_breast_cancer()

examples/examples-by-ml-library/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from libraries import (
1818
annoy_example,
1919
catboost_example,
20+
causalml_example,
2021
fastai_example,
2122
gensim_example,
2223
keras_example,
@@ -50,6 +51,7 @@
5051
EXAMPLES = {
5152
"annoy": annoy_example,
5253
"catboost": catboost_example,
54+
"causalml": causalml_example,
5355
"dpt": dpt,
5456
"fastai": fastai_example,
5557
"file": raw_file_example,

examples/examples-by-ml-library/run-all.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
set -e
22
backends=( filesystem aws-s3 google-cloud-storage azure-container minio )
3-
frameworks=( annoy catboost fastai file gensim keras lightgbm \
3+
frameworks=( annoy catboost causalml fastai file gensim keras lightgbm \
44
onnx-sklearn onnx-lightgbm prophet pyspark pytorch pytorch-lightning \
55
sklearn sklearn-with-explainer sklearn-with-extras skorch xgboost xgboost-booster \
66
tensorflow hf-distilbert hf-gpt2-pt hf-gpt2-tf segment-anything yolov5 )

modelstore/models/causalml.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2024 Neal Lathia
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
from functools import partial
16+
from typing import Any
17+
18+
from modelstore.metadata import metadata
19+
from modelstore.models.common import load_joblib, save_joblib
20+
from modelstore.models.model_manager import ModelManager
21+
from modelstore.storage.storage import CloudStorage
22+
23+
MODEL_FILE = "model.joblib"
24+
25+
26+
class CausalMLManager(ModelManager):
27+
28+
"""
29+
Model persistence for Causal ML models:
30+
https://causalml.readthedocs.io/en/latest/index.html
31+
"""
32+
33+
NAME = "causalml"
34+
35+
def __init__(self, storage: CloudStorage = None):
36+
super().__init__(self.NAME, storage)
37+
38+
def required_dependencies(self) -> list:
39+
return ["causalml"]
40+
41+
def optional_dependencies(self) -> list:
42+
deps = super().optional_dependencies()
43+
return deps + ["Cython", "joblib"]
44+
45+
def _required_kwargs(self):
46+
return ["model"]
47+
48+
def matches_with(self, **kwargs) -> bool:
49+
# pylint: disable=import-outside-toplevel
50+
import causalml
51+
52+
return any(
53+
[
54+
isinstance(kwargs.get("model"), causalml.inference.meta.base.BaseLearner),
55+
isinstance(kwargs.get("model"), causalml.propensity.PropensityModel),
56+
]
57+
)
58+
59+
def _get_functions(self, **kwargs) -> list:
60+
if not self.matches_with(**kwargs):
61+
raise TypeError("This model is not a Causal ML model!")
62+
63+
return [partial(save_joblib, model=kwargs["model"], file_name=MODEL_FILE)]
64+
65+
def load(self, model_path: str, meta_data: metadata.Summary) -> Any:
66+
super().load(model_path, meta_data)
67+
68+
# @Future: check if loading into same version of joblib
69+
# as was used for saving
70+
file_name = os.path.join(model_path, MODEL_FILE)
71+
return load_joblib(file_name)

modelstore/models/managers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from modelstore.metadata.code.dependencies import module_exists
1717
from modelstore.models.annoy import AnnoyManager
1818
from modelstore.models.catboost import CatBoostManager
19+
from modelstore.models.causalml import CausalMLManager
1920
from modelstore.models.fastai import FastAIManager
2021
from modelstore.models.gensim import GensimManager
2122
from modelstore.models.lightgbm import LightGbmManager
@@ -41,6 +42,7 @@
4142
for m in [
4243
AnnoyManager,
4344
CatBoostManager,
45+
CausalMLManager,
4446
FastAIManager,
4547
ModelFileManager,
4648
GensimManager,
@@ -97,4 +99,3 @@ def get_manager(name: str, storage: CloudStorage = None) -> ModelManager:
9799
if not module_exists(x):
98100
raise ValueError(f"could not create manager for {name}: {x} not installed")
99101
return manager
100-

requirements-dev1.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ minio
1111
# Machine Learning
1212
annoy
1313
catboost
14+
causalml
1415
fastai # Note: 1.0.61 has different import paths!
1516
gensim
1617
Keras-Preprocessing
@@ -26,6 +27,7 @@ skl2onnx
2627
skorch
2728
tf-keras
2829
tensorflow
30+
tf-keras
2931
torch
3032
torchvision
3133
transformers

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ joblib>=1.0.0
55
requests>=2.23.0
66
tqdm>=4.54.1
77
click>=7.0
8-
numpy
8+
numpy==1.23.1

0 commit comments

Comments
 (0)