Skip to content

Commit ac12f54

Browse files
authored
Handle pandas categorical types for categorical columns in _causal_analysis.py (#602)
If the categorical type is set for a treatment column explicitly then there is a failure in `CausalAnalysis` class. ``` ~\AppData\Local\Continuum\miniconda3\envs\nhs-hips\lib\site-packages\econml\solutions\causal_analysis\_causal_analysis.py in individualized_policy(self, Xtest, feature_index, n_rows, treatment_costs, alpha) 1714 all_costs = np.array([0] + [treatment_costs] * (len(treatment_arr) - 1)) 1715 # construct index of current treatment -> 1716 current_ind = (current_treatment.reshape(-1, 1) == 1717 treatment_arr.reshape(1, -1)) @ np.arange(len(treatment_arr)) 1718 current_cost = all_costs[current_ind] ~\AppData\Local\Continuum\miniconda3\envs\nhs-hips\lib\site-packages\pandas\core\ops\common.py in new_method(self, other) 67 other = item_from_zerodim(other) 68 ---> 69 return method(self, other) 70 71 return new_method ~\AppData\Local\Continuum\miniconda3\envs\nhs-hips\lib\site-packages\pandas\core\arrays\categorical.py in func(self, other) 131 if is_list_like(other) and len(other) != len(self) and not hashable: 132 # in hashable case we may have a tuple that is itself a category --> 133 raise ValueError("Lengths must match.") 134 135 if not self.ordered: ``` Solution is to check for the type of the categorical column to see if it is of type `pd.core.arrays.categorical.Categorical` and extract the numpy array using `to_numpy()` method.
1 parent 5cf6920 commit ac12f54

File tree

2 files changed

+122
-104
lines changed

2 files changed

+122
-104
lines changed

econml/solutions/causal_analysis/_causal_analysis.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,6 +1701,8 @@ def individualized_policy(self, Xtest, feature_index, *, n_rows=None, treatment_
17011701
effect = result.estimator.effect_inference(Xtest, T0=orig_df['Current treatment'], T1=rec)
17021702
# we now need to construct the delta in the cost between the two treatments and translate the effect
17031703
current_treatment = orig_df['Current treatment'].values
1704+
if isinstance(current_treatment, pd.core.arrays.categorical.Categorical):
1705+
current_treatment = current_treatment.to_numpy()
17041706
if np.ndim(treatment_costs) >= 2:
17051707
# remove third dimenions potentially added
17061708
if multi_y: # y was an array, not a vector

econml/tests/test_causal_analysis.py

Lines changed: 120 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def test_basic_array(self):
8585
# policy value should exceed always treating with any treatment
8686
assert_less_close(np.array(list(always_trt.values())), policy_val)
8787

88+
ind_pol = ca.individualized_policy(X, inds[idx])
89+
8890
# global shape is (d_y, sum(d_t))
8991
assert glo_point_est.shape == coh_point_est.shape == (1, 5)
9092
assert loc_point_est.shape == (2,) + glo_point_est.shape
@@ -128,113 +130,121 @@ def test_basic_array(self):
128130

129131
def test_basic_pandas(self):
130132
for classification in [False, True]:
131-
y = pd.Series(np.random.choice([0, 1], size=(500,)))
132-
X = pd.DataFrame({'a': np.random.normal(size=500),
133-
'b': np.random.normal(size=500),
134-
'c': np.random.choice([0, 1], size=500),
135-
'd': np.random.choice(['a', 'b', 'c'], size=500)})
136-
n_inds = [0, 1, 2, 3]
137-
t_inds = ['a', 'b', 'c', 'd']
138-
n_cats = [2, 3]
139-
t_cats = ['c', 'd']
140-
n_hinds = [0, 3]
141-
t_hinds = ['a', 'd']
142-
for (inds, cats, hinds) in [(n_inds, n_cats, n_hinds), (t_inds, t_cats, t_hinds)]:
143-
ca = CausalAnalysis(inds, cats, hinds, classification=classification)
144-
ca.fit(X, y)
145-
glo = ca.global_causal_effect()
146-
coh = ca.cohort_causal_effect(X[:2])
147-
loc = ca.local_causal_effect(X[:2])
148-
149-
# global and cohort data should have exactly the same structure, but different values
150-
assert glo.index.equals(coh.index)
151-
152-
# local index should have as many times entries as global as there were rows passed in
153-
assert len(loc.index) == 2 * len(glo.index)
154-
155-
assert glo.index.names == ['feature', 'feature_value']
156-
assert loc.index.names == ['sample'] + glo.index.names
157-
158-
# features; for categoricals they should appear #cats-1 times each
159-
fts = ['a', 'b', 'c', 'd', 'd']
160-
161-
for i in range(len(fts)):
162-
assert fts[i] == glo.index[i][0] == loc.index[i][1] == loc.index[len(fts) + i][1]
163-
164-
glo_dict = ca._global_causal_effect_dict()
165-
glo_dict2 = ca._global_causal_effect_dict(row_wise=True)
166-
167-
coh_dict = ca._cohort_causal_effect_dict(X[:2])
168-
coh_dict2 = ca._cohort_causal_effect_dict(X[:2], row_wise=True)
169-
170-
loc_dict = ca._local_causal_effect_dict(X[:2])
171-
loc_dict2 = ca._local_causal_effect_dict(X[:2], row_wise=True)
172-
173-
glo_point_est = np.array(glo_dict[_CausalInsightsConstants.PointEstimateKey])
174-
coh_point_est = np.array(coh_dict[_CausalInsightsConstants.PointEstimateKey])
175-
loc_point_est = np.array(loc_dict[_CausalInsightsConstants.PointEstimateKey])
176-
177-
# global shape is (d_y, sum(d_t))
178-
assert glo_point_est.shape == coh_point_est.shape == (1, 5)
179-
assert loc_point_est.shape == (2,) + glo_point_est.shape
180-
181-
# global and cohort row-wise dicts have d_y * d_t entries
182-
assert len(
183-
glo_dict2[_CausalInsightsConstants.RowData]) == len(
184-
coh_dict2[_CausalInsightsConstants.RowData]) == 5
185-
# local dictionary is flattened to n_rows * d_y * d_t
186-
assert len(loc_dict2[_CausalInsightsConstants.RowData]) == 10
187-
188-
pto = ca._policy_tree_output(X, inds[1])
189-
ca._heterogeneity_tree_output(X, inds[1])
190-
ca._heterogeneity_tree_output(X, inds[3])
191-
192-
# continuous treatments have typical treatment values equal to
193-
# the mean of the absolute value of non-zero entries
194-
np.testing.assert_allclose(ca.typical_treatment_value(inds[0]), np.mean(np.abs(X['a'])))
195-
np.testing.assert_allclose(ca.typical_treatment_value(inds[1]), np.mean(np.abs(X['b'])))
196-
# discrete treatments have typical treatment value 1
197-
assert ca.typical_treatment_value(inds[2]) == ca.typical_treatment_value(inds[3]) == 1
198-
199-
# Make sure we handle continuous, binary, and multi-class treatments
200-
# For multiple discrete treatments, one "always treat" value per non-default treatment
201-
for (idx, length) in [(0, 1), (1, 1), (2, 1), (3, 2)]:
202-
pto = ca._policy_tree_output(X, inds[idx])
203-
policy_val = pto.policy_value
204-
always_trt = pto.always_treat
205-
assert isinstance(pto.control_name, str)
206-
assert isinstance(always_trt, dict)
207-
assert np.array(policy_val).shape == ()
208-
assert len(always_trt) == length
209-
for val in always_trt.values():
210-
assert np.array(val).shape == ()
211-
212-
# policy value should exceed always treating with any treatment
213-
assert_less_close(np.array(list(always_trt.values())), policy_val)
214-
215-
if not classification:
216-
# ExitStack can be used as a "do nothing" ContextManager
217-
cm = ExitStack()
218-
else:
219-
cm = self.assertRaises(Exception)
220-
with cm:
221-
inf = ca.whatif(X[:2], np.ones(shape=(2,)), inds[1], y[:2])
222-
assert np.shape(inf.point_estimate) == np.shape(y[:2])
223-
inf = ca.whatif(X[:2], np.ones(shape=(2,)), inds[2], y[:2])
224-
assert np.shape(inf.point_estimate) == np.shape(y[:2])
133+
for category in [False, True]:
134+
y = pd.Series(np.random.choice([0, 1], size=(500,)))
135+
X = pd.DataFrame({'a': np.random.normal(size=500),
136+
'b': np.random.normal(size=500),
137+
'c': np.random.choice([0, 1], size=500),
138+
'd': np.random.choice(['a', 'b', 'c'], size=500)})
139+
140+
if category:
141+
X['c'] = X['c'].astype('category')
142+
X['d'] = X['d'].astype('category')
143+
144+
n_inds = [0, 1, 2, 3]
145+
t_inds = ['a', 'b', 'c', 'd']
146+
n_cats = [2, 3]
147+
t_cats = ['c', 'd']
148+
n_hinds = [0, 3]
149+
t_hinds = ['a', 'd']
150+
for (inds, cats, hinds) in [(n_inds, n_cats, n_hinds), (t_inds, t_cats, t_hinds)]:
151+
ca = CausalAnalysis(inds, cats, hinds, classification=classification)
152+
ca.fit(X, y)
153+
glo = ca.global_causal_effect()
154+
coh = ca.cohort_causal_effect(X[:2])
155+
loc = ca.local_causal_effect(X[:2])
156+
157+
# global and cohort data should have exactly the same structure, but different values
158+
assert glo.index.equals(coh.index)
159+
160+
# local index should have as many times entries as global as there were rows passed in
161+
assert len(loc.index) == 2 * len(glo.index)
162+
163+
assert glo.index.names == ['feature', 'feature_value']
164+
assert loc.index.names == ['sample'] + glo.index.names
165+
166+
# features; for categoricals they should appear #cats-1 times each
167+
fts = ['a', 'b', 'c', 'd', 'd']
168+
169+
for i in range(len(fts)):
170+
assert fts[i] == glo.index[i][0] == loc.index[i][1] == loc.index[len(fts) + i][1]
171+
172+
glo_dict = ca._global_causal_effect_dict()
173+
glo_dict2 = ca._global_causal_effect_dict(row_wise=True)
174+
175+
coh_dict = ca._cohort_causal_effect_dict(X[:2])
176+
coh_dict2 = ca._cohort_causal_effect_dict(X[:2], row_wise=True)
177+
178+
loc_dict = ca._local_causal_effect_dict(X[:2])
179+
loc_dict2 = ca._local_causal_effect_dict(X[:2], row_wise=True)
180+
181+
glo_point_est = np.array(glo_dict[_CausalInsightsConstants.PointEstimateKey])
182+
coh_point_est = np.array(coh_dict[_CausalInsightsConstants.PointEstimateKey])
183+
loc_point_est = np.array(loc_dict[_CausalInsightsConstants.PointEstimateKey])
184+
185+
# global shape is (d_y, sum(d_t))
186+
assert glo_point_est.shape == coh_point_est.shape == (1, 5)
187+
assert loc_point_est.shape == (2,) + glo_point_est.shape
188+
189+
# global and cohort row-wise dicts have d_y * d_t entries
190+
assert len(
191+
glo_dict2[_CausalInsightsConstants.RowData]) == len(
192+
coh_dict2[_CausalInsightsConstants.RowData]) == 5
193+
# local dictionary is flattened to n_rows * d_y * d_t
194+
assert len(loc_dict2[_CausalInsightsConstants.RowData]) == 10
195+
196+
pto = ca._policy_tree_output(X, inds[1])
197+
ca._heterogeneity_tree_output(X, inds[1])
198+
ca._heterogeneity_tree_output(X, inds[3])
199+
200+
# continuous treatments have typical treatment values equal to
201+
# the mean of the absolute value of non-zero entries
202+
np.testing.assert_allclose(ca.typical_treatment_value(inds[0]), np.mean(np.abs(X['a'])))
203+
np.testing.assert_allclose(ca.typical_treatment_value(inds[1]), np.mean(np.abs(X['b'])))
204+
# discrete treatments have typical treatment value 1
205+
assert ca.typical_treatment_value(inds[2]) == ca.typical_treatment_value(inds[3]) == 1
206+
207+
# Make sure we handle continuous, binary, and multi-class treatments
208+
# For multiple discrete treatments, one "always treat" value per non-default treatment
209+
for (idx, length) in [(0, 1), (1, 1), (2, 1), (3, 2)]:
210+
pto = ca._policy_tree_output(X, inds[idx])
211+
policy_val = pto.policy_value
212+
always_trt = pto.always_treat
213+
assert isinstance(pto.control_name, str)
214+
assert isinstance(always_trt, dict)
215+
assert np.array(policy_val).shape == ()
216+
assert len(always_trt) == length
217+
for val in always_trt.values():
218+
assert np.array(val).shape == ()
219+
220+
# policy value should exceed always treating with any treatment
221+
assert_less_close(np.array(list(always_trt.values())), policy_val)
222+
223+
ind_pol = ca.individualized_policy(X, inds[idx])
224+
225+
if not classification:
226+
# ExitStack can be used as a "do nothing" ContextManager
227+
cm = ExitStack()
228+
else:
229+
cm = self.assertRaises(Exception)
230+
with cm:
231+
inf = ca.whatif(X[:2], np.ones(shape=(2,)), inds[1], y[:2])
232+
assert np.shape(inf.point_estimate) == np.shape(y[:2])
233+
inf = ca.whatif(X[:2], np.ones(shape=(2,)), inds[2], y[:2])
234+
assert np.shape(inf.point_estimate) == np.shape(y[:2])
225235

226-
ca._whatif_dict(X[:2], np.ones(shape=(2,)), inds[1], y[:2])
227-
ca._whatif_dict(X[:2], np.ones(shape=(2,)), inds[1], y[:2], row_wise=True)
236+
ca._whatif_dict(X[:2], np.ones(shape=(2,)), inds[1], y[:2])
237+
ca._whatif_dict(X[:2], np.ones(shape=(2,)), inds[1], y[:2], row_wise=True)
228238

229-
badargs = [
230-
(n_inds, n_cats, [4]), # hinds out of range
231-
(n_inds, n_cats, ["test"]) # hinds out of range
232-
]
239+
badargs = [
240+
(n_inds, n_cats, [4]), # hinds out of range
241+
(n_inds, n_cats, ["test"]) # hinds out of range
242+
]
233243

234-
for args in badargs:
235-
with self.assertRaises(Exception):
236-
ca = CausalAnalysis(*args)
237-
ca.fit(X, y)
244+
for args in badargs:
245+
with self.assertRaises(Exception):
246+
ca = CausalAnalysis(*args)
247+
ca.fit(X, y)
238248

239249
def test_automl_first_stage(self):
240250
d_y = (1,)
@@ -294,6 +304,8 @@ def test_automl_first_stage(self):
294304
# policy value should exceed always treating with any treatment
295305
assert_less_close(np.array(list(always_trt.values())), policy_val)
296306

307+
ind_pol = ca.individualized_policy(X, inds[idx])
308+
297309
# global shape is (d_y, sum(d_t))
298310
assert glo_point_est.shape == coh_point_est.shape == (1, 5)
299311
assert loc_point_est.shape == (2,) + glo_point_est.shape
@@ -436,6 +448,8 @@ def test_final_models(self):
436448
# policy value should exceed always treating with any treatment
437449
assert_less_close(np.array(list(always_trt.values())), policy_val)
438450

451+
ind_pol = ca.individualized_policy(X, inds[idx])
452+
439453
if not classification:
440454
# ExitStack can be used as a "do nothing" ContextManager
441455
cm = ExitStack()
@@ -526,6 +540,8 @@ def test_forest_with_pandas(self):
526540
# policy value should exceed always treating with any treatment
527541
assert_less_close(np.array(list(always_trt.values())), policy_val)
528542

543+
ind_pol = ca.individualized_policy(X, inds[idx])
544+
529545
def test_warm_start(self):
530546
for classification in [True, False]:
531547
# dgp

0 commit comments

Comments
 (0)