Skip to content

Commit 00dd506

Browse files
committed
add tests for ests that handle missing in W
Signed-off-by: Fabio Vera <fabiovera@microsoft.com>
1 parent 782e252 commit 00dd506

1 file changed

Lines changed: 82 additions & 0 deletions

File tree

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) PyWhy contributors. All rights reserved.
2+
# Licensed under the MIT License.
3+
import pytest
4+
import unittest
5+
import numpy as np
6+
from sklearn.linear_model import LinearRegression
7+
from sklearn.pipeline import make_pipeline
8+
from sklearn.impute import SimpleImputer
9+
10+
from econml._ortho_learner import _OrthoLearner
11+
from econml.dml import LinearDML, CausalForestDML
12+
from econml.panel.dml import DynamicDML
13+
14+
15+
class ModelNuisance:
16+
def __init__(self, model_t, model_y):
17+
self._model_t = model_t
18+
self._model_y = model_y
19+
20+
def fit(self, Y, T, W=None):
21+
self._model_t.fit(W, T)
22+
self._model_y.fit(W, Y)
23+
return self
24+
25+
def predict(self, Y, T, W=None):
26+
return Y - self._model_y.predict(W), T - self._model_t.predict(W)
27+
28+
29+
class ModelFinal:
30+
31+
def __init__(self):
32+
return
33+
34+
def fit(self, Y, T, W=None, nuisances=None):
35+
Y_res, T_res = nuisances
36+
self.model = LinearRegression(fit_intercept=False).fit(T_res.reshape(-1, 1), Y_res)
37+
return self
38+
39+
def predict(self):
40+
# theta needs to be of dimension (1, d_t) if T is (n, d_t)
41+
return np.array([[self.model.coef_[0]]])
42+
43+
def score(self, Y, T, W=None, nuisances=None):
44+
Y_res, T_res = nuisances
45+
return np.mean((Y_res - self.model.predict(T_res.reshape(-1, 1)))**2)
46+
47+
48+
class OrthoLearner(_OrthoLearner):
49+
def _gen_ortho_learner_model_nuisance(self):
50+
return ModelNuisance(
51+
make_pipeline(SimpleImputer(strategy='mean'), LinearRegression()),
52+
make_pipeline(SimpleImputer(strategy='mean'), LinearRegression())
53+
)
54+
55+
def _gen_ortho_learner_model_final(self):
56+
return ModelFinal()
57+
58+
59+
class TestTreatmentFeaturization(unittest.TestCase):
60+
61+
def test_missing(self):
62+
# create data with missing values
63+
np.random.seed(123)
64+
X = np.random.normal(size=(1000, 1))
65+
W = np.random.normal(size=(1000, 5))
66+
T = X[:, 0] + np.random.normal(size=(1000,))
67+
y = (1 + 0.5 * X[:, 0]) * T + X[:, 0] + np.random.normal(size=(1000,))
68+
mask = np.random.rand(*W.shape) < 0.05
69+
W_missing = W.copy()
70+
W_missing[mask] = np.nan
71+
groups = np.repeat(np.arange(500), 2) # groups for dynamic dml
72+
73+
# model that can handle missing values
74+
nuisance_model = make_pipeline(SimpleImputer(strategy='mean'), LinearRegression())
75+
OrthoLearner(discrete_treatment=False, treatment_featurizer=None, discrete_instrument=None,
76+
categories='auto', cv=3, random_state=1).fit(y, T, W=W_missing)
77+
78+
CausalForestDML(model_y=nuisance_model, model_t=nuisance_model).fit(y, T, X=X, W=W_missing)
79+
80+
DynamicDML(model_y=nuisance_model, model_t=nuisance_model).fit(y, T, W=W_missing, groups=groups)
81+
82+
LinearDML(model_y=nuisance_model, model_t=nuisance_model).dowhy.fit(y, T, X=X, W=W_missing)

0 commit comments

Comments
 (0)