|
14 | 14 |
|
15 | 15 | from sklearn.utils.validation import check_random_state
|
16 | 16 |
|
| 17 | +from imblearn.metrics import sensitivity_specificity_support |
| 18 | + |
17 | 19 | RND_SEED = 42
|
18 | 20 |
|
19 | 21 | ###############################################################################
|
@@ -67,138 +69,138 @@ def make_prediction(dataset=None, binary=False):
|
67 | 69 | ###############################################################################
|
68 | 70 | # Tests
|
69 | 71 |
|
70 |
| -def test_precision_recall_f1_score_binary(): |
71 |
| - # Test Precision Recall and F1 Score for binary classification task |
| 72 | +def test_sensitivity_specificity_support_binary(): |
| 73 | + """Test the sensitivity specificity for binary classification task""" |
72 | 74 | y_true, y_pred, _ = make_prediction(binary=True)
|
73 | 75 |
|
74 | 76 | # detailed measures for each class
|
75 |
| - p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average=None) |
76 |
| - assert_array_almost_equal(p, [0.73, 0.85], 2) |
77 |
| - assert_array_almost_equal(r, [0.88, 0.68], 2) |
78 |
| - assert_array_almost_equal(f, [0.80, 0.76], 2) |
79 |
| - assert_array_equal(s, [25, 25]) |
80 |
| - |
81 |
| - # individual scoring function that can be used for grid search: in the |
82 |
| - # binary class case the score is the value of the measure for the positive |
83 |
| - # class (e.g. label == 1). This is deprecated for average != 'binary'. |
84 |
| - for kwargs, my_assert in [({}, assert_no_warnings), |
85 |
| - ({'average': 'binary'}, assert_no_warnings)]: |
86 |
| - ps = my_assert(precision_score, y_true, y_pred, **kwargs) |
87 |
| - assert_array_almost_equal(ps, 0.85, 2) |
88 |
| - |
89 |
| - rs = my_assert(recall_score, y_true, y_pred, **kwargs) |
90 |
| - assert_array_almost_equal(rs, 0.68, 2) |
91 |
| - |
92 |
| - fs = my_assert(f1_score, y_true, y_pred, **kwargs) |
93 |
| - assert_array_almost_equal(fs, 0.76, 2) |
94 |
| - |
95 |
| - assert_almost_equal(my_assert(fbeta_score, y_true, y_pred, beta=2, |
96 |
| - **kwargs), |
97 |
| - (1 + 2 ** 2) * ps * rs / (2 ** 2 * ps + rs), 2) |
98 |
| - |
99 |
| - |
100 |
| -def test_precision_recall_f_binary_single_class(): |
101 |
| - # Test precision, recall and F1 score behave with a single positive or |
102 |
| - # negative class |
103 |
| - # Such a case may occur with non-stratified cross-validation |
104 |
| - assert_equal(1., precision_score([1, 1], [1, 1])) |
105 |
| - assert_equal(1., recall_score([1, 1], [1, 1])) |
106 |
| - assert_equal(1., f1_score([1, 1], [1, 1])) |
107 |
| - |
108 |
| - assert_equal(0., precision_score([-1, -1], [-1, -1])) |
109 |
| - assert_equal(0., recall_score([-1, -1], [-1, -1])) |
110 |
| - assert_equal(0., f1_score([-1, -1], [-1, -1])) |
111 |
| - |
112 |
| - |
113 |
| -@ignore_warnings |
114 |
| -def test_precision_recall_f_extra_labels(): |
115 |
| - # Test handling of explicit additional (not in input) labels to PRF |
116 |
| - y_true = [1, 3, 3, 2] |
117 |
| - y_pred = [1, 1, 3, 2] |
118 |
| - y_true_bin = label_binarize(y_true, classes=np.arange(5)) |
119 |
| - y_pred_bin = label_binarize(y_pred, classes=np.arange(5)) |
120 |
| - data = [(y_true, y_pred), |
121 |
| - (y_true_bin, y_pred_bin)] |
122 |
| - |
123 |
| - for i, (y_true, y_pred) in enumerate(data): |
124 |
| - # No average: zeros in array |
125 |
| - actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4], |
126 |
| - average=None) |
127 |
| - assert_array_almost_equal([0., 1., 1., .5, 0.], actual) |
128 |
| - |
129 |
| - # Macro average is changed |
130 |
| - actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4], |
131 |
| - average='macro') |
132 |
| - assert_array_almost_equal(np.mean([0., 1., 1., .5, 0.]), actual) |
133 |
| - |
134 |
| - # No effect otheriwse |
135 |
| - for average in ['micro', 'weighted', 'samples']: |
136 |
| - if average == 'samples' and i == 0: |
137 |
| - continue |
138 |
| - assert_almost_equal(recall_score(y_true, y_pred, |
139 |
| - labels=[0, 1, 2, 3, 4], |
140 |
| - average=average), |
141 |
| - recall_score(y_true, y_pred, labels=None, |
142 |
| - average=average)) |
143 |
| - |
144 |
| - # Error when introducing invalid label in multilabel case |
145 |
| - # (although it would only affect performance if average='macro'/None) |
146 |
| - for average in [None, 'macro', 'micro', 'samples']: |
147 |
| - assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin, |
148 |
| - labels=np.arange(6), average=average) |
149 |
| - assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin, |
150 |
| - labels=np.arange(-1, 4), average=average) |
151 |
| - |
152 |
| - |
153 |
| -@ignore_warnings |
154 |
| -def test_precision_recall_f_ignored_labels(): |
155 |
| - # Test a subset of labels may be requested for PRF |
156 |
| - y_true = [1, 1, 2, 3] |
157 |
| - y_pred = [1, 3, 3, 3] |
158 |
| - y_true_bin = label_binarize(y_true, classes=np.arange(5)) |
159 |
| - y_pred_bin = label_binarize(y_pred, classes=np.arange(5)) |
160 |
| - data = [(y_true, y_pred), |
161 |
| - (y_true_bin, y_pred_bin)] |
162 |
| - |
163 |
| - for i, (y_true, y_pred) in enumerate(data): |
164 |
| - recall_13 = partial(recall_score, y_true, y_pred, labels=[1, 3]) |
165 |
| - recall_all = partial(recall_score, y_true, y_pred, labels=None) |
166 |
| - |
167 |
| - assert_array_almost_equal([.5, 1.], recall_13(average=None)) |
168 |
| - assert_almost_equal((.5 + 1.) / 2, recall_13(average='macro')) |
169 |
| - assert_almost_equal((.5 * 2 + 1. * 1) / 3, |
170 |
| - recall_13(average='weighted')) |
171 |
| - assert_almost_equal(2. / 3, recall_13(average='micro')) |
172 |
| - |
173 |
| - # ensure the above were meaningful tests: |
174 |
| - for average in ['macro', 'weighted', 'micro']: |
175 |
| - assert_not_equal(recall_13(average=average), |
176 |
| - recall_all(average=average)) |
177 |
| - |
178 |
| - |
179 |
| -@ignore_warnings |
180 |
| -def test_precision_recall_fscore_support_errors(): |
181 |
| - y_true, y_pred, _ = make_prediction(binary=True) |
182 |
| - |
183 |
| - # Bad beta |
184 |
| - assert_raises(ValueError, precision_recall_fscore_support, |
185 |
| - y_true, y_pred, beta=0.0) |
186 |
| - |
187 |
| - # Bad pos_label |
188 |
| - assert_raises(ValueError, precision_recall_fscore_support, |
189 |
| - y_true, y_pred, pos_label=2, average='binary') |
190 |
| - |
191 |
| - # Bad average option |
192 |
| - assert_raises(ValueError, precision_recall_fscore_support, |
193 |
| - [0, 1, 2], [1, 2, 0], average='mega') |
194 |
| - |
195 |
| - |
196 |
| -def test_precision_recall_f_unused_pos_label(): |
197 |
| - # Check warning that pos_label unused when set to non-default value |
198 |
| - # but average != 'binary'; even if data is binary. |
199 |
| - assert_warns_message(UserWarning, |
200 |
| - "Note that pos_label (set to 2) is " |
201 |
| - "ignored when average != 'binary' (got 'macro'). You " |
202 |
| - "may use labels=[pos_label] to specify a single " |
203 |
| - "positive class.", precision_recall_fscore_support, |
204 |
| - [1, 2, 1], [1, 2, 2], pos_label=2, average='macro') |
| 77 | + sens, spec, supp = sensitivity_specificity_support(y_true, y_pred, |
| 78 | + average=None) |
| 79 | + assert_array_almost_equal(sens, [0.88, 0.68], 2) |
| 80 | + assert_array_almost_equal(spec, [0.73, 0.85], 2) |
| 81 | + assert_array_equal(supp, [25, 25]) |
| 82 | + |
| 83 | + # # individual scoring function that can be used for grid search: in the |
| 84 | + # # binary class case the score is the value of the measure for the positive |
| 85 | + # # class (e.g. label == 1). This is deprecated for average != 'binary'. |
| 86 | + # for kwargs, my_assert in [({}, assert_no_warnings), |
| 87 | + # ({'average': 'binary'}, assert_no_warnings)]: |
| 88 | + # ps = my_assert(precision_score, y_true, y_pred, **kwargs) |
| 89 | + # assert_array_almost_equal(ps, 0.85, 2) |
| 90 | + |
| 91 | + # rs = my_assert(recall_score, y_true, y_pred, **kwargs) |
| 92 | + # assert_array_almost_equal(rs, 0.68, 2) |
| 93 | + |
| 94 | + # fs = my_assert(f1_score, y_true, y_pred, **kwargs) |
| 95 | + # assert_array_almost_equal(fs, 0.76, 2) |
| 96 | + |
| 97 | + # assert_almost_equal(my_assert(fbeta_score, y_true, y_pred, beta=2, |
| 98 | + # **kwargs), |
| 99 | + # (1 + 2 ** 2) * ps * rs / (2 ** 2 * ps + rs), 2) |
| 100 | + |
| 101 | + |
| 102 | +# def test_precision_recall_f_binary_single_class(): |
| 103 | +# # Test precision, recall and F1 score behave with a single positive or |
| 104 | +# # negative class |
| 105 | +# # Such a case may occur with non-stratified cross-validation |
| 106 | +# assert_equal(1., precision_score([1, 1], [1, 1])) |
| 107 | +# assert_equal(1., recall_score([1, 1], [1, 1])) |
| 108 | +# assert_equal(1., f1_score([1, 1], [1, 1])) |
| 109 | + |
| 110 | +# assert_equal(0., precision_score([-1, -1], [-1, -1])) |
| 111 | +# assert_equal(0., recall_score([-1, -1], [-1, -1])) |
| 112 | +# assert_equal(0., f1_score([-1, -1], [-1, -1])) |
| 113 | + |
| 114 | + |
| 115 | +# @ignore_warnings |
| 116 | +# def test_precision_recall_f_extra_labels(): |
| 117 | +# # Test handling of explicit additional (not in input) labels to PRF |
| 118 | +# y_true = [1, 3, 3, 2] |
| 119 | +# y_pred = [1, 1, 3, 2] |
| 120 | +# y_true_bin = label_binarize(y_true, classes=np.arange(5)) |
| 121 | +# y_pred_bin = label_binarize(y_pred, classes=np.arange(5)) |
| 122 | +# data = [(y_true, y_pred), |
| 123 | +# (y_true_bin, y_pred_bin)] |
| 124 | + |
| 125 | +# for i, (y_true, y_pred) in enumerate(data): |
| 126 | +# # No average: zeros in array |
| 127 | +# actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4], |
| 128 | +# average=None) |
| 129 | +# assert_array_almost_equal([0., 1., 1., .5, 0.], actual) |
| 130 | + |
| 131 | +# # Macro average is changed |
| 132 | +# actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4], |
| 133 | +# average='macro') |
| 134 | +# assert_array_almost_equal(np.mean([0., 1., 1., .5, 0.]), actual) |
| 135 | + |
| 136 | +# # No effect otheriwse |
| 137 | +# for average in ['micro', 'weighted', 'samples']: |
| 138 | +# if average == 'samples' and i == 0: |
| 139 | +# continue |
| 140 | +# assert_almost_equal(recall_score(y_true, y_pred, |
| 141 | +# labels=[0, 1, 2, 3, 4], |
| 142 | +# average=average), |
| 143 | +# recall_score(y_true, y_pred, labels=None, |
| 144 | +# average=average)) |
| 145 | + |
| 146 | +# # Error when introducing invalid label in multilabel case |
| 147 | +# # (although it would only affect performance if average='macro'/None) |
| 148 | +# for average in [None, 'macro', 'micro', 'samples']: |
| 149 | +# assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin, |
| 150 | +# labels=np.arange(6), average=average) |
| 151 | +# assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin, |
| 152 | +# labels=np.arange(-1, 4), average=average) |
| 153 | + |
| 154 | + |
| 155 | +# @ignore_warnings |
| 156 | +# def test_precision_recall_f_ignored_labels(): |
| 157 | +# # Test a subset of labels may be requested for PRF |
| 158 | +# y_true = [1, 1, 2, 3] |
| 159 | +# y_pred = [1, 3, 3, 3] |
| 160 | +# y_true_bin = label_binarize(y_true, classes=np.arange(5)) |
| 161 | +# y_pred_bin = label_binarize(y_pred, classes=np.arange(5)) |
| 162 | +# data = [(y_true, y_pred), |
| 163 | +# (y_true_bin, y_pred_bin)] |
| 164 | + |
| 165 | +# for i, (y_true, y_pred) in enumerate(data): |
| 166 | +# recall_13 = partial(recall_score, y_true, y_pred, labels=[1, 3]) |
| 167 | +# recall_all = partial(recall_score, y_true, y_pred, labels=None) |
| 168 | + |
| 169 | +# assert_array_almost_equal([.5, 1.], recall_13(average=None)) |
| 170 | +# assert_almost_equal((.5 + 1.) / 2, recall_13(average='macro')) |
| 171 | +# assert_almost_equal((.5 * 2 + 1. * 1) / 3, |
| 172 | +# recall_13(average='weighted')) |
| 173 | +# assert_almost_equal(2. / 3, recall_13(average='micro')) |
| 174 | + |
| 175 | +# # ensure the above were meaningful tests: |
| 176 | +# for average in ['macro', 'weighted', 'micro']: |
| 177 | +# assert_not_equal(recall_13(average=average), |
| 178 | +# recall_all(average=average)) |
| 179 | + |
| 180 | + |
| 181 | +# @ignore_warnings |
| 182 | +# def test_precision_recall_fscore_support_errors(): |
| 183 | +# y_true, y_pred, _ = make_prediction(binary=True) |
| 184 | + |
| 185 | +# # Bad beta |
| 186 | +# assert_raises(ValueError, precision_recall_fscore_support, |
| 187 | +# y_true, y_pred, beta=0.0) |
| 188 | + |
| 189 | +# # Bad pos_label |
| 190 | +# assert_raises(ValueError, precision_recall_fscore_support, |
| 191 | +# y_true, y_pred, pos_label=2, average='binary') |
| 192 | + |
| 193 | +# # Bad average option |
| 194 | +# assert_raises(ValueError, precision_recall_fscore_support, |
| 195 | +# [0, 1, 2], [1, 2, 0], average='mega') |
| 196 | + |
| 197 | + |
| 198 | +# def test_precision_recall_f_unused_pos_label(): |
| 199 | +# # Check warning that pos_label unused when set to non-default value |
| 200 | +# # but average != 'binary'; even if data is binary. |
| 201 | +# assert_warns_message(UserWarning, |
| 202 | +# "Note that pos_label (set to 2) is " |
| 203 | +# "ignored when average != 'binary' (got 'macro'). You " |
| 204 | +# "may use labels=[pos_label] to specify a single " |
| 205 | +# "positive class.", precision_recall_fscore_support, |
| 206 | +# [1, 2, 1], [1, 2, 2], pos_label=2, average='macro') |
0 commit comments