2
2
3
3
from __future__ import division , print_function
4
4
5
+ from functools import partial
6
+
5
7
import numpy as np
6
8
7
9
from numpy .testing import (assert_array_almost_equal , assert_array_equal ,
12
14
from sklearn import datasets
13
15
from sklearn import svm
14
16
17
+ from sklearn .preprocessing import label_binarize
18
+ from sklearn .utils .testing import assert_not_equal
15
19
from sklearn .utils .validation import check_random_state
16
20
17
21
from imblearn .metrics import sensitivity_specificity_support
@@ -94,113 +98,93 @@ def test_sensitivity_specificity_support_binary():
94
98
assert_array_almost_equal (spec , 0.88 , 2 )
95
99
96
100
97
- # def test_precision_recall_f_binary_single_class():
98
- # # Test precision, recall and F1 score behave with a single positive or
99
- # # negative class
100
- # # Such a case may occur with non-stratified cross-validation
101
- # assert_equal(1., precision_score([1, 1], [1, 1]))
102
- # assert_equal(1., recall_score([1, 1], [1, 1]))
103
- # assert_equal(1., f1_score([1, 1], [1, 1]))
104
-
105
- # assert_equal(0., precision_score([-1, -1], [-1, -1]))
106
- # assert_equal(0., recall_score([-1, -1], [-1, -1]))
107
- # assert_equal(0., f1_score([-1, -1], [-1, -1]))
108
-
109
-
110
- # @ignore_warnings
111
- # def test_precision_recall_f_extra_labels():
112
- # # Test handling of explicit additional (not in input) labels to PRF
113
- # y_true = [1, 3, 3, 2]
114
- # y_pred = [1, 1, 3, 2]
115
- # y_true_bin = label_binarize(y_true, classes=np.arange(5))
116
- # y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
117
- # data = [(y_true, y_pred),
118
- # (y_true_bin, y_pred_bin)]
119
-
120
- # for i, (y_true, y_pred) in enumerate(data):
121
- # # No average: zeros in array
122
- # actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4],
123
- # average=None)
124
- # assert_array_almost_equal([0., 1., 1., .5, 0.], actual)
125
-
126
- # # Macro average is changed
127
- # actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4],
128
- # average='macro')
129
- # assert_array_almost_equal(np.mean([0., 1., 1., .5, 0.]), actual)
130
-
131
- # # No effect otheriwse
132
- # for average in ['micro', 'weighted', 'samples']:
133
- # if average == 'samples' and i == 0:
134
- # continue
135
- # assert_almost_equal(recall_score(y_true, y_pred,
136
- # labels=[0, 1, 2, 3, 4],
137
- # average=average),
138
- # recall_score(y_true, y_pred, labels=None,
139
- # average=average))
140
-
141
- # # Error when introducing invalid label in multilabel case
142
- # # (although it would only affect performance if average='macro'/None)
143
- # for average in [None, 'macro', 'micro', 'samples']:
144
- # assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin,
145
- # labels=np.arange(6), average=average)
146
- # assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin,
147
- # labels=np.arange(-1, 4), average=average)
148
-
149
-
150
- # @ignore_warnings
151
- # def test_precision_recall_f_ignored_labels():
152
- # # Test a subset of labels may be requested for PRF
153
- # y_true = [1, 1, 2, 3]
154
- # y_pred = [1, 3, 3, 3]
155
- # y_true_bin = label_binarize(y_true, classes=np.arange(5))
156
- # y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
157
- # data = [(y_true, y_pred),
158
- # (y_true_bin, y_pred_bin)]
159
-
160
- # for i, (y_true, y_pred) in enumerate(data):
161
- # recall_13 = partial(recall_score, y_true, y_pred, labels=[1, 3])
162
- # recall_all = partial(recall_score, y_true, y_pred, labels=None)
163
-
164
- # assert_array_almost_equal([.5, 1.], recall_13(average=None))
165
- # assert_almost_equal((.5 + 1.) / 2, recall_13(average='macro'))
166
- # assert_almost_equal((.5 * 2 + 1. * 1) / 3,
167
- # recall_13(average='weighted'))
168
- # assert_almost_equal(2. / 3, recall_13(average='micro'))
169
-
170
- # # ensure the above were meaningful tests:
171
- # for average in ['macro', 'weighted', 'micro']:
172
- # assert_not_equal(recall_13(average=average),
173
- # recall_all(average=average))
174
-
175
-
176
- # @ignore_warnings
177
- # def test_precision_recall_fscore_support_errors():
178
- # y_true, y_pred, _ = make_prediction(binary=True)
179
-
180
- # # Bad beta
181
- # assert_raises(ValueError, precision_recall_fscore_support,
182
- # y_true, y_pred, beta=0.0)
183
-
184
- # # Bad pos_label
185
- # assert_raises(ValueError, precision_recall_fscore_support,
186
- # y_true, y_pred, pos_label=2, average='binary')
187
-
188
- # # Bad average option
189
- # assert_raises(ValueError, precision_recall_fscore_support,
190
- # [0, 1, 2], [1, 2, 0], average='mega')
191
-
192
-
193
- # def test_precision_recall_f_unused_pos_label():
194
- # # Check warning that pos_label unused when set to non-default value
195
- # # but average != 'binary'; even if data is binary.
196
- # assert_warns_message(UserWarning,
197
- # "Note that pos_label (set to 2) is "
198
- # "ignored when average != 'binary' (got 'macro'). You "
199
- # "may use labels=[pos_label] to specify a single "
200
- # "positive class.", precision_recall_fscore_support,
201
- # [1, 2, 1], [1, 2, 2], pos_label=2, average='macro')
202
-
203
- def test_precision_recall_f1_score_multiclass ():
101
+ def test_sensitivity_specificity_binary_single_class ():
102
+ # Test sensitivity and specificity score behave with a single positive or
103
+ # negative class
104
+ # Such a case may occur with non-stratified cross-validation
105
+ assert_equal (1. , sensitivity_score ([1 , 1 ], [1 , 1 ]))
106
+ assert_equal (0. , specificity_score ([1 , 1 ], [1 , 1 ]))
107
+
108
+ assert_equal (0. , sensitivity_score ([- 1 , - 1 ], [- 1 , - 1 ]))
109
+ assert_equal (0. , specificity_score ([- 1 , - 1 ], [- 1 , - 1 ]))
110
+
111
+
112
+ def test_sensitivity_specificity_error_multilabels ():
113
+ # Test either if an error is raised when the input are multilabels
114
+ y_true = [1 , 3 , 3 , 2 ]
115
+ y_pred = [1 , 1 , 3 , 2 ]
116
+ y_true_bin = label_binarize (y_true , classes = np .arange (5 ))
117
+ y_pred_bin = label_binarize (y_pred , classes = np .arange (5 ))
118
+
119
+ assert_raises (ValueError , sensitivity_score , y_true_bin , y_pred_bin )
120
+
121
+ @ignore_warnings
122
+ def test_sensitivity_specifiicity_extra_labels ():
123
+ # Test handling of explicit additional (not in input) labels to SS
124
+ y_true = [1 , 3 , 3 , 2 ]
125
+ y_pred = [1 , 1 , 3 , 2 ]
126
+
127
+ actual = sensitivity_score (y_true , y_pred , labels = [0 , 1 , 2 , 3 , 4 ],
128
+ average = None )
129
+ assert_array_almost_equal ([0. , 1. , 1. , .5 , 0. ], actual )
130
+
131
+ # Macro average is changed
132
+ actual = sensitivity_score (y_true , y_pred , labels = [0 , 1 , 2 , 3 , 4 ],
133
+ average = 'macro' )
134
+ assert_array_almost_equal (np .mean ([0. , 1. , 1. , .5 , 0. ]), actual )
135
+
136
+ # Weighted average is changed
137
+ assert_almost_equal (sensitivity_score (y_true , y_pred ,
138
+ labels = [0 , 1 , 2 , 3 , 4 ],
139
+ average = 'weighted' ),
140
+ sensitivity_score (y_true , y_pred , labels = None ,
141
+ average = 'weighted' ))
142
+
143
+ @ignore_warnings
144
+ def test_sensitivity_specificity_f_ignored_labels ():
145
+ # Test a subset of labels may be requested for SS
146
+ y_true = [1 , 1 , 2 , 3 ]
147
+ y_pred = [1 , 3 , 3 , 3 ]
148
+
149
+ sensitivity_13 = partial (sensitivity_score , y_true , y_pred , labels = [1 , 3 ])
150
+ sensitivity_all = partial (sensitivity_score , y_true , y_pred , labels = None )
151
+
152
+ assert_array_almost_equal ([.5 , 1. ], sensitivity_13 (average = None ))
153
+ assert_almost_equal ((.5 + 1. ) / 2 , sensitivity_13 (average = 'macro' ))
154
+ assert_almost_equal ((.5 * 2 + 1. * 1 ) / 3 ,
155
+ sensitivity_13 (average = 'weighted' ))
156
+
157
+ # ensure the above were meaningful tests:
158
+ for average in ['macro' , 'weighted' ]:
159
+ assert_not_equal (sensitivity_13 (average = average ),
160
+ sensitivity_all (average = average ))
161
+
162
+
163
+ @ignore_warnings
164
+ def test_sensitivity_specificity_support_errors ():
165
+ y_true , y_pred , _ = make_prediction (binary = True )
166
+
167
+ # Bad pos_label
168
+ assert_raises (ValueError , sensitivity_specificity_support ,
169
+ y_true , y_pred , pos_label = 2 , average = 'binary' )
170
+
171
+ # Bad average option
172
+ assert_raises (ValueError , sensitivity_specificity_support ,
173
+ [0 , 1 , 2 ], [1 , 2 , 0 ], average = 'mega' )
174
+
175
+
176
+ def test_sensitivity_specificity_unused_pos_label ():
177
+ # Check warning that pos_label unused when set to non-default value
178
+ # but average != 'binary'; even if data is binary.
179
+ assert_warns_message (UserWarning ,
180
+ "Note that pos_label (set to 2) is "
181
+ "ignored when average != 'binary' (got 'macro'). You "
182
+ "may use labels=[pos_label] to specify a single "
183
+ "positive class." , sensitivity_specificity_support ,
184
+ [1 , 2 , 1 ], [1 , 2 , 2 ], pos_label = 2 , average = 'macro' )
185
+
186
+
187
+ def test_sensitivity_specificity_multiclass ():
204
188
# Test Precision Recall and F1 Score for multiclass classification task
205
189
y_true , y_pred , _ = make_prediction (binary = False )
206
190
@@ -212,15 +196,6 @@ def test_precision_recall_f1_score_multiclass():
212
196
assert_array_equal (supp , [24 , 31 , 20 ])
213
197
214
198
# averaging tests
215
- spec = specificity_score (y_true , y_pred , pos_label = 1 , average = 'micro' )
216
- assert_array_almost_equal (spec , 0.77 , 2 )
217
-
218
- sens = sensitivity_score (y_true , y_pred , average = 'micro' )
219
- assert_array_almost_equal (sens , 0.53 , 2 )
220
-
221
- spec = specificity_score (y_true , y_pred , average = 'macro' )
222
- assert_array_almost_equal (spec , 0.77 , 2 )
223
-
224
199
sens = sensitivity_score (y_true , y_pred , average = 'macro' )
225
200
assert_array_almost_equal (sens , 0.60 , 2 )
226
201
0 commit comments