Skip to content

Commit 0eddca6

Browse files
author
Guillaume Lemaitre
committed
Finish sensitivity and specificity
1 parent 2085360 commit 0eddca6

File tree

2 files changed

+112
-152
lines changed

2 files changed

+112
-152
lines changed

imblearn/metrics/classification.py

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
4040
4141
If ``pos_label is None`` and in binary classification, this function
4242
returns the average sensitivity and specificity if ``average``
43-
is one of ``'micro'`` or 'weighted'``.
43+
is one of ``'weighted'``.
4444
4545
Parameters
4646
----------
@@ -105,8 +105,7 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
105105
<https://en.wikipedia.org/wiki/Sensitivity_and_specificity>`_
106106
107107
"""
108-
109-
average_options = (None, 'micro', 'macro', 'weighted')
108+
average_options = (None, 'macro', 'weighted')
110109
if average not in average_options and average != 'binary':
111110
raise ValueError('average has to be one of ' +
112111
str(average_options))
@@ -154,24 +153,20 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
154153
y_pred = le.transform(y_pred)
155154
sorted_labels = le.classes_
156155

157-
LOGGER.debug(y_true)
158-
LOGGER.debug(y_pred)
159-
LOGGER.debug(sorted_labels)
160-
161156
LOGGER.debug('The number of labels is %s' % n_labels)
162157

163158
# In a leave out strategy and for each label, compute:
164159
# TP, TN, FP, FN
165160
# These list contain an array in which each sample is labeled as
166161
# TP, TN, FP, FN
167162
list_tp = [np.bitwise_and((y_true == label), (y_pred == label))
168-
for label in sorted_labels]
163+
for label in range(sorted_labels.size)]
169164
list_tn = [np.bitwise_and((y_true != label), (y_pred != label))
170-
for label in sorted_labels]
165+
for label in range(sorted_labels.size)]
171166
list_fp = [np.bitwise_and((y_true != label), (y_pred == label))
172-
for label in sorted_labels]
167+
for label in range(sorted_labels.size)]
173168
list_fn = [np.bitwise_and((y_true == label), (y_pred != label))
174-
for label in sorted_labels]
169+
for label in range(sorted_labels.size)]
175170

176171
# Compute the sum for each type
177172
# We keep only the counting corresponding to True values
@@ -197,42 +192,32 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
197192
# Sort the support
198193
support = support[indices]
199194

200-
201195
LOGGER.debug('The indices which are retained are %s' % indices)
202196

203-
LOGGER.debug('TP: %s' % tp_sum)
204-
LOGGER.debug('TN: %s' % tn_sum)
205-
LOGGER.debug('FP: %s' % fp_sum)
206-
LOGGER.debug('FN: %s' % fn_sum)
207-
208197
tp_sum = tp_sum[indices]
209198
tn_sum = tn_sum[indices]
210199
fp_sum = fp_sum[indices]
211200
fn_sum = fn_sum[indices]
212201

213-
if average == 'micro':
214-
tp_sum = np.array([tp_sum.sum()])
215-
tn_sum = np.array([tn_sum.sum()])
216-
fp_sum = np.array([fp_sum.sum()])
217-
fn_sum = np.array([fn_sum.sum()])
218-
219-
LOGGER.debug('Did we do the average micro %s' % tp_sum)
220-
221202
LOGGER.debug('Computed the necessary stats for the sensitivity and'
222203
' specificity')
223204

224-
# Compute the sensitivity and specificity
225-
sensitivity = [_prf_divide(tp, tp + fn, 'sensitivity', 'tp + fn', average,
226-
warn_for) for tp, fn in zip(tp_sum, fn_sum)]
227-
specificity = [_prf_divide(tn, tn + fp, 'specificity', 'tn + fp', average,
228-
warn_for) for tn, fp in zip(tn_sum, fp_sum)]
229-
230-
LOGGER.debug('Sensitivity = %s - Specificity = %s' % (sensitivity,
231-
specificity))
205+
LOGGER.debug(tp_sum)
206+
LOGGER.debug(tn_sum)
207+
LOGGER.debug(fp_sum)
208+
LOGGER.debug(fn_sum)
232209

233-
LOGGER.debug('Computed the sensitivity and specificity for each class')
234-
LOGGER.debug('The lengths of those two metrics are: %s - %s',
235-
len(sensitivity), len(specificity))
210+
# Compute the sensitivity and specificity
211+
with np.errstate(divide='ignore', invalid='ignore'):
212+
sensitivity = _prf_divide(tp_sum, tp_sum + fn_sum, 'sensitivity',
213+
'tp + fn', average, warn_for)
214+
specificity = _prf_divide(tn_sum, tn_sum + fp_sum, 'specificity',
215+
'tn + fp', average, warn_for)
216+
217+
# sensitivity = [_prf_divide(tp, tp + fn, 'sensitivity', 'tp + fn', average,
218+
# warn_for) for tp, fn in zip(tp_sum, fn_sum)]
219+
# specificity = [_prf_divide(tn, tn + fp, 'specificity', 'tn + fp', average,
220+
# warn_for) for tn, fp in zip(tn_sum, fp_sum)]
236221

237222
# If we need to weight the results
238223
if average == 'weighted':

imblearn/metrics/tests/test_classification.py

Lines changed: 91 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import division, print_function
44

5+
from functools import partial
6+
57
import numpy as np
68

79
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
@@ -12,6 +14,8 @@
1214
from sklearn import datasets
1315
from sklearn import svm
1416

17+
from sklearn.preprocessing import label_binarize
18+
from sklearn.utils.testing import assert_not_equal
1519
from sklearn.utils.validation import check_random_state
1620

1721
from imblearn.metrics import sensitivity_specificity_support
@@ -94,113 +98,93 @@ def test_sensitivity_specificity_support_binary():
9498
assert_array_almost_equal(spec, 0.88, 2)
9599

96100

97-
# def test_precision_recall_f_binary_single_class():
98-
# # Test precision, recall and F1 score behave with a single positive or
99-
# # negative class
100-
# # Such a case may occur with non-stratified cross-validation
101-
# assert_equal(1., precision_score([1, 1], [1, 1]))
102-
# assert_equal(1., recall_score([1, 1], [1, 1]))
103-
# assert_equal(1., f1_score([1, 1], [1, 1]))
104-
105-
# assert_equal(0., precision_score([-1, -1], [-1, -1]))
106-
# assert_equal(0., recall_score([-1, -1], [-1, -1]))
107-
# assert_equal(0., f1_score([-1, -1], [-1, -1]))
108-
109-
110-
# @ignore_warnings
111-
# def test_precision_recall_f_extra_labels():
112-
# # Test handling of explicit additional (not in input) labels to PRF
113-
# y_true = [1, 3, 3, 2]
114-
# y_pred = [1, 1, 3, 2]
115-
# y_true_bin = label_binarize(y_true, classes=np.arange(5))
116-
# y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
117-
# data = [(y_true, y_pred),
118-
# (y_true_bin, y_pred_bin)]
119-
120-
# for i, (y_true, y_pred) in enumerate(data):
121-
# # No average: zeros in array
122-
# actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4],
123-
# average=None)
124-
# assert_array_almost_equal([0., 1., 1., .5, 0.], actual)
125-
126-
# # Macro average is changed
127-
# actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4],
128-
# average='macro')
129-
# assert_array_almost_equal(np.mean([0., 1., 1., .5, 0.]), actual)
130-
131-
# # No effect otheriwse
132-
# for average in ['micro', 'weighted', 'samples']:
133-
# if average == 'samples' and i == 0:
134-
# continue
135-
# assert_almost_equal(recall_score(y_true, y_pred,
136-
# labels=[0, 1, 2, 3, 4],
137-
# average=average),
138-
# recall_score(y_true, y_pred, labels=None,
139-
# average=average))
140-
141-
# # Error when introducing invalid label in multilabel case
142-
# # (although it would only affect performance if average='macro'/None)
143-
# for average in [None, 'macro', 'micro', 'samples']:
144-
# assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin,
145-
# labels=np.arange(6), average=average)
146-
# assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin,
147-
# labels=np.arange(-1, 4), average=average)
148-
149-
150-
# @ignore_warnings
151-
# def test_precision_recall_f_ignored_labels():
152-
# # Test a subset of labels may be requested for PRF
153-
# y_true = [1, 1, 2, 3]
154-
# y_pred = [1, 3, 3, 3]
155-
# y_true_bin = label_binarize(y_true, classes=np.arange(5))
156-
# y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
157-
# data = [(y_true, y_pred),
158-
# (y_true_bin, y_pred_bin)]
159-
160-
# for i, (y_true, y_pred) in enumerate(data):
161-
# recall_13 = partial(recall_score, y_true, y_pred, labels=[1, 3])
162-
# recall_all = partial(recall_score, y_true, y_pred, labels=None)
163-
164-
# assert_array_almost_equal([.5, 1.], recall_13(average=None))
165-
# assert_almost_equal((.5 + 1.) / 2, recall_13(average='macro'))
166-
# assert_almost_equal((.5 * 2 + 1. * 1) / 3,
167-
# recall_13(average='weighted'))
168-
# assert_almost_equal(2. / 3, recall_13(average='micro'))
169-
170-
# # ensure the above were meaningful tests:
171-
# for average in ['macro', 'weighted', 'micro']:
172-
# assert_not_equal(recall_13(average=average),
173-
# recall_all(average=average))
174-
175-
176-
# @ignore_warnings
177-
# def test_precision_recall_fscore_support_errors():
178-
# y_true, y_pred, _ = make_prediction(binary=True)
179-
180-
# # Bad beta
181-
# assert_raises(ValueError, precision_recall_fscore_support,
182-
# y_true, y_pred, beta=0.0)
183-
184-
# # Bad pos_label
185-
# assert_raises(ValueError, precision_recall_fscore_support,
186-
# y_true, y_pred, pos_label=2, average='binary')
187-
188-
# # Bad average option
189-
# assert_raises(ValueError, precision_recall_fscore_support,
190-
# [0, 1, 2], [1, 2, 0], average='mega')
191-
192-
193-
# def test_precision_recall_f_unused_pos_label():
194-
# # Check warning that pos_label unused when set to non-default value
195-
# # but average != 'binary'; even if data is binary.
196-
# assert_warns_message(UserWarning,
197-
# "Note that pos_label (set to 2) is "
198-
# "ignored when average != 'binary' (got 'macro'). You "
199-
# "may use labels=[pos_label] to specify a single "
200-
# "positive class.", precision_recall_fscore_support,
201-
# [1, 2, 1], [1, 2, 2], pos_label=2, average='macro')
202-
203-
def test_precision_recall_f1_score_multiclass():
101+
def test_sensitivity_specificity_binary_single_class():
102+
# Test sensitivity and specificity score behave with a single positive or
103+
# negative class
104+
# Such a case may occur with non-stratified cross-validation
105+
assert_equal(1., sensitivity_score([1, 1], [1, 1]))
106+
assert_equal(0., specificity_score([1, 1], [1, 1]))
107+
108+
assert_equal(0., sensitivity_score([-1, -1], [-1, -1]))
109+
assert_equal(0., specificity_score([-1, -1], [-1, -1]))
110+
111+
112+
def test_sensitivity_specificity_error_multilabels():
113+
# Test either if an error is raised when the input are multilabels
114+
y_true = [1, 3, 3, 2]
115+
y_pred = [1, 1, 3, 2]
116+
y_true_bin = label_binarize(y_true, classes=np.arange(5))
117+
y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
118+
119+
assert_raises(ValueError, sensitivity_score, y_true_bin, y_pred_bin)
120+
121+
@ignore_warnings
122+
def test_sensitivity_specifiicity_extra_labels():
123+
# Test handling of explicit additional (not in input) labels to SS
124+
y_true = [1, 3, 3, 2]
125+
y_pred = [1, 1, 3, 2]
126+
127+
actual = sensitivity_score(y_true, y_pred, labels=[0, 1, 2, 3, 4],
128+
average=None)
129+
assert_array_almost_equal([0., 1., 1., .5, 0.], actual)
130+
131+
# Macro average is changed
132+
actual = sensitivity_score(y_true, y_pred, labels=[0, 1, 2, 3, 4],
133+
average='macro')
134+
assert_array_almost_equal(np.mean([0., 1., 1., .5, 0.]), actual)
135+
136+
# Weighted average is changed
137+
assert_almost_equal(sensitivity_score(y_true, y_pred,
138+
labels=[0, 1, 2, 3, 4],
139+
average='weighted'),
140+
sensitivity_score(y_true, y_pred, labels=None,
141+
average='weighted'))
142+
143+
@ignore_warnings
144+
def test_sensitivity_specificity_f_ignored_labels():
145+
# Test a subset of labels may be requested for SS
146+
y_true = [1, 1, 2, 3]
147+
y_pred = [1, 3, 3, 3]
148+
149+
sensitivity_13 = partial(sensitivity_score, y_true, y_pred, labels=[1, 3])
150+
sensitivity_all = partial(sensitivity_score, y_true, y_pred, labels=None)
151+
152+
assert_array_almost_equal([.5, 1.], sensitivity_13(average=None))
153+
assert_almost_equal((.5 + 1.) / 2, sensitivity_13(average='macro'))
154+
assert_almost_equal((.5 * 2 + 1. * 1) / 3,
155+
sensitivity_13(average='weighted'))
156+
157+
# ensure the above were meaningful tests:
158+
for average in ['macro', 'weighted']:
159+
assert_not_equal(sensitivity_13(average=average),
160+
sensitivity_all(average=average))
161+
162+
163+
@ignore_warnings
164+
def test_sensitivity_specificity_support_errors():
165+
y_true, y_pred, _ = make_prediction(binary=True)
166+
167+
# Bad pos_label
168+
assert_raises(ValueError, sensitivity_specificity_support,
169+
y_true, y_pred, pos_label=2, average='binary')
170+
171+
# Bad average option
172+
assert_raises(ValueError, sensitivity_specificity_support,
173+
[0, 1, 2], [1, 2, 0], average='mega')
174+
175+
176+
def test_sensitivity_specificity_unused_pos_label():
177+
# Check warning that pos_label unused when set to non-default value
178+
# but average != 'binary'; even if data is binary.
179+
assert_warns_message(UserWarning,
180+
"Note that pos_label (set to 2) is "
181+
"ignored when average != 'binary' (got 'macro'). You "
182+
"may use labels=[pos_label] to specify a single "
183+
"positive class.", sensitivity_specificity_support,
184+
[1, 2, 1], [1, 2, 2], pos_label=2, average='macro')
185+
186+
187+
def test_sensitivity_specificity_multiclass():
204188
# Test Precision Recall and F1 Score for multiclass classification task
205189
y_true, y_pred, _ = make_prediction(binary=False)
206190

@@ -212,15 +196,6 @@ def test_precision_recall_f1_score_multiclass():
212196
assert_array_equal(supp, [24, 31, 20])
213197

214198
# averaging tests
215-
spec = specificity_score(y_true, y_pred, pos_label=1, average='micro')
216-
assert_array_almost_equal(spec, 0.77, 2)
217-
218-
sens = sensitivity_score(y_true, y_pred, average='micro')
219-
assert_array_almost_equal(sens, 0.53, 2)
220-
221-
spec = specificity_score(y_true, y_pred, average='macro')
222-
assert_array_almost_equal(spec, 0.77, 2)
223-
224199
sens = sensitivity_score(y_true, y_pred, average='macro')
225200
assert_array_almost_equal(sens, 0.60, 2)
226201

0 commit comments

Comments
 (0)