15
15
# License: MIT
16
16
17
17
import functools
18
+ import numbers
18
19
import warnings
19
20
from inspect import signature
20
21
26
27
from sklearn .utils .multiclass import unique_labels
27
28
from sklearn .utils .validation import check_consistent_length , column_or_1d
28
29
30
+ from ..utils ._param_validation import Interval , StrOptions , validate_params
29
31
32
+
33
+ @validate_params (
34
+ {
35
+ "y_true" : ["array-like" ],
36
+ "y_pred" : ["array-like" ],
37
+ "labels" : ["array-like" , None ],
38
+ "pos_label" : [str , numbers .Integral , None ],
39
+ "average" : [
40
+ None ,
41
+ StrOptions ({"binary" , "micro" , "macro" , "weighted" , "samples" }),
42
+ ],
43
+ "warn_for" : ["array-like" ],
44
+ "sample_weight" : ["array-like" , None ],
45
+ }
46
+ )
30
47
def sensitivity_specificity_support (
31
48
y_true ,
32
49
y_pred ,
@@ -57,13 +74,13 @@ def sensitivity_specificity_support(
57
74
58
75
Parameters
59
76
----------
60
- y_true : ndarray of shape (n_samples,)
77
+ y_true : array-like of shape (n_samples,)
61
78
Ground truth (correct) target values.
62
79
63
- y_pred : ndarray of shape (n_samples,)
80
+ y_pred : array-like of shape (n_samples,)
64
81
Estimated targets as returned by a classifier.
65
82
66
- labels : list , default=None
83
+ labels : array-like , default=None
67
84
The set of labels to include when ``average != 'binary'``, and their
68
85
order if ``average is None``. Labels present in the data can be
69
86
excluded, for example to calculate a multiclass average ignoring a
@@ -72,8 +89,11 @@ def sensitivity_specificity_support(
72
89
labels are column indices. By default, all labels in ``y_true`` and
73
90
``y_pred`` are used in sorted order.
74
91
75
- pos_label : str or int , default=1
92
+ pos_label : str, int or None , default=1
76
93
The class to report if ``average='binary'`` and the data is binary.
94
+ If ``pos_label is None`` and in binary classification, this function
95
+ returns the average sensitivity and specificity if ``average``
96
+ is one of ``'weighted'``.
77
97
If the data are multiclass, this will be ignored;
78
98
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
79
99
scores for that label only.
@@ -105,7 +125,7 @@ def sensitivity_specificity_support(
105
125
This determines which warnings will be made in the case that this
106
126
function is being used to return only one of its metrics.
107
127
108
- sample_weight : ndarray of shape (n_samples,), default=None
128
+ sample_weight : array-like of shape (n_samples,), default=None
109
129
Sample weights.
110
130
111
131
Returns
@@ -274,6 +294,19 @@ def sensitivity_specificity_support(
274
294
return sensitivity , specificity , true_sum
275
295
276
296
297
+ @validate_params (
298
+ {
299
+ "y_true" : ["array-like" ],
300
+ "y_pred" : ["array-like" ],
301
+ "labels" : ["array-like" , None ],
302
+ "pos_label" : [str , numbers .Integral , None ],
303
+ "average" : [
304
+ None ,
305
+ StrOptions ({"binary" , "micro" , "macro" , "weighted" , "samples" }),
306
+ ],
307
+ "sample_weight" : ["array-like" , None ],
308
+ }
309
+ )
277
310
def sensitivity_score (
278
311
y_true ,
279
312
y_pred ,
@@ -295,21 +328,23 @@ def sensitivity_score(
295
328
296
329
Parameters
297
330
----------
298
- y_true : ndarray of shape (n_samples,)
331
+ y_true : array-like of shape (n_samples,)
299
332
Ground truth (correct) target values.
300
333
301
- y_pred : ndarray of shape (n_samples,)
334
+ y_pred : array-like of shape (n_samples,)
302
335
Estimated targets as returned by a classifier.
303
336
304
- labels : list , default=None
337
+ labels : array-like , default=None
305
338
The set of labels to include when ``average != 'binary'``, and their
306
339
order if ``average is None``. Labels present in the data can be
307
340
excluded, for example to calculate a multiclass average ignoring a
308
341
majority negative class, while labels not present in the data will
309
342
result in 0 components in a macro average.
310
343
311
- pos_label : str or int , default=1
344
+ pos_label : str, int or None , default=1
312
345
The class to report if ``average='binary'`` and the data is binary.
346
+ If ``pos_label is None`` and in binary classification, this function
347
+ returns the average sensitivity if ``average`` is one of ``'weighted'``.
313
348
If the data are multiclass, this will be ignored;
314
349
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
315
350
scores for that label only.
@@ -337,7 +372,7 @@ def sensitivity_score(
337
372
meaningful for multilabel classification where this differs from
338
373
:func:`accuracy_score`).
339
374
340
- sample_weight : ndarray of shape (n_samples,), default=None
375
+ sample_weight : array-like of shape (n_samples,), default=None
341
376
Sample weights.
342
377
343
378
Returns
@@ -374,6 +409,19 @@ def sensitivity_score(
374
409
return s
375
410
376
411
412
+ @validate_params (
413
+ {
414
+ "y_true" : ["array-like" ],
415
+ "y_pred" : ["array-like" ],
416
+ "labels" : ["array-like" , None ],
417
+ "pos_label" : [str , numbers .Integral , None ],
418
+ "average" : [
419
+ None ,
420
+ StrOptions ({"binary" , "micro" , "macro" , "weighted" , "samples" }),
421
+ ],
422
+ "sample_weight" : ["array-like" , None ],
423
+ }
424
+ )
377
425
def specificity_score (
378
426
y_true ,
379
427
y_pred ,
@@ -395,21 +443,23 @@ def specificity_score(
395
443
396
444
Parameters
397
445
----------
398
- y_true : ndarray of shape (n_samples,)
446
+ y_true : array-like of shape (n_samples,)
399
447
Ground truth (correct) target values.
400
448
401
- y_pred : ndarray of shape (n_samples,)
449
+ y_pred : array-like of shape (n_samples,)
402
450
Estimated targets as returned by a classifier.
403
451
404
- labels : list , default=None
452
+ labels : array-like , default=None
405
453
The set of labels to include when ``average != 'binary'``, and their
406
454
order if ``average is None``. Labels present in the data can be
407
455
excluded, for example to calculate a multiclass average ignoring a
408
456
majority negative class, while labels not present in the data will
409
457
result in 0 components in a macro average.
410
458
411
- pos_label : str or int , default=1
459
+ pos_label : str, int or None , default=1
412
460
The class to report if ``average='binary'`` and the data is binary.
461
+ If ``pos_label is None`` and in binary classification, this function
462
+ returns the average specificity if ``average`` is one of ``'weighted'``.
413
463
If the data are multiclass, this will be ignored;
414
464
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
415
465
scores for that label only.
@@ -437,7 +487,7 @@ def specificity_score(
437
487
meaningful for multilabel classification where this differs from
438
488
:func:`accuracy_score`).
439
489
440
- sample_weight : ndarray of shape (n_samples,), default=None
490
+ sample_weight : array-like of shape (n_samples,), default=None
441
491
Sample weights.
442
492
443
493
Returns
@@ -474,6 +524,22 @@ def specificity_score(
474
524
return s
475
525
476
526
527
+ @validate_params (
528
+ {
529
+ "y_true" : ["array-like" ],
530
+ "y_pred" : ["array-like" ],
531
+ "labels" : ["array-like" , None ],
532
+ "pos_label" : [str , numbers .Integral , None ],
533
+ "average" : [
534
+ None ,
535
+ StrOptions (
536
+ {"binary" , "micro" , "macro" , "weighted" , "samples" , "multiclass" }
537
+ ),
538
+ ],
539
+ "sample_weight" : ["array-like" , None ],
540
+ "correction" : [Interval (numbers .Real , 0 , None , closed = "left" )],
541
+ }
542
+ )
477
543
def geometric_mean_score (
478
544
y_true ,
479
545
y_pred ,
@@ -507,21 +573,24 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
507
573
508
574
Parameters
509
575
----------
510
- y_true : ndarray of shape (n_samples,)
576
+ y_true : array-like of shape (n_samples,)
511
577
Ground truth (correct) target values.
512
578
513
- y_pred : ndarray of shape (n_samples,)
579
+ y_pred : array-like of shape (n_samples,)
514
580
Estimated targets as returned by a classifier.
515
581
516
- labels : list , default=None
582
+ labels : array-like , default=None
517
583
The set of labels to include when ``average != 'binary'``, and their
518
584
order if ``average is None``. Labels present in the data can be
519
585
excluded, for example to calculate a multiclass average ignoring a
520
586
majority negative class, while labels not present in the data will
521
587
result in 0 components in a macro average.
522
588
523
- pos_label : str or int , default=1
589
+ pos_label : str, int or None , default=1
524
590
The class to report if ``average='binary'`` and the data is binary.
591
+ If ``pos_label is None`` and in binary classification, this function
592
+ returns the average geometric mean if ``average`` is one of
593
+ ``'weighted'``.
525
594
If the data are multiclass, this will be ignored;
526
595
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
527
596
scores for that label only.
@@ -539,6 +608,8 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
539
608
``'macro'``:
540
609
Calculate metrics for each label, and find their unweighted
541
610
mean. This does not take label imbalance into account.
611
+ ``'multiclass'``:
612
+ No average is taken.
542
613
``'weighted'``:
543
614
Calculate metrics for each label, and find their average, weighted
544
615
by support (the number of true instances for each label). This
@@ -549,7 +620,7 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
549
620
meaningful for multilabel classification where this differs from
550
621
:func:`accuracy_score`).
551
622
552
- sample_weight : ndarray of shape (n_samples,), default=None
623
+ sample_weight : array-like of shape (n_samples,), default=None
553
624
Sample weights.
554
625
555
626
correction : float, default=0.0
@@ -658,6 +729,7 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
658
729
return gmean
659
730
660
731
732
+ @validate_params ({"alpha" : [numbers .Real ], "squared" : ["boolean" ]})
661
733
def make_index_balanced_accuracy (* , alpha = 0.1 , squared = True ):
662
734
"""Balance any scoring function using the index balanced accuracy.
663
735
@@ -763,6 +835,22 @@ def compute_score(*args, **kwargs):
763
835
return decorate
764
836
765
837
838
+ @validate_params (
839
+ {
840
+ "y_true" : ["array-like" ],
841
+ "y_pred" : ["array-like" ],
842
+ "labels" : ["array-like" , None ],
843
+ "target_names" : ["array-like" , None ],
844
+ "sample_weight" : ["array-like" , None ],
845
+ "digits" : [Interval (numbers .Integral , 0 , None , closed = "left" )],
846
+ "alpha" : [numbers .Real ],
847
+ "output_dict" : ["boolean" ],
848
+ "zero_division" : [
849
+ StrOptions ({"warn" }),
850
+ Interval (numbers .Integral , 0 , 1 , closed = "both" ),
851
+ ],
852
+ }
853
+ )
766
854
def classification_report_imbalanced (
767
855
y_true ,
768
856
y_pred ,
@@ -970,6 +1058,13 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
970
1058
return report
971
1059
972
1060
1061
+ @validate_params (
1062
+ {
1063
+ "y_true" : ["array-like" ],
1064
+ "y_pred" : ["array-like" ],
1065
+ "sample_weight" : ["array-like" , None ],
1066
+ }
1067
+ )
973
1068
def macro_averaged_mean_absolute_error (y_true , y_pred , * , sample_weight = None ):
974
1069
"""Compute Macro-Averaged MAE for imbalanced ordinal classification.
975
1070
0 commit comments