Skip to content

Commit ad71707

Browse files
authored
MAINT validate parameters for public functions (#956)
1 parent f8c27ae commit ad71707

File tree

10 files changed

+277
-109
lines changed

10 files changed

+277
-109
lines changed

imblearn/datasets/_imbalance.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,22 @@
66
# License: MIT
77

88
from collections import Counter
9+
from collections.abc import Mapping
910

1011
from ..under_sampling import RandomUnderSampler
1112
from ..utils import check_sampling_strategy
12-
13-
13+
from ..utils._param_validation import validate_params
14+
15+
16+
@validate_params(
17+
{
18+
"X": ["array-like"],
19+
"y": ["array-like"],
20+
"sampling_strategy": [Mapping, callable, None],
21+
"random_state": ["random_state"],
22+
"verbose": ["boolean"],
23+
}
24+
)
1425
def make_imbalance(
1526
X, y, *, sampling_strategy=None, random_state=None, verbose=False, **kwargs
1627
):
@@ -26,7 +37,7 @@ def make_imbalance(
2637
X : {array-like, dataframe} of shape (n_samples, n_features)
2738
Matrix containing the data to be imbalanced.
2839
29-
y : ndarray of shape (n_samples,)
40+
y : array-like of shape (n_samples,)
3041
Corresponding label for each sample in X.
3142
3243
sampling_strategy : dict or callable,
@@ -86,16 +97,10 @@ def make_imbalance(
8697
"""
8798
target_stats = Counter(y)
8899
# restrict ratio to be a dict or a callable
89-
if isinstance(sampling_strategy, dict) or callable(sampling_strategy):
100+
if isinstance(sampling_strategy, Mapping) or callable(sampling_strategy):
90101
sampling_strategy_ = check_sampling_strategy(
91102
sampling_strategy, y, "under-sampling", **kwargs
92103
)
93-
else:
94-
raise ValueError(
95-
f"'sampling_strategy' has to be a dictionary or a "
96-
f"function returning a dictionary. Got {type(sampling_strategy)} "
97-
f"instead."
98-
)
99104

100105
if verbose:
101106
print(f"The original target distribution in the dataset is: {target_stats}")

imblearn/datasets/_zenodo.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
from sklearn.datasets import get_data_home
5555
from sklearn.utils import Bunch, check_random_state
5656

57+
from ..utils._param_validation import validate_params
58+
5759
URL = "https://zenodo.org/record/61452/files/benchmark-imbalanced-learn.tar.gz"
5860
PRE_FILENAME = "x"
5961
POST_FILENAME = "data.npz"
@@ -95,6 +97,16 @@
9597
MAP_ID_NAME[v + 1] = k
9698

9799

100+
@validate_params(
101+
{
102+
"data_home": [None, str],
103+
"filter_data": [None, tuple],
104+
"download_if_missing": ["boolean"],
105+
"random_state": ["random_state"],
106+
"shuffle": ["boolean"],
107+
"verbose": ["boolean"],
108+
}
109+
)
98110
def fetch_datasets(
99111
*,
100112
data_home=None,

imblearn/datasets/tests/test_imbalance.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def iris():
2222
[
2323
({0: -100, 1: 50, 2: 50}, "in a class cannot be negative"),
2424
({0: 10, 1: 70}, "should be less or equal to the original"),
25-
("random-string", "has to be a dictionary or a function"),
2625
],
2726
)
2827
def test_make_imbalance_error(iris, sampling_strategy, err_msg):

imblearn/metrics/_classification.py

Lines changed: 115 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# License: MIT
1616

1717
import functools
18+
import numbers
1819
import warnings
1920
from inspect import signature
2021

@@ -26,7 +27,23 @@
2627
from sklearn.utils.multiclass import unique_labels
2728
from sklearn.utils.validation import check_consistent_length, column_or_1d
2829

30+
from ..utils._param_validation import Interval, StrOptions, validate_params
2931

32+
33+
@validate_params(
34+
{
35+
"y_true": ["array-like"],
36+
"y_pred": ["array-like"],
37+
"labels": ["array-like", None],
38+
"pos_label": [str, numbers.Integral, None],
39+
"average": [
40+
None,
41+
StrOptions({"binary", "micro", "macro", "weighted", "samples"}),
42+
],
43+
"warn_for": ["array-like"],
44+
"sample_weight": ["array-like", None],
45+
}
46+
)
3047
def sensitivity_specificity_support(
3148
y_true,
3249
y_pred,
@@ -57,13 +74,13 @@ def sensitivity_specificity_support(
5774
5875
Parameters
5976
----------
60-
y_true : ndarray of shape (n_samples,)
77+
y_true : array-like of shape (n_samples,)
6178
Ground truth (correct) target values.
6279
63-
y_pred : ndarray of shape (n_samples,)
80+
y_pred : array-like of shape (n_samples,)
6481
Estimated targets as returned by a classifier.
6582
66-
labels : list, default=None
83+
labels : array-like, default=None
6784
The set of labels to include when ``average != 'binary'``, and their
6885
order if ``average is None``. Labels present in the data can be
6986
excluded, for example to calculate a multiclass average ignoring a
@@ -72,8 +89,11 @@ def sensitivity_specificity_support(
7289
labels are column indices. By default, all labels in ``y_true`` and
7390
``y_pred`` are used in sorted order.
7491
75-
pos_label : str or int, default=1
92+
pos_label : str, int or None, default=1
7693
The class to report if ``average='binary'`` and the data is binary.
94+
If ``pos_label is None`` and in binary classification, this function
95+
returns the average sensitivity and specificity if ``average``
96+
is one of ``'weighted'``.
7797
If the data are multiclass, this will be ignored;
7898
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
7999
scores for that label only.
@@ -105,7 +125,7 @@ def sensitivity_specificity_support(
105125
This determines which warnings will be made in the case that this
106126
function is being used to return only one of its metrics.
107127
108-
sample_weight : ndarray of shape (n_samples,), default=None
128+
sample_weight : array-like of shape (n_samples,), default=None
109129
Sample weights.
110130
111131
Returns
@@ -274,6 +294,19 @@ def sensitivity_specificity_support(
274294
return sensitivity, specificity, true_sum
275295

276296

297+
@validate_params(
298+
{
299+
"y_true": ["array-like"],
300+
"y_pred": ["array-like"],
301+
"labels": ["array-like", None],
302+
"pos_label": [str, numbers.Integral, None],
303+
"average": [
304+
None,
305+
StrOptions({"binary", "micro", "macro", "weighted", "samples"}),
306+
],
307+
"sample_weight": ["array-like", None],
308+
}
309+
)
277310
def sensitivity_score(
278311
y_true,
279312
y_pred,
@@ -295,21 +328,23 @@ def sensitivity_score(
295328
296329
Parameters
297330
----------
298-
y_true : ndarray of shape (n_samples,)
331+
y_true : array-like of shape (n_samples,)
299332
Ground truth (correct) target values.
300333
301-
y_pred : ndarray of shape (n_samples,)
334+
y_pred : array-like of shape (n_samples,)
302335
Estimated targets as returned by a classifier.
303336
304-
labels : list, default=None
337+
labels : array-like, default=None
305338
The set of labels to include when ``average != 'binary'``, and their
306339
order if ``average is None``. Labels present in the data can be
307340
excluded, for example to calculate a multiclass average ignoring a
308341
majority negative class, while labels not present in the data will
309342
result in 0 components in a macro average.
310343
311-
pos_label : str or int, default=1
344+
pos_label : str, int or None, default=1
312345
The class to report if ``average='binary'`` and the data is binary.
346+
If ``pos_label is None`` and in binary classification, this function
347+
returns the average sensitivity if ``average`` is one of ``'weighted'``.
313348
If the data are multiclass, this will be ignored;
314349
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
315350
scores for that label only.
@@ -337,7 +372,7 @@ def sensitivity_score(
337372
meaningful for multilabel classification where this differs from
338373
:func:`accuracy_score`).
339374
340-
sample_weight : ndarray of shape (n_samples,), default=None
375+
sample_weight : array-like of shape (n_samples,), default=None
341376
Sample weights.
342377
343378
Returns
@@ -374,6 +409,19 @@ def sensitivity_score(
374409
return s
375410

376411

412+
@validate_params(
413+
{
414+
"y_true": ["array-like"],
415+
"y_pred": ["array-like"],
416+
"labels": ["array-like", None],
417+
"pos_label": [str, numbers.Integral, None],
418+
"average": [
419+
None,
420+
StrOptions({"binary", "micro", "macro", "weighted", "samples"}),
421+
],
422+
"sample_weight": ["array-like", None],
423+
}
424+
)
377425
def specificity_score(
378426
y_true,
379427
y_pred,
@@ -395,21 +443,23 @@ def specificity_score(
395443
396444
Parameters
397445
----------
398-
y_true : ndarray of shape (n_samples,)
446+
y_true : array-like of shape (n_samples,)
399447
Ground truth (correct) target values.
400448
401-
y_pred : ndarray of shape (n_samples,)
449+
y_pred : array-like of shape (n_samples,)
402450
Estimated targets as returned by a classifier.
403451
404-
labels : list, default=None
452+
labels : array-like, default=None
405453
The set of labels to include when ``average != 'binary'``, and their
406454
order if ``average is None``. Labels present in the data can be
407455
excluded, for example to calculate a multiclass average ignoring a
408456
majority negative class, while labels not present in the data will
409457
result in 0 components in a macro average.
410458
411-
pos_label : str or int, default=1
459+
pos_label : str, int or None, default=1
412460
The class to report if ``average='binary'`` and the data is binary.
461+
If ``pos_label is None`` and in binary classification, this function
462+
returns the average specificity if ``average`` is one of ``'weighted'``.
413463
If the data are multiclass, this will be ignored;
414464
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
415465
scores for that label only.
@@ -437,7 +487,7 @@ def specificity_score(
437487
meaningful for multilabel classification where this differs from
438488
:func:`accuracy_score`).
439489
440-
sample_weight : ndarray of shape (n_samples,), default=None
490+
sample_weight : array-like of shape (n_samples,), default=None
441491
Sample weights.
442492
443493
Returns
@@ -474,6 +524,22 @@ def specificity_score(
474524
return s
475525

476526

527+
@validate_params(
528+
{
529+
"y_true": ["array-like"],
530+
"y_pred": ["array-like"],
531+
"labels": ["array-like", None],
532+
"pos_label": [str, numbers.Integral, None],
533+
"average": [
534+
None,
535+
StrOptions(
536+
{"binary", "micro", "macro", "weighted", "samples", "multiclass"}
537+
),
538+
],
539+
"sample_weight": ["array-like", None],
540+
"correction": [Interval(numbers.Real, 0, None, closed="left")],
541+
}
542+
)
477543
def geometric_mean_score(
478544
y_true,
479545
y_pred,
@@ -507,21 +573,24 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
507573
508574
Parameters
509575
----------
510-
y_true : ndarray of shape (n_samples,)
576+
y_true : array-like of shape (n_samples,)
511577
Ground truth (correct) target values.
512578
513-
y_pred : ndarray of shape (n_samples,)
579+
y_pred : array-like of shape (n_samples,)
514580
Estimated targets as returned by a classifier.
515581
516-
labels : list, default=None
582+
labels : array-like, default=None
517583
The set of labels to include when ``average != 'binary'``, and their
518584
order if ``average is None``. Labels present in the data can be
519585
excluded, for example to calculate a multiclass average ignoring a
520586
majority negative class, while labels not present in the data will
521587
result in 0 components in a macro average.
522588
523-
pos_label : str or int, default=1
589+
pos_label : str, int or None, default=1
524590
The class to report if ``average='binary'`` and the data is binary.
591+
If ``pos_label is None`` and in binary classification, this function
592+
returns the average geometric mean if ``average`` is one of
593+
``'weighted'``.
525594
If the data are multiclass, this will be ignored;
526595
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
527596
scores for that label only.
@@ -539,6 +608,8 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
539608
``'macro'``:
540609
Calculate metrics for each label, and find their unweighted
541610
mean. This does not take label imbalance into account.
611+
``'multiclass'``:
612+
No average is taken.
542613
``'weighted'``:
543614
Calculate metrics for each label, and find their average, weighted
544615
by support (the number of true instances for each label). This
@@ -549,7 +620,7 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
549620
meaningful for multilabel classification where this differs from
550621
:func:`accuracy_score`).
551622
552-
sample_weight : ndarray of shape (n_samples,), default=None
623+
sample_weight : array-like of shape (n_samples,), default=None
553624
Sample weights.
554625
555626
correction : float, default=0.0
@@ -658,6 +729,7 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
658729
return gmean
659730

660731

732+
@validate_params({"alpha": [numbers.Real], "squared": ["boolean"]})
661733
def make_index_balanced_accuracy(*, alpha=0.1, squared=True):
662734
"""Balance any scoring function using the index balanced accuracy.
663735
@@ -763,6 +835,22 @@ def compute_score(*args, **kwargs):
763835
return decorate
764836

765837

838+
@validate_params(
839+
{
840+
"y_true": ["array-like"],
841+
"y_pred": ["array-like"],
842+
"labels": ["array-like", None],
843+
"target_names": ["array-like", None],
844+
"sample_weight": ["array-like", None],
845+
"digits": [Interval(numbers.Integral, 0, None, closed="left")],
846+
"alpha": [numbers.Real],
847+
"output_dict": ["boolean"],
848+
"zero_division": [
849+
StrOptions({"warn"}),
850+
Interval(numbers.Integral, 0, 1, closed="both"),
851+
],
852+
}
853+
)
766854
def classification_report_imbalanced(
767855
y_true,
768856
y_pred,
@@ -970,6 +1058,13 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
9701058
return report
9711059

9721060

1061+
@validate_params(
1062+
{
1063+
"y_true": ["array-like"],
1064+
"y_pred": ["array-like"],
1065+
"sample_weight": ["array-like", None],
1066+
}
1067+
)
9731068
def macro_averaged_mean_absolute_error(y_true, y_pred, *, sample_weight=None):
9741069
"""Compute Macro-Averaged MAE for imbalanced ordinal classification.
9751070

0 commit comments

Comments
 (0)