Skip to content

Commit 436ee08

Browse files
hayesallglemaitre
andauthored
MAINT Replace stats.mode calls with fixes._mode (#938)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 2779327 commit 436ee08

File tree

5 files changed

+29
-7
lines changed

5 files changed

+29
-7
lines changed

build_tools/azure/install.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ elif [[ "$DISTRIB" == "conda-pip-latest" ]]; then
6868
python -m pip install -U pip
6969

7070
python -m pip install pandas matplotlib
71-
python -m pip install --pre scikit-learn
71+
python -m pip install scikit-learn
7272

7373
elif [[ "$DISTRIB" == "conda-pip-latest-tensorflow" ]]; then
7474
make_conda "python=$PYTHON_VERSION"

imblearn/over_sampling/_smote/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import numpy as np
1414
from scipy import sparse
15-
from scipy import stats
1615

1716
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
1817
from sklearn.utils import check_random_state
@@ -29,6 +28,7 @@
2928
from ...utils._docstring import _n_jobs_docstring
3029
from ...utils._docstring import _random_state_docstring
3130
from ...utils._validation import _deprecate_positional_args
31+
from ...utils.fixes import _mode
3232

3333

3434
class BaseSMOTE(BaseOverSampler):
@@ -786,7 +786,7 @@ def _make_samples(self, X_class, klass, y_dtype, nn_indices, n_samples):
786786
# where for each feature individually, each category generated is the
787787
# most common category
788788
X_new = np.squeeze(
789-
stats.mode(X_class[nn_indices[samples_indices]], axis=1).mode, axis=1
789+
_mode(X_class[nn_indices[samples_indices]], axis=1).mode, axis=1
790790
)
791791
y_new = np.full(n_samples, fill_value=klass, dtype=y_dtype)
792792
return X_new, y_new

imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from collections import Counter
1010

1111
import numpy as np
12-
from scipy.stats import mode
1312

1413
from sklearn.utils import _safe_indexing
1514

@@ -18,6 +17,8 @@
1817
from ...utils import Substitution
1918
from ...utils._docstring import _n_jobs_docstring
2019
from ...utils._validation import _deprecate_positional_args
20+
from ...utils.fixes import _mode
21+
2122

2223
SEL_KIND = ("all", "mode")
2324

@@ -155,7 +156,7 @@ def _fit_resample(self, X, y):
155156
nnhood_idx = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:]
156157
nnhood_label = y[nnhood_idx]
157158
if self.kind_sel == "mode":
158-
nnhood_label, _ = mode(nnhood_label, axis=1)
159+
nnhood_label, _ = _mode(nnhood_label, axis=1)
159160
nnhood_bool = np.ravel(nnhood_label) == y_class
160161
elif self.kind_sel == "all":
161162
nnhood_label = nnhood_label == target_class

imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from collections import Counter
88

99
import numpy as np
10-
from scipy.stats import mode
1110

1211
from sklearn.utils import _safe_indexing
1312

@@ -17,6 +16,8 @@
1716
from ...utils import Substitution
1817
from ...utils._docstring import _n_jobs_docstring
1918
from ...utils._validation import _deprecate_positional_args
19+
from ...utils.fixes import _mode
20+
2021

2122
SEL_KIND = ("all", "mode")
2223

@@ -182,7 +183,7 @@ def _fit_resample(self, X, y):
182183
nnhood_idx = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:]
183184
nnhood_label = y[nnhood_idx]
184185
if self.kind_sel == "mode":
185-
nnhood_label_majority, _ = mode(nnhood_label, axis=1)
186+
nnhood_label_majority, _ = _mode(nnhood_label, axis=1)
186187
nnhood_bool = np.ravel(nnhood_label_majority) == y_class
187188
elif self.kind_sel == "all":
188189
nnhood_label_majority = nnhood_label == class_minority

imblearn/utils/fixes.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Compatibility fixes for older version of python, numpy, scipy, and
2+
scikit-learn.
3+
4+
If you add content to this file, please give the version of the package at
5+
which the fix is no longer needed.
6+
"""
7+
8+
import scipy
9+
import scipy.stats
10+
11+
from sklearn.utils.fixes import parse_version
12+
13+
sp_version = parse_version(scipy.__version__)
14+
15+
16+
# TODO: Remove when SciPy 1.9 is the minimum supported version
17+
def _mode(a, axis=0):
18+
if sp_version >= parse_version("1.9.0"):
19+
return scipy.stats.mode(a, axis=axis, keepdims=True)
20+
return scipy.stats.mode(a, axis=axis)

0 commit comments

Comments
 (0)