Skip to content

Commit 4714869

Browse files
authored
MAINT fix remaining failures with scikit-learn 1.2 (#947)
1 parent e3f5075 commit 4714869

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

imblearn/metrics/pairwise.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,8 @@ def pairwise(self, X, Y=None):
205205
distance_matrix(proba_feature_X, proba_feature_Y, p=self.k) ** self.r
206206
)
207207
return distance
208+
209+
def _more_tags(self):
210+
return {
211+
"requires_positive_X": True, # X should be encoded with OrdinalEncoder
212+
}

imblearn/tests/test_docstring_parameters.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,20 @@
1717
from sklearn.utils._testing import _get_func_name
1818
from sklearn.utils._testing import ignore_warnings
1919
from sklearn.utils.estimator_checks import _enforce_estimator_tags_y
20-
from sklearn.utils.estimator_checks import _enforce_estimator_tags_x
20+
21+
try:
22+
from sklearn.utils.estimator_checks import _enforce_estimator_tags_x
23+
except ImportError:
24+
# scikit-learn >= 1.2
25+
from sklearn.utils.estimator_checks import (
26+
_enforce_estimator_tags_X as _enforce_estimator_tags_x,
27+
)
2128
from sklearn.utils.estimator_checks import _construct_instance
2229
from sklearn.utils.deprecation import _is_deprecated
2330

2431
import imblearn
2532
from imblearn.base import is_sampler
33+
from imblearn.utils.estimator_checks import _set_checking_parameters
2634
from imblearn.utils.testing import all_estimators
2735

2836

@@ -183,6 +191,7 @@ def test_fit_docstring_attributes(name, Estimator):
183191
est = _construct_compose_pipeline_instance(Estimator)
184192
else:
185193
est = _construct_instance(Estimator)
194+
_set_checking_parameters(est)
186195

187196
X, y = make_classification(
188197
n_samples=20,

imblearn/utils/estimator_checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _set_checking_parameters(estimator):
4545
if name == "ClusterCentroids":
4646
estimator.set_params(
4747
voting="soft",
48-
estimator=KMeans(random_state=0, algorithm="full", n_init=1),
48+
estimator=KMeans(random_state=0, algorithm="lloyd", n_init=1),
4949
)
5050
if name == "KMeansSMOTE":
5151
estimator.set_params(kmeans_estimator=12)

0 commit comments

Comments
 (0)