Skip to content

Commit 704d106

Browse files
authored
ENH support array-like of str for categorical_features in SMOTENC (#1008)
1 parent b468f7f commit 704d106

File tree

4 files changed

+52
-21
lines changed

4 files changed

+52
-21
lines changed

doc/over_sampling.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,8 @@ which categorical data are treated differently::
192192

193193
In this data set, the first and last features are considered as categorical
194194
features. One needs to provide this information to :class:`SMOTENC` via the
195-
parameters ``categorical_features`` either by passing the indices of these
196-
features or a boolean mask marking these features::
195+
parameters ``categorical_features`` either by passing the indices, the feature
196+
names when `X` is a pandas DataFrame, or a boolean mask marking these features::
197197

198198
>>> from imblearn.over_sampling import SMOTENC
199199
>>> smote_nc = SMOTENC(categorical_features=[0, 2], random_state=0)

doc/whats_new/v0.11.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,7 @@ Enhancements
5757
:class:`~imblearn.over_sampling.RandomOverSampler` (when `shrinkage is not
5858
None`) now accept any data types and will not attempt any data conversion.
5959
:pr:`1004` by :user:`Guillaume Lemaitre <glemaitre>`.
60+
61+
- :class:`~imblearn.over_sampling.SMOTENC` now support passing array-like of `str`
62+
when passing the `categorical_features` parameter.
63+
:pr:`1008` by :user`Guillaume Lemaitre <glemaitre>`.

imblearn/over_sampling/_smote/base.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
from sklearn.base import clone
1717
from sklearn.exceptions import DataConversionWarning
1818
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
19-
from sklearn.utils import _safe_indexing, check_array, check_random_state
19+
from sklearn.utils import (
20+
_get_column_indices,
21+
_safe_indexing,
22+
check_array,
23+
check_random_state,
24+
)
2025
from sklearn.utils.sparsefuncs_fast import (
2126
csc_mean_variance_axis0,
2227
csr_mean_variance_axis0,
@@ -390,10 +395,14 @@ class SMOTENC(SMOTE):
390395
391396
Parameters
392397
----------
393-
categorical_features : array-like of shape (n_cat_features,) or (n_features,)
398+
categorical_features : array-like of shape (n_cat_features,) or (n_features,), \
399+
dtype={{bool, int, str}}
394400
Specified which features are categorical. Can either be:
395401
396-
- array of indices specifying the categorical features;
402+
- array of `int` corresponding to the indices specifying the categorical
403+
features;
404+
- array of `str` corresponding to the feature names. `X` should be a pandas
405+
:class:`pandas.DataFrame` in this case.
397406
- mask array of shape (n_features, ) and ``bool`` dtype for which
398407
``True`` indicates the categorical features.
399408
@@ -565,24 +574,16 @@ def _check_X_y(self, X, y):
565574
self._check_feature_names(X, reset=True)
566575
return X, y, binarize_y
567576

568-
def _validate_estimator(self):
569-
super()._validate_estimator()
570-
categorical_features = np.asarray(self.categorical_features)
571-
if categorical_features.dtype.name == "bool":
572-
self.categorical_features_ = np.flatnonzero(categorical_features)
573-
else:
574-
if any(
575-
[cat not in np.arange(self.n_features_) for cat in categorical_features]
576-
):
577-
raise ValueError(
578-
f"Some of the categorical indices are out of range. Indices"
579-
f" should be between 0 and {self.n_features_ - 1}"
580-
)
581-
self.categorical_features_ = categorical_features
577+
def _validate_column_types(self, X):
578+
self.categorical_features_ = np.array(
579+
_get_column_indices(X, self.categorical_features)
580+
)
582581
self.continuous_features_ = np.setdiff1d(
583582
np.arange(self.n_features_), self.categorical_features_
584583
)
585584

585+
def _validate_estimator(self):
586+
super()._validate_estimator()
586587
if self.categorical_features_.size == self.n_features_in_:
587588
raise ValueError(
588589
"SMOTE-NC is not designed to work only with categorical "
@@ -600,6 +601,7 @@ def _fit_resample(self, X, y):
600601
)
601602

602603
self.n_features_ = _num_features(X)
604+
self._validate_column_types(X)
603605
self._validate_estimator()
604606

605607
# compute the median of the standard deviation of the minority class

imblearn/over_sampling/_smote/tests/test_smote_nc.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def data_heterogneous_masked():
6363
X[:, 3] = rng.randint(3, size=30)
6464
y = np.array([0] * 10 + [1] * 20)
6565
# return the categories
66-
return X, y, [True, False, True]
66+
return X, y, [True, False, False, True]
6767

6868

6969
def data_heterogneous_unordered_multiclass():
@@ -98,7 +98,7 @@ def test_smotenc_error():
9898
X, y, _ = data_heterogneous_unordered()
9999
categorical_features = [0, 10]
100100
smote = SMOTENC(random_state=0, categorical_features=categorical_features)
101-
with pytest.raises(ValueError, match="indices are out of range"):
101+
with pytest.raises(ValueError, match="all features must be in"):
102102
smote.fit_resample(X, y)
103103

104104

@@ -324,3 +324,28 @@ def test_smotenc_bool_categorical():
324324
X_res, y_res = smote.fit_resample(X, y)
325325
pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
326326
assert len(X_res) == len(y_res)
327+
328+
329+
def test_smotenc_categorical_features_str():
330+
"""Check that we support array-like of strings for `categorical_features` using
331+
pandas dataframe.
332+
"""
333+
pd = pytest.importorskip("pandas")
334+
335+
X = pd.DataFrame(
336+
{
337+
"A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
338+
"B": ["a", "b"] * 5,
339+
"C": ["a", "b", "c"] * 3 + ["a"],
340+
}
341+
)
342+
X = pd.concat([X] * 10, ignore_index=True)
343+
y = np.array([0] * 70 + [1] * 30)
344+
smote = SMOTENC(categorical_features=["B", "C"], random_state=0)
345+
X_res, y_res = smote.fit_resample(X, y)
346+
assert X_res["B"].isin(["a", "b"]).all()
347+
assert X_res["C"].isin(["a", "b", "c"]).all()
348+
counter = Counter(y_res)
349+
assert counter[0] == counter[1] == 70
350+
assert_array_equal(smote.categorical_features_, [1, 2])
351+
assert_array_equal(smote.continuous_features_, [0])

0 commit comments

Comments
 (0)