@@ -85,6 +85,8 @@ def test_basic_array(self):
85
85
# policy value should exceed always treating with any treatment
86
86
assert_less_close (np .array (list (always_trt .values ())), policy_val )
87
87
88
+ ind_pol = ca .individualized_policy (X , inds [idx ])
89
+
88
90
# global shape is (d_y, sum(d_t))
89
91
assert glo_point_est .shape == coh_point_est .shape == (1 , 5 )
90
92
assert loc_point_est .shape == (2 ,) + glo_point_est .shape
@@ -128,113 +130,121 @@ def test_basic_array(self):
128
130
129
131
def test_basic_pandas (self ):
130
132
for classification in [False , True ]:
131
- y = pd .Series (np .random .choice ([0 , 1 ], size = (500 ,)))
132
- X = pd .DataFrame ({'a' : np .random .normal (size = 500 ),
133
- 'b' : np .random .normal (size = 500 ),
134
- 'c' : np .random .choice ([0 , 1 ], size = 500 ),
135
- 'd' : np .random .choice (['a' , 'b' , 'c' ], size = 500 )})
136
- n_inds = [0 , 1 , 2 , 3 ]
137
- t_inds = ['a' , 'b' , 'c' , 'd' ]
138
- n_cats = [2 , 3 ]
139
- t_cats = ['c' , 'd' ]
140
- n_hinds = [0 , 3 ]
141
- t_hinds = ['a' , 'd' ]
142
- for (inds , cats , hinds ) in [(n_inds , n_cats , n_hinds ), (t_inds , t_cats , t_hinds )]:
143
- ca = CausalAnalysis (inds , cats , hinds , classification = classification )
144
- ca .fit (X , y )
145
- glo = ca .global_causal_effect ()
146
- coh = ca .cohort_causal_effect (X [:2 ])
147
- loc = ca .local_causal_effect (X [:2 ])
148
-
149
- # global and cohort data should have exactly the same structure, but different values
150
- assert glo .index .equals (coh .index )
151
-
152
- # local index should have as many times entries as global as there were rows passed in
153
- assert len (loc .index ) == 2 * len (glo .index )
154
-
155
- assert glo .index .names == ['feature' , 'feature_value' ]
156
- assert loc .index .names == ['sample' ] + glo .index .names
157
-
158
- # features; for categoricals they should appear #cats-1 times each
159
- fts = ['a' , 'b' , 'c' , 'd' , 'd' ]
160
-
161
- for i in range (len (fts )):
162
- assert fts [i ] == glo .index [i ][0 ] == loc .index [i ][1 ] == loc .index [len (fts ) + i ][1 ]
163
-
164
- glo_dict = ca ._global_causal_effect_dict ()
165
- glo_dict2 = ca ._global_causal_effect_dict (row_wise = True )
166
-
167
- coh_dict = ca ._cohort_causal_effect_dict (X [:2 ])
168
- coh_dict2 = ca ._cohort_causal_effect_dict (X [:2 ], row_wise = True )
169
-
170
- loc_dict = ca ._local_causal_effect_dict (X [:2 ])
171
- loc_dict2 = ca ._local_causal_effect_dict (X [:2 ], row_wise = True )
172
-
173
- glo_point_est = np .array (glo_dict [_CausalInsightsConstants .PointEstimateKey ])
174
- coh_point_est = np .array (coh_dict [_CausalInsightsConstants .PointEstimateKey ])
175
- loc_point_est = np .array (loc_dict [_CausalInsightsConstants .PointEstimateKey ])
176
-
177
- # global shape is (d_y, sum(d_t))
178
- assert glo_point_est .shape == coh_point_est .shape == (1 , 5 )
179
- assert loc_point_est .shape == (2 ,) + glo_point_est .shape
180
-
181
- # global and cohort row-wise dicts have d_y * d_t entries
182
- assert len (
183
- glo_dict2 [_CausalInsightsConstants .RowData ]) == len (
184
- coh_dict2 [_CausalInsightsConstants .RowData ]) == 5
185
- # local dictionary is flattened to n_rows * d_y * d_t
186
- assert len (loc_dict2 [_CausalInsightsConstants .RowData ]) == 10
187
-
188
- pto = ca ._policy_tree_output (X , inds [1 ])
189
- ca ._heterogeneity_tree_output (X , inds [1 ])
190
- ca ._heterogeneity_tree_output (X , inds [3 ])
191
-
192
- # continuous treatments have typical treatment values equal to
193
- # the mean of the absolute value of non-zero entries
194
- np .testing .assert_allclose (ca .typical_treatment_value (inds [0 ]), np .mean (np .abs (X ['a' ])))
195
- np .testing .assert_allclose (ca .typical_treatment_value (inds [1 ]), np .mean (np .abs (X ['b' ])))
196
- # discrete treatments have typical treatment value 1
197
- assert ca .typical_treatment_value (inds [2 ]) == ca .typical_treatment_value (inds [3 ]) == 1
198
-
199
- # Make sure we handle continuous, binary, and multi-class treatments
200
- # For multiple discrete treatments, one "always treat" value per non-default treatment
201
- for (idx , length ) in [(0 , 1 ), (1 , 1 ), (2 , 1 ), (3 , 2 )]:
202
- pto = ca ._policy_tree_output (X , inds [idx ])
203
- policy_val = pto .policy_value
204
- always_trt = pto .always_treat
205
- assert isinstance (pto .control_name , str )
206
- assert isinstance (always_trt , dict )
207
- assert np .array (policy_val ).shape == ()
208
- assert len (always_trt ) == length
209
- for val in always_trt .values ():
210
- assert np .array (val ).shape == ()
211
-
212
- # policy value should exceed always treating with any treatment
213
- assert_less_close (np .array (list (always_trt .values ())), policy_val )
214
-
215
- if not classification :
216
- # ExitStack can be used as a "do nothing" ContextManager
217
- cm = ExitStack ()
218
- else :
219
- cm = self .assertRaises (Exception )
220
- with cm :
221
- inf = ca .whatif (X [:2 ], np .ones (shape = (2 ,)), inds [1 ], y [:2 ])
222
- assert np .shape (inf .point_estimate ) == np .shape (y [:2 ])
223
- inf = ca .whatif (X [:2 ], np .ones (shape = (2 ,)), inds [2 ], y [:2 ])
224
- assert np .shape (inf .point_estimate ) == np .shape (y [:2 ])
133
+ for category in [False , True ]:
134
+ y = pd .Series (np .random .choice ([0 , 1 ], size = (500 ,)))
135
+ X = pd .DataFrame ({'a' : np .random .normal (size = 500 ),
136
+ 'b' : np .random .normal (size = 500 ),
137
+ 'c' : np .random .choice ([0 , 1 ], size = 500 ),
138
+ 'd' : np .random .choice (['a' , 'b' , 'c' ], size = 500 )})
139
+
140
+ if category :
141
+ X ['c' ] = X ['c' ].astype ('category' )
142
+ X ['d' ] = X ['d' ].astype ('category' )
143
+
144
+ n_inds = [0 , 1 , 2 , 3 ]
145
+ t_inds = ['a' , 'b' , 'c' , 'd' ]
146
+ n_cats = [2 , 3 ]
147
+ t_cats = ['c' , 'd' ]
148
+ n_hinds = [0 , 3 ]
149
+ t_hinds = ['a' , 'd' ]
150
+ for (inds , cats , hinds ) in [(n_inds , n_cats , n_hinds ), (t_inds , t_cats , t_hinds )]:
151
+ ca = CausalAnalysis (inds , cats , hinds , classification = classification )
152
+ ca .fit (X , y )
153
+ glo = ca .global_causal_effect ()
154
+ coh = ca .cohort_causal_effect (X [:2 ])
155
+ loc = ca .local_causal_effect (X [:2 ])
156
+
157
+ # global and cohort data should have exactly the same structure, but different values
158
+ assert glo .index .equals (coh .index )
159
+
160
+ # local index should have as many times entries as global as there were rows passed in
161
+ assert len (loc .index ) == 2 * len (glo .index )
162
+
163
+ assert glo .index .names == ['feature' , 'feature_value' ]
164
+ assert loc .index .names == ['sample' ] + glo .index .names
165
+
166
+ # features; for categoricals they should appear #cats-1 times each
167
+ fts = ['a' , 'b' , 'c' , 'd' , 'd' ]
168
+
169
+ for i in range (len (fts )):
170
+ assert fts [i ] == glo .index [i ][0 ] == loc .index [i ][1 ] == loc .index [len (fts ) + i ][1 ]
171
+
172
+ glo_dict = ca ._global_causal_effect_dict ()
173
+ glo_dict2 = ca ._global_causal_effect_dict (row_wise = True )
174
+
175
+ coh_dict = ca ._cohort_causal_effect_dict (X [:2 ])
176
+ coh_dict2 = ca ._cohort_causal_effect_dict (X [:2 ], row_wise = True )
177
+
178
+ loc_dict = ca ._local_causal_effect_dict (X [:2 ])
179
+ loc_dict2 = ca ._local_causal_effect_dict (X [:2 ], row_wise = True )
180
+
181
+ glo_point_est = np .array (glo_dict [_CausalInsightsConstants .PointEstimateKey ])
182
+ coh_point_est = np .array (coh_dict [_CausalInsightsConstants .PointEstimateKey ])
183
+ loc_point_est = np .array (loc_dict [_CausalInsightsConstants .PointEstimateKey ])
184
+
185
+ # global shape is (d_y, sum(d_t))
186
+ assert glo_point_est .shape == coh_point_est .shape == (1 , 5 )
187
+ assert loc_point_est .shape == (2 ,) + glo_point_est .shape
188
+
189
+ # global and cohort row-wise dicts have d_y * d_t entries
190
+ assert len (
191
+ glo_dict2 [_CausalInsightsConstants .RowData ]) == len (
192
+ coh_dict2 [_CausalInsightsConstants .RowData ]) == 5
193
+ # local dictionary is flattened to n_rows * d_y * d_t
194
+ assert len (loc_dict2 [_CausalInsightsConstants .RowData ]) == 10
195
+
196
+ pto = ca ._policy_tree_output (X , inds [1 ])
197
+ ca ._heterogeneity_tree_output (X , inds [1 ])
198
+ ca ._heterogeneity_tree_output (X , inds [3 ])
199
+
200
+ # continuous treatments have typical treatment values equal to
201
+ # the mean of the absolute value of non-zero entries
202
+ np .testing .assert_allclose (ca .typical_treatment_value (inds [0 ]), np .mean (np .abs (X ['a' ])))
203
+ np .testing .assert_allclose (ca .typical_treatment_value (inds [1 ]), np .mean (np .abs (X ['b' ])))
204
+ # discrete treatments have typical treatment value 1
205
+ assert ca .typical_treatment_value (inds [2 ]) == ca .typical_treatment_value (inds [3 ]) == 1
206
+
207
+ # Make sure we handle continuous, binary, and multi-class treatments
208
+ # For multiple discrete treatments, one "always treat" value per non-default treatment
209
+ for (idx , length ) in [(0 , 1 ), (1 , 1 ), (2 , 1 ), (3 , 2 )]:
210
+ pto = ca ._policy_tree_output (X , inds [idx ])
211
+ policy_val = pto .policy_value
212
+ always_trt = pto .always_treat
213
+ assert isinstance (pto .control_name , str )
214
+ assert isinstance (always_trt , dict )
215
+ assert np .array (policy_val ).shape == ()
216
+ assert len (always_trt ) == length
217
+ for val in always_trt .values ():
218
+ assert np .array (val ).shape == ()
219
+
220
+ # policy value should exceed always treating with any treatment
221
+ assert_less_close (np .array (list (always_trt .values ())), policy_val )
222
+
223
+ ind_pol = ca .individualized_policy (X , inds [idx ])
224
+
225
+ if not classification :
226
+ # ExitStack can be used as a "do nothing" ContextManager
227
+ cm = ExitStack ()
228
+ else :
229
+ cm = self .assertRaises (Exception )
230
+ with cm :
231
+ inf = ca .whatif (X [:2 ], np .ones (shape = (2 ,)), inds [1 ], y [:2 ])
232
+ assert np .shape (inf .point_estimate ) == np .shape (y [:2 ])
233
+ inf = ca .whatif (X [:2 ], np .ones (shape = (2 ,)), inds [2 ], y [:2 ])
234
+ assert np .shape (inf .point_estimate ) == np .shape (y [:2 ])
225
235
226
- ca ._whatif_dict (X [:2 ], np .ones (shape = (2 ,)), inds [1 ], y [:2 ])
227
- ca ._whatif_dict (X [:2 ], np .ones (shape = (2 ,)), inds [1 ], y [:2 ], row_wise = True )
236
+ ca ._whatif_dict (X [:2 ], np .ones (shape = (2 ,)), inds [1 ], y [:2 ])
237
+ ca ._whatif_dict (X [:2 ], np .ones (shape = (2 ,)), inds [1 ], y [:2 ], row_wise = True )
228
238
229
- badargs = [
230
- (n_inds , n_cats , [4 ]), # hinds out of range
231
- (n_inds , n_cats , ["test" ]) # hinds out of range
232
- ]
239
+ badargs = [
240
+ (n_inds , n_cats , [4 ]), # hinds out of range
241
+ (n_inds , n_cats , ["test" ]) # hinds out of range
242
+ ]
233
243
234
- for args in badargs :
235
- with self .assertRaises (Exception ):
236
- ca = CausalAnalysis (* args )
237
- ca .fit (X , y )
244
+ for args in badargs :
245
+ with self .assertRaises (Exception ):
246
+ ca = CausalAnalysis (* args )
247
+ ca .fit (X , y )
238
248
239
249
def test_automl_first_stage (self ):
240
250
d_y = (1 ,)
@@ -294,6 +304,8 @@ def test_automl_first_stage(self):
294
304
# policy value should exceed always treating with any treatment
295
305
assert_less_close (np .array (list (always_trt .values ())), policy_val )
296
306
307
+ ind_pol = ca .individualized_policy (X , inds [idx ])
308
+
297
309
# global shape is (d_y, sum(d_t))
298
310
assert glo_point_est .shape == coh_point_est .shape == (1 , 5 )
299
311
assert loc_point_est .shape == (2 ,) + glo_point_est .shape
@@ -436,6 +448,8 @@ def test_final_models(self):
436
448
# policy value should exceed always treating with any treatment
437
449
assert_less_close (np .array (list (always_trt .values ())), policy_val )
438
450
451
+ ind_pol = ca .individualized_policy (X , inds [idx ])
452
+
439
453
if not classification :
440
454
# ExitStack can be used as a "do nothing" ContextManager
441
455
cm = ExitStack ()
@@ -526,6 +540,8 @@ def test_forest_with_pandas(self):
526
540
# policy value should exceed always treating with any treatment
527
541
assert_less_close (np .array (list (always_trt .values ())), policy_val )
528
542
543
+ ind_pol = ca .individualized_policy (X , inds [idx ])
544
+
529
545
def test_warm_start (self ):
530
546
for classification in [True , False ]:
531
547
# dgp
0 commit comments