Skip to content

Commit 58fdc9f

Browse files
committed
Fix tests
1 parent 63dccf2 commit 58fdc9f

File tree

3 files changed

+29
-10
lines changed

3 files changed

+29
-10
lines changed

econml/tests/test_dml.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,22 @@ class TestDML(unittest.TestCase):
3434

3535
def test_cate_api(self):
3636
"""Test that we correctly implement the CATE API."""
37-
n = 20
37+
n_c = 20 # number of rows for continuous models
38+
n_d = 30 # number of rows for discrete models
3839

39-
def make_random(is_discrete, d):
40+
def make_random(n, is_discrete, d):
4041
if d is None:
4142
return None
4243
sz = (n, d) if d >= 0 else (n,)
4344
if is_discrete:
4445
while True:
4546
arr = np.random.choice(['a', 'b', 'c'], size=sz)
46-
# ensure that we've got at least two of every element
47+
# ensure that we've got at least 6 of every element
48+
# 2 outer splits, 3 inner splits when model_t is 'auto' and treatment is discrete
49+
# NOTE: this number may need to change if the default number of folds in
50+
# WeightedStratifiedKFold changes
4751
_, counts = np.unique(arr, return_counts=True)
48-
if len(counts) == 3 and counts.min() > 1:
52+
if len(counts) == 3 and counts.min() > 5:
4953
return arr
5054
else:
5155
return np.random.normal(size=sz)
@@ -55,7 +59,8 @@ def make_random(is_discrete, d):
5559
for d_y in [3, 1, -1]:
5660
for d_x in [2, None]:
5761
for d_w in [2, None]:
58-
W, X, Y, T = [make_random(is_discrete, d)
62+
n = n_d if is_discrete else n_c
63+
W, X, Y, T = [make_random(n, is_discrete, d)
5964
for is_discrete, d in [(False, d_w),
6065
(False, d_x),
6166
(False, d_y),
@@ -699,7 +704,7 @@ def test_can_custom_splitter(self):
699704
def test_can_use_featurizer(self):
700705
"Test that we can use a featurizer, and that fit is only called during training"
701706
dml = LinearDMLCateEstimator(LinearRegression(), LinearRegression(),
702-
fit_cate_intercept=False, featurizer=OneHotEncoder(n_values='auto', sparse=False))
707+
fit_cate_intercept=False, featurizer=OneHotEncoder(sparse=False))
703708

704709
T = np.tile([1, 2, 3], 6)
705710
Y = np.array([1, 2, 3, 1, 2, 3])

econml/tests/test_drlearner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ def test_sparse(self):
713713
y_lower, y_upper = sparse_dml.effect_interval(x_test, T0=0, T1=1)
714714
in_CI = ((y_lower < true_eff) & (true_eff < y_upper))
715715
# Check that a majority of true effects lie in the 5-95% CI
716-
self.assertTrue(in_CI.mean() > 0.8)
716+
self.assertGreater(in_CI.mean(), 0.8)
717717

718718
def _test_te(self, learner_instance, tol, te_type="const"):
719719
if te_type not in ["const", "heterogeneous"]:

econml/tests/test_orf.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,20 @@ def test_effect_shape(self):
184184

185185
def test_nuisance_model_has_weights(self):
186186
"""Test whether the correct exception is being raised if model_final doesn't have weights."""
187+
188+
# Create a wrapper around Lasso that doesn't support weights
189+
# since Lasso does natively support them starting in sklearn 0.23
190+
class NoWeightModel:
191+
def __init__(self):
192+
self.model = Lasso()
193+
194+
def fit(self, X, y):
195+
self.model.fit(X, y)
196+
return self
197+
198+
def predict(self, X):
199+
return self.model.predict(X)
200+
187201
# Generate data with continuous treatments
188202
T = np.dot(TestOrthoForest.W[:, TestOrthoForest.support], TestOrthoForest.coefs_T) + \
189203
TestOrthoForest.eta_sample(TestOrthoForest.n)
@@ -192,14 +206,14 @@ def test_nuisance_model_has_weights(self):
192206
T * TE + TestOrthoForest.epsilon_sample(TestOrthoForest.n)
193207
# Instantiate model with most of the default parameters
194208
est = ContinuousTreatmentOrthoForest(n_jobs=4, n_trees=10,
195-
model_T=Lasso(),
196-
model_Y=Lasso())
209+
model_T=NoWeightModel(),
210+
model_Y=NoWeightModel())
197211
est.fit(Y=Y, T=T, X=TestOrthoForest.X, W=TestOrthoForest.W)
198212
weights_error_msg = (
199213
"Estimators of type {} do not accept weights. "
200214
"Consider using the class WeightedModelWrapper from econml.utilities to build a weighted model."
201215
)
202-
self.assertRaisesRegexp(TypeError, weights_error_msg.format("Lasso"),
216+
self.assertRaisesRegexp(TypeError, weights_error_msg.format("NoWeightModel"),
203217
est.effect, X=TestOrthoForest.X)
204218

205219
def _test_te(self, learner_instance, expected_te, tol, treatment_type='continuous'):

0 commit comments

Comments
 (0)