Skip to content

Commit c8cb6d6

Browse files
author
Guillaume Lemaitre
committed
Update the test for the specificity
1 parent 965c5a1 commit c8cb6d6

File tree

2 files changed

+149
-157
lines changed

2 files changed

+149
-157
lines changed

imblearn/metrics/classification.py

Lines changed: 69 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# coding: utf-8
2-
32
"""Metrics to assess performance on classification task given class prediction
43
54
Functions named as ``*_score`` return a scalar value to maximize: the higher
@@ -20,12 +19,16 @@
2019
from sklearn.preprocessing import LabelEncoder
2120
from sklearn.utils.fixes import bincount
2221
from sklearn.utils.multiclass import unique_labels
22+
from sklearn.utils.sparsefuncs import count_nonzero
2323

2424
LOGGER = logging.getLogger(__name__)
2525

2626

27-
def sensitivity_specificity_support(y_true, y_pred, labels=None,
28-
pos_label=1, average=None,
27+
def sensitivity_specificity_support(y_true,
28+
y_pred,
29+
labels=None,
30+
pos_label=1,
31+
average=None,
2932
warn_for=('sensitivity', 'specificity'),
3033
sample_weight=None):
3134
"""Compute sensitivity, specificity, and support for each class
@@ -116,8 +119,7 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
116119
"""
117120
average_options = (None, 'micro', 'macro', 'weighted', 'samples')
118121
if average not in average_options and average != 'binary':
119-
raise ValueError('average has to be one of ' +
120-
str(average_options))
122+
raise ValueError('average has to be one of ' + str(average_options))
121123

122124
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
123125
present_labels = unique_labels(y_true, y_pred)
@@ -146,38 +148,14 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
146148
n_labels = None
147149
else:
148150
n_labels = len(labels)
149-
labels = np.hstack([labels, np.setdiff1d(present_labels, labels,
150-
assume_unique=True)])
151+
labels = np.hstack(
152+
[labels, np.setdiff1d(
153+
present_labels, labels, assume_unique=True)])
151154

152155
# Calculate tp_sum, pred_sum, true_sum ###
153156

154157
if y_type.startswith('multilabel'):
155-
sum_axis = 1 if average == 'samples' else 0
156-
157-
# All labels are index integers for multilabel.
158-
# Select labels:
159-
if not np.all(labels == present_labels):
160-
if np.max(labels) > np.max(present_labels):
161-
raise ValueError('All labels must be in [0, n labels). '
162-
'Got %d > %d' %
163-
(np.max(labels), np.max(present_labels)))
164-
if np.min(labels) < 0:
165-
raise ValueError('All labels must be in [0, n labels). '
166-
'Got %d < 0' % np.min(labels))
167-
168-
y_true = y_true[:, labels[:n_labels]]
169-
y_pred = y_pred[:, labels[:n_labels]]
170-
171-
# calculate weighted counts
172-
true_and_pred = y_true.multiply(y_pred)
173-
tp_sum = count_nonzero(true_and_pred, axis=sum_axis,
174-
sample_weight=sample_weight)
175-
pred_sum = count_nonzero(y_pred, axis=sum_axis,
176-
sample_weight=sample_weight)
177-
true_sum = count_nonzero(y_true, axis=sum_axis,
178-
sample_weight=sample_weight)
179-
tn_sum = y_true.size - (pred_sum + true_sum - tp_sum)
180-
158+
raise ValueError('imblearn does not support multilabel')
181159
elif average == 'samples':
182160
raise ValueError("Sample-based precision, recall, fscore is "
183161
"not meaningful outside multilabel "
@@ -198,17 +176,17 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
198176
tp_bins_weights = None
199177

200178
if len(tp_bins):
201-
tp_sum = bincount(tp_bins, weights=tp_bins_weights,
202-
minlength=len(labels))
179+
tp_sum = bincount(
180+
tp_bins, weights=tp_bins_weights, minlength=len(labels))
203181
else:
204182
# Pathological case
205183
true_sum = pred_sum = tp_sum = np.zeros(len(labels))
206184
if len(y_pred):
207-
pred_sum = bincount(y_pred, weights=sample_weight,
208-
minlength=len(labels))
185+
pred_sum = bincount(
186+
y_pred, weights=sample_weight, minlength=len(labels))
209187
if len(y_true):
210-
true_sum = bincount(y_true, weights=sample_weight,
211-
minlength=len(labels))
188+
true_sum = bincount(
189+
y_true, weights=sample_weight, minlength=len(labels))
212190

213191
# Compute the true negative
214192
tn_sum = y_true.size - (pred_sum + true_sum - tp_sum)
@@ -220,6 +198,11 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
220198
pred_sum = pred_sum[indices]
221199
tn_sum = tn_sum[indices]
222200

201+
LOGGER.debug('tp: %s' % tp_sum)
202+
LOGGER.debug('tn: %s' % tn_sum)
203+
LOGGER.debug('pred_sum: %s' % pred_sum)
204+
LOGGER.debug('true_sum: %s' % true_sum)
205+
223206
if average == 'micro':
224207
tp_sum = np.array([tp_sum.sum()])
225208
pred_sum = np.array([pred_sum.sum()])
@@ -236,8 +219,8 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
236219
specificity = _prf_divide(tn_sum, tn_sum + pred_sum - tp_sum,
237220
'specificity', 'predicted', average,
238221
warn_for)
239-
sensitivity = _prf_divide(tp_sum, true_sum,
240-
'sensitivity', 'true', average, warn_for)
222+
sensitivity = _prf_divide(tp_sum, true_sum, 'sensitivity', 'true',
223+
average, warn_for)
241224

242225
# Average the results
243226

@@ -250,6 +233,9 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
250233
else:
251234
weights = None
252235

236+
LOGGER.debug(specificity)
237+
LOGGER.debug(weights)
238+
253239
if average is not None:
254240
assert average != 'binary' or len(specificity) == 1
255241
specificity = np.average(specificity, weights=weights)
@@ -259,8 +245,12 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
259245
return sensitivity, specificity, true_sum
260246

261247

262-
def sensitivity_score(y_true, y_pred, labels=None, pos_label=1,
263-
average='binary', sample_weight=None):
248+
def sensitivity_score(y_true,
249+
y_pred,
250+
labels=None,
251+
pos_label=1,
252+
average='binary',
253+
sample_weight=None):
264254
"""Compute the sensitivity
265255
266256
The sensitivity is the ratio ``tp / (tp + fn)`` where ``tp`` is the number
@@ -326,18 +316,24 @@ def sensitivity_score(y_true, y_pred, labels=None, pos_label=1,
326316
shape (n_unique_labels, )
327317
328318
"""
329-
s, _, _ = sensitivity_specificity_support(y_true, y_pred,
330-
labels=labels,
331-
pos_label=pos_label,
332-
average=average,
333-
warn_for=('sensitivity',),
334-
sample_weight=sample_weight)
319+
s, _, _ = sensitivity_specificity_support(
320+
y_true,
321+
y_pred,
322+
labels=labels,
323+
pos_label=pos_label,
324+
average=average,
325+
warn_for=('sensitivity', ),
326+
sample_weight=sample_weight)
335327

336328
return s
337329

338330

339-
def specificity_score(y_true, y_pred, labels=None, pos_label=1,
340-
average='binary', sample_weight=None):
331+
def specificity_score(y_true,
332+
y_pred,
333+
labels=None,
334+
pos_label=1,
335+
average='binary',
336+
sample_weight=None):
341337
"""Compute the specificity
342338
343339
The specificity is the ratio ``tp / (tp + fn)`` where ``tp`` is the number
@@ -404,18 +400,24 @@ def specificity_score(y_true, y_pred, labels=None, pos_label=1,
404400
shape (n_unique_labels, )
405401
406402
"""
407-
_, s, _ = sensitivity_specificity_support(y_true, y_pred,
408-
labels=labels,
409-
pos_label=pos_label,
410-
average=average,
411-
warn_for=('specificity',),
412-
sample_weight=sample_weight)
403+
_, s, _ = sensitivity_specificity_support(
404+
y_true,
405+
y_pred,
406+
labels=labels,
407+
pos_label=pos_label,
408+
average=average,
409+
warn_for=('specificity', ),
410+
sample_weight=sample_weight)
413411

414412
return s
415413

416414

417-
def geometric_mean_score(y_true, y_pred, labels=None, pos_label=1,
418-
average='binary', sample_weight=None):
415+
def geometric_mean_score(y_true,
416+
y_pred,
417+
labels=None,
418+
pos_label=1,
419+
average='binary',
420+
sample_weight=None):
419421
"""Compute the geometric mean
420422
421423
The geometric mean is the squared root of the product of the sensitivity
@@ -495,13 +497,14 @@ def geometric_mean_score(y_true, y_pred, labels=None, pos_label=1,
495497
36(3), (2003), pp 849-851.
496498
497499
"""
498-
sen, spe, _ = sensitivity_specificity_support(y_true, y_pred,
499-
labels=labels,
500-
pos_label=pos_label,
501-
average=average,
502-
warn_for=('specificity',
503-
'specificity'),
504-
sample_weight=sample_weight)
500+
sen, spe, _ = sensitivity_specificity_support(
501+
y_true,
502+
y_pred,
503+
labels=labels,
504+
pos_label=pos_label,
505+
average=average,
506+
warn_for=('specificity', 'specificity'),
507+
sample_weight=sample_weight)
505508

506509
LOGGER.debug('The sensitivity and specificity are : %s - %s' % (sen, spe))
507510

0 commit comments

Comments
 (0)