@@ -34,18 +34,22 @@ class TestDML(unittest.TestCase):
34
34
35
35
def test_cate_api (self ):
36
36
"""Test that we correctly implement the CATE API."""
37
- n = 20
37
+ n_c = 20 # number of rows for continuous models
38
+ n_d = 30 # number of rows for discrete models
38
39
39
- def make_random (is_discrete , d ):
40
+ def make_random (n , is_discrete , d ):
40
41
if d is None :
41
42
return None
42
43
sz = (n , d ) if d >= 0 else (n ,)
43
44
if is_discrete :
44
45
while True :
45
46
arr = np .random .choice (['a' , 'b' , 'c' ], size = sz )
46
- # ensure that we've got at least two of every element
47
+ # ensure that we've got at least 6 of every element
48
+ # 2 outer splits, 3 inner splits when model_t is 'auto' and treatment is discrete
49
+ # NOTE: this number may need to change if the default number of folds in
50
+ # WeightedStratifiedKFold changes
47
51
_ , counts = np .unique (arr , return_counts = True )
48
- if len (counts ) == 3 and counts .min () > 1 :
52
+ if len (counts ) == 3 and counts .min () > 5 :
49
53
return arr
50
54
else :
51
55
return np .random .normal (size = sz )
@@ -55,7 +59,8 @@ def make_random(is_discrete, d):
55
59
for d_y in [3 , 1 , - 1 ]:
56
60
for d_x in [2 , None ]:
57
61
for d_w in [2 , None ]:
58
- W , X , Y , T = [make_random (is_discrete , d )
62
+ n = n_d if is_discrete else n_c
63
+ W , X , Y , T = [make_random (n , is_discrete , d )
59
64
for is_discrete , d in [(False , d_w ),
60
65
(False , d_x ),
61
66
(False , d_y ),
@@ -699,7 +704,7 @@ def test_can_custom_splitter(self):
699
704
def test_can_use_featurizer (self ):
700
705
"Test that we can use a featurizer, and that fit is only called during training"
701
706
dml = LinearDMLCateEstimator (LinearRegression (), LinearRegression (),
702
- fit_cate_intercept = False , featurizer = OneHotEncoder (n_values = 'auto' , sparse = False ))
707
+ fit_cate_intercept = False , featurizer = OneHotEncoder (sparse = False ))
703
708
704
709
T = np .tile ([1 , 2 , 3 ], 6 )
705
710
Y = np .array ([1 , 2 , 3 , 1 , 2 , 3 ])
0 commit comments