@@ -184,6 +184,20 @@ def test_effect_shape(self):
184
184
185
185
def test_nuisance_model_has_weights (self ):
186
186
"""Test whether the correct exception is being raised if model_final doesn't have weights."""
187
+
188
+ # in sklearn 0.23 Lasso now supports weights
189
+ # so create a wrapper that doesn't
190
+ class NoWeightModel :
191
+ def __init__ (self ):
192
+ self .model = Lasso ()
193
+
194
+ def fit (self , X , y ):
195
+ self .model .fit (X , y )
196
+ return self
197
+
198
+ def predict (self , X ):
199
+ return self .model .predict (X )
200
+
187
201
# Generate data with continuous treatments
188
202
T = np .dot (TestOrthoForest .W [:, TestOrthoForest .support ], TestOrthoForest .coefs_T ) + \
189
203
TestOrthoForest .eta_sample (TestOrthoForest .n )
@@ -192,14 +206,14 @@ def test_nuisance_model_has_weights(self):
192
206
T * TE + TestOrthoForest .epsilon_sample (TestOrthoForest .n )
193
207
# Instantiate model with most of the default parameters
194
208
est = ContinuousTreatmentOrthoForest (n_jobs = 4 , n_trees = 10 ,
195
- model_T = Lasso (),
196
- model_Y = Lasso ())
209
+ model_T = NoWeightModel (),
210
+ model_Y = NoWeightModel ())
197
211
est .fit (Y = Y , T = T , X = TestOrthoForest .X , W = TestOrthoForest .W )
198
212
weights_error_msg = (
199
213
"Estimators of type {} do not accept weights. "
200
214
"Consider using the class WeightedModelWrapper from econml.utilities to build a weighted model."
201
215
)
202
- self .assertRaisesRegexp (TypeError , weights_error_msg .format ("Lasso " ),
216
+ self .assertRaisesRegexp (TypeError , weights_error_msg .format ("NoWeightModel " ),
203
217
est .effect , X = TestOrthoForest .X )
204
218
205
219
def _test_te (self , learner_instance , expected_te , tol , treatment_type = 'continuous' ):
0 commit comments