29
29
from ...sklearn_extensions .linear_model import StatsModelsLinearRegression , DebiasedLasso , WeightedLassoCVWrapper
30
30
from ...sklearn_extensions .model_selection import WeightedStratifiedKFold
31
31
from ...utilities import (_deprecate_positional , add_intercept , filter_none_kwargs ,
32
- inverse_onehot , get_feature_names_or_default , check_high_dimensional )
32
+ inverse_onehot , get_feature_names_or_default , check_high_dimensional , check_input_arrays )
33
33
from ...grf import RegressionForest
34
34
from ...dml .dml import _FirstStageWrapper , _FinalWrapper
35
35
from ..._shap import _shap_explain_model_cate
@@ -176,7 +176,6 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None)
176
176
Y_pred = Y_pred .reshape (Y .shape )
177
177
T_pred = T_pred .reshape (T .shape )
178
178
TZ_pred = TZ_pred .reshape (T .shape )
179
- prel_theta = prel_theta .reshape (Y .shape )
180
179
181
180
Y_res = Y - Y_pred
182
181
T_res = T - T_pred
@@ -196,8 +195,7 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None)
196
195
cov = TZ_pred - T_pred * Z_pred
197
196
198
197
# check nuisances outcome shape
199
- assert prel_theta .ndim == 1 , "Nuisance outcome should be vector!"
200
- assert Y_res .ndim == 1 , "Nuisance outcome should be vector!"
198
+ # Y_res could be a vector or 1-dimensional 2d-array
201
199
assert T_res .ndim == 1 , "Nuisance outcome should be vector!"
202
200
assert Z_res .ndim == 1 , "Nuisance outcome should be vector!"
203
201
assert cov .ndim == 1 , "Nuisance outcome should be vector!"
@@ -340,21 +338,19 @@ def _gen_ortho_learner_model_final(self):
340
338
self .cov_clip , self .opt_reweighted )
341
339
342
340
def _check_inputs (self , Y , T , Z , X , W ):
343
- if len (Y .shape ) > 1 and Y .shape [1 ] > 1 :
341
+ Y1 , T1 , Z1 , = check_input_arrays (Y , T , Z )
342
+ if len (Y1 .shape ) > 1 and Y1 .shape [1 ] > 1 :
344
343
raise AssertionError ("DRIV only supports single dimensional outcome" )
345
- if len (T .shape ) > 1 and T .shape [1 ] > 1 :
344
+ if len (T1 .shape ) > 1 and T1 .shape [1 ] > 1 :
346
345
if self .discrete_treatment :
347
346
raise AttributeError ("DRIV only supports binary treatments" )
348
347
else :
349
348
raise AttributeError ("DRIV only supports single-dimensional continuous treatments" )
350
- if len (Z .shape ) > 1 and Z .shape [1 ] > 1 :
349
+ if len (Z1 .shape ) > 1 and Z1 .shape [1 ] > 1 :
351
350
if self .discrete_instrument :
352
351
raise AttributeError ("DRIV only supports binary instruments" )
353
352
else :
354
353
raise AttributeError ("DRIV only supports single-dimensional continuous instruments" )
355
- Z = Z .ravel ()
356
- T = T .ravel ()
357
- Y = Y .ravel ()
358
354
return Y , T , Z , X , W
359
355
360
356
@_deprecate_positional ("X and W should be passed by keyword only. In a future release "
@@ -1746,19 +1742,14 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None)
1746
1742
Z_pred = Z_pred .reshape (Z .shape )
1747
1743
T_pred_one = T_pred_one .reshape (T .shape )
1748
1744
T_pred_zero = T_pred_zero .reshape (T .shape )
1749
- prel_theta = prel_theta .reshape (Y .shape )
1750
1745
1746
+ # T_res, Z_res, beta expect shape to be (n,1)
1751
1747
beta = Z_pred * (1 - Z_pred ) * (T_pred_one - T_pred_zero )
1752
1748
T_pred = T_pred_one * Z_pred + T_pred_zero * (1 - Z_pred )
1753
1749
Y_res = Y - Y_pred
1754
1750
T_res = T - T_pred
1755
1751
Z_res = Z - Z_pred
1756
1752
1757
- # check nuisances outcome shape
1758
- # T_res, Z_res, beta expect shape to be (n,1)
1759
- assert prel_theta .ndim == 1 , "Nuisance outcome should be vector!"
1760
- assert Y_res .ndim == 1 , "Nuisance outcome should be vector!"
1761
-
1762
1753
return prel_theta , Y_res , T_res , Z_res , beta
1763
1754
1764
1755
0 commit comments