Skip to content

Commit 00a4a94

Browse files
committed
fix test error
1 parent 5ceedec commit 00a4a94

File tree

5 files changed

+333
-392
lines changed

5 files changed

+333
-392
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,8 @@ from econml.iv.dr import LinearIntentToTreatDRIV
290290
from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier
291291
from sklearn.linear_model import LinearRegression
292292

293-
est = LinearIntentToTreatDRIV(model_Y_X=GradientBoostingRegressor(),
294-
model_T_XZ=GradientBoostingClassifier(),
293+
est = LinearIntentToTreatDRIV(model_y_xw=GradientBoostingRegressor(),
294+
model_t_xwz=GradientBoostingClassifier(),
295295
flexible_model_effect=GradientBoostingRegressor())
296296
est.fit(Y, T, Z=Z, X=X) # OLS inference by default
297297
treatment_effects = est.effect(X_test)

econml/iv/dr/_dr.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ...sklearn_extensions.linear_model import StatsModelsLinearRegression, DebiasedLasso, WeightedLassoCVWrapper
3030
from ...sklearn_extensions.model_selection import WeightedStratifiedKFold
3131
from ...utilities import (_deprecate_positional, add_intercept, filter_none_kwargs,
32-
inverse_onehot, get_feature_names_or_default, check_high_dimensional)
32+
inverse_onehot, get_feature_names_or_default, check_high_dimensional, check_input_arrays)
3333
from ...grf import RegressionForest
3434
from ...dml.dml import _FirstStageWrapper, _FinalWrapper
3535
from ..._shap import _shap_explain_model_cate
@@ -176,7 +176,6 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None)
176176
Y_pred = Y_pred.reshape(Y.shape)
177177
T_pred = T_pred.reshape(T.shape)
178178
TZ_pred = TZ_pred.reshape(T.shape)
179-
prel_theta = prel_theta.reshape(Y.shape)
180179

181180
Y_res = Y - Y_pred
182181
T_res = T - T_pred
@@ -196,8 +195,7 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None)
196195
cov = TZ_pred - T_pred * Z_pred
197196

198197
# check nuisances outcome shape
199-
assert prel_theta.ndim == 1, "Nuisance outcome should be vector!"
200-
assert Y_res.ndim == 1, "Nuisance outcome should be vector!"
198+
# Y_res could be a vector or 1-dimensional 2d-array
201199
assert T_res.ndim == 1, "Nuisance outcome should be vector!"
202200
assert Z_res.ndim == 1, "Nuisance outcome should be vector!"
203201
assert cov.ndim == 1, "Nuisance outcome should be vector!"
@@ -340,21 +338,19 @@ def _gen_ortho_learner_model_final(self):
340338
self.cov_clip, self.opt_reweighted)
341339

342340
def _check_inputs(self, Y, T, Z, X, W):
343-
if len(Y.shape) > 1 and Y.shape[1] > 1:
341+
Y1, T1, Z1, = check_input_arrays(Y, T, Z)
342+
if len(Y1.shape) > 1 and Y1.shape[1] > 1:
344343
raise AssertionError("DRIV only supports single dimensional outcome")
345-
if len(T.shape) > 1 and T.shape[1] > 1:
344+
if len(T1.shape) > 1 and T1.shape[1] > 1:
346345
if self.discrete_treatment:
347346
raise AttributeError("DRIV only supports binary treatments")
348347
else:
349348
raise AttributeError("DRIV only supports single-dimensional continuous treatments")
350-
if len(Z.shape) > 1 and Z.shape[1] > 1:
349+
if len(Z1.shape) > 1 and Z1.shape[1] > 1:
351350
if self.discrete_instrument:
352351
raise AttributeError("DRIV only supports binary instruments")
353352
else:
354353
raise AttributeError("DRIV only supports single-dimensional continuous instruments")
355-
Z = Z.ravel()
356-
T = T.ravel()
357-
Y = Y.ravel()
358354
return Y, T, Z, X, W
359355

360356
@_deprecate_positional("X and W should be passed by keyword only. In a future release "
@@ -1746,19 +1742,14 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None)
17461742
Z_pred = Z_pred.reshape(Z.shape)
17471743
T_pred_one = T_pred_one.reshape(T.shape)
17481744
T_pred_zero = T_pred_zero.reshape(T.shape)
1749-
prel_theta = prel_theta.reshape(Y.shape)
17501745

1746+
# T_res, Z_res, beta expect shape to be (n,1)
17511747
beta = Z_pred * (1 - Z_pred) * (T_pred_one - T_pred_zero)
17521748
T_pred = T_pred_one * Z_pred + T_pred_zero * (1 - Z_pred)
17531749
Y_res = Y - Y_pred
17541750
T_res = T - T_pred
17551751
Z_res = Z - Z_pred
17561752

1757-
# check nuisances outcome shape
1758-
# T_res, Z_res, beta expect shape to be (n,1)
1759-
assert prel_theta.ndim == 1, "Nuisance outcome should be vector!"
1760-
assert Y_res.ndim == 1, "Nuisance outcome should be vector!"
1761-
17621753
return prel_theta, Y_res, T_res, Z_res, beta
17631754

17641755

econml/ortho_iv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
from .utilities import deprecated
77

88

9-
@deprecated("The econml.ortho_iv.DMLATEIV class has been moved to econml.iv.dml.DMLATEIV; "
9+
@deprecated("The econml.ortho_iv.DMLATEIV class has been moved to econml.iv.dml.OrthoIV; "
1010
"an upcoming release will remove support for the old name")
11-
class DMLATEIV(dmliv.DMLATEIV):
11+
class DMLATEIV(dmliv.OrthoIV):
1212
pass
1313

1414

15-
@deprecated("The econml.ortho_iv.ProjectedDMLATEIV class has been moved to econml.iv.dml.ProjectedDMLATEIV; "
15+
@deprecated("The econml.ortho_iv.ProjectedDMLATEIV class has been moved to econml.iv.dml.OrthoIV; "
1616
"an upcoming release will remove support for the old name")
17-
class ProjectedDMLATEIV(dmliv.ProjectedDMLATEIV):
17+
class ProjectedDMLATEIV(dmliv.OrthoIV):
1818
pass
1919

2020

econml/tests/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier
1717
from sklearn.linear_model import LinearRegression, MultiTaskLasso, LassoCV
1818
from sklearn.preprocessing import PolynomialFeatures, FunctionTransformer
19-
from econml.ortho_iv import LinearIntentToTreatDRIV
19+
from econml.iv.dr import LinearIntentToTreatDRIV
2020
from econml.deepiv import DeepIVEstimator
2121

2222

notebooks/CustomerScenarios/Case Study - Recommendation AB Testing at An Online Travel Company - EconML + DoWhy.ipynb

Lines changed: 319 additions & 369 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)