|
| 1 | +# Copyright (c) PyWhy contributors. All rights reserved. |
| 2 | +# Licensed under the MIT License. |
| 3 | +import pytest |
| 4 | +import unittest |
| 5 | +import numpy as np |
| 6 | +from sklearn.linear_model import LinearRegression |
| 7 | +from sklearn.pipeline import make_pipeline |
| 8 | +from sklearn.impute import SimpleImputer |
| 9 | + |
| 10 | +from econml._ortho_learner import _OrthoLearner |
| 11 | +from econml.dml import LinearDML, CausalForestDML |
| 12 | +from econml.panel.dml import DynamicDML |
| 13 | + |
| 14 | + |
| 15 | +class ModelNuisance: |
| 16 | + def __init__(self, model_t, model_y): |
| 17 | + self._model_t = model_t |
| 18 | + self._model_y = model_y |
| 19 | + |
| 20 | + def fit(self, Y, T, W=None): |
| 21 | + self._model_t.fit(W, T) |
| 22 | + self._model_y.fit(W, Y) |
| 23 | + return self |
| 24 | + |
| 25 | + def predict(self, Y, T, W=None): |
| 26 | + return Y - self._model_y.predict(W), T - self._model_t.predict(W) |
| 27 | + |
| 28 | + |
| 29 | +class ModelFinal: |
| 30 | + |
| 31 | + def __init__(self): |
| 32 | + return |
| 33 | + |
| 34 | + def fit(self, Y, T, W=None, nuisances=None): |
| 35 | + Y_res, T_res = nuisances |
| 36 | + self.model = LinearRegression(fit_intercept=False).fit(T_res.reshape(-1, 1), Y_res) |
| 37 | + return self |
| 38 | + |
| 39 | + def predict(self): |
| 40 | + # theta needs to be of dimension (1, d_t) if T is (n, d_t) |
| 41 | + return np.array([[self.model.coef_[0]]]) |
| 42 | + |
| 43 | + def score(self, Y, T, W=None, nuisances=None): |
| 44 | + Y_res, T_res = nuisances |
| 45 | + return np.mean((Y_res - self.model.predict(T_res.reshape(-1, 1)))**2) |
| 46 | + |
| 47 | + |
| 48 | +class OrthoLearner(_OrthoLearner): |
| 49 | + def _gen_ortho_learner_model_nuisance(self): |
| 50 | + return ModelNuisance( |
| 51 | + make_pipeline(SimpleImputer(strategy='mean'), LinearRegression()), |
| 52 | + make_pipeline(SimpleImputer(strategy='mean'), LinearRegression()) |
| 53 | + ) |
| 54 | + |
| 55 | + def _gen_ortho_learner_model_final(self): |
| 56 | + return ModelFinal() |
| 57 | + |
| 58 | + |
| 59 | +class TestTreatmentFeaturization(unittest.TestCase): |
| 60 | + |
| 61 | + def test_missing(self): |
| 62 | + # create data with missing values |
| 63 | + np.random.seed(123) |
| 64 | + X = np.random.normal(size=(1000, 1)) |
| 65 | + W = np.random.normal(size=(1000, 5)) |
| 66 | + T = X[:, 0] + np.random.normal(size=(1000,)) |
| 67 | + y = (1 + 0.5 * X[:, 0]) * T + X[:, 0] + np.random.normal(size=(1000,)) |
| 68 | + mask = np.random.rand(*W.shape) < 0.05 |
| 69 | + W_missing = W.copy() |
| 70 | + W_missing[mask] = np.nan |
| 71 | + groups = np.repeat(np.arange(500), 2) # groups for dynamic dml |
| 72 | + |
| 73 | + # model that can handle missing values |
| 74 | + nuisance_model = make_pipeline(SimpleImputer(strategy='mean'), LinearRegression()) |
| 75 | + OrthoLearner(discrete_treatment=False, treatment_featurizer=None, discrete_instrument=None, |
| 76 | + categories='auto', cv=3, random_state=1).fit(y, T, W=W_missing) |
| 77 | + |
| 78 | + CausalForestDML(model_y=nuisance_model, model_t=nuisance_model).fit(y, T, X=X, W=W_missing) |
| 79 | + |
| 80 | + DynamicDML(model_y=nuisance_model, model_t=nuisance_model).fit(y, T, W=W_missing, groups=groups) |
| 81 | + |
| 82 | + LinearDML(model_y=nuisance_model, model_t=nuisance_model).dowhy.fit(y, T, X=X, W=W_missing) |
0 commit comments