Skip to content

Commit 2085360

Browse files
author
Guillaume Lemaitre
committed
Finish the non-failure test
1 parent d28f450 commit 2085360

File tree

3 files changed

+269
-54
lines changed

3 files changed

+269
-54
lines changed

imblearn/metrics/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
"""
55

66
from .classification import sensitivity_specificity_support
7+
from .classification import sensitivity_score
8+
from .classification import specificity_score
79

810
__all__ = [
9-
'sensitivity_specificity_support'
11+
'sensitivity_specificity_support',
12+
'sensitivity_score',
13+
'specificity_score'
1014
]

imblearn/metrics/classification.py

Lines changed: 210 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import numpy as np
1616

17-
from sklearn.metrics.classification import (_check_targets, _prf_divide)
17+
from sklearn.metrics.classification import _check_targets, _prf_divide
1818
from sklearn.preprocessing import LabelEncoder
1919
from sklearn.utils.fixes import bincount
2020
from sklearn.utils.multiclass import unique_labels
@@ -44,10 +44,10 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
4444
4545
Parameters
4646
----------
47-
y_true : 1d array-like, or label indicator array / sparse matrix
47+
y_true : ndarray, shape (n_samples, )
4848
Ground truth (correct) target values.
4949
50-
y_pred : 1d array-like, or label indicator array / sparse matrix
50+
y_pred : ndarray, shape (n_samples, )
5151
Estimated targets as returned by a classifier.
5252
5353
labels : list, optional
@@ -59,13 +59,13 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
5959
labels are column indices. By default, all labels in ``y_true`` and
6060
``y_pred`` are used in sorted order.
6161
62-
pos_label : str or int, 1 by default
62+
pos_label : str or int, optional (default=1)
6363
The class to report if ``average='binary'`` and the data is binary.
6464
If the data are multiclass or multilabel, this will be ignored;
6565
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
6666
scores for that label only.
6767
68-
average : string, [None (default), 'binary', 'macro', 'weighted']
68+
average : str or None, optional (default=None)
6969
If ``None``, the scores for each class are returned. Otherwise, this
7070
determines the type of averaging performed on the data:
7171
@@ -84,16 +84,19 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
8484
This determines which warnings will be made in the case that this
8585
function is being used to return only one of its metrics.
8686
87+
sample_weight : ndarray, shape (n_samples, )
88+
Sample weights.
89+
8790
Returns
8891
-------
8992
sensitivity : float (if ``average`` = None) or ndarray, \
90-
shape(n_unique_labels,)
93+
shape (n_unique_labels, )
9194
9295
specificity : float (if ``average`` = None) or ndarray, \
93-
shape(n_unique_labels,)
96+
shape (n_unique_labels, )
9497
9598
support : int (if ``average`` = None) or ndarray, \
96-
shape(n_unique_labels,)
99+
shape (n_unique_labels, )
97100
The number of occurrences of each label in ``y_true``.
98101
99102
References
@@ -151,6 +154,12 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
151154
y_pred = le.transform(y_pred)
152155
sorted_labels = le.classes_
153156

157+
LOGGER.debug(y_true)
158+
LOGGER.debug(y_pred)
159+
LOGGER.debug(sorted_labels)
160+
161+
LOGGER.debug('The number of labels is %s' % n_labels)
162+
154163
# In a leave out strategy and for each label, compute:
155164
# TP, TN, FP, FN
156165
# These list contain an array in which each sample is labeled as
@@ -159,53 +168,75 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
159168
for label in sorted_labels]
160169
list_tn = [np.bitwise_and((y_true != label), (y_pred != label))
161170
for label in sorted_labels]
162-
list_fp = [np.bitwise_and((y_true == label), (y_pred != label))
171+
list_fp = [np.bitwise_and((y_true != label), (y_pred == label))
163172
for label in sorted_labels]
164-
list_fn = [np.bitwise_and((y_true != label), (y_pred == label))
173+
list_fn = [np.bitwise_and((y_true == label), (y_pred != label))
165174
for label in sorted_labels]
166175

167-
LOGGER.debug(list_tp)
168-
LOGGER.debug(list_tn)
169-
LOGGER.debug(list_fn)
170-
LOGGER.debug(list_fn)
171-
172176
# Compute the sum for each type
173-
tp_sum = [bincount(tp, weights=sample_weight, minlength=len(labels))
174-
for tp in list_tp]
175-
tn_sum = [bincount(tn, weights=sample_weight, minlength=len(labels))
176-
for tn in list_tn]
177-
fp_sum = [bincount(fp, weights=sample_weight, minlength=len(labels))
178-
for fp in list_fp]
179-
fn_sum = [bincount(fn, weights=sample_weight, minlength=len(labels))
180-
for fn in list_fn]
181-
182-
LOGGER.debug(tp_sum)
183-
LOGGER.debug(tn_sum)
184-
LOGGER.debug(fp_sum)
185-
LOGGER.debug(fn_sum)
177+
# We keep only the counting corresponding to True values
178+
# We are using bincount since it allows to weight the samples
179+
tp_sum = np.array([bincount(tp, weights=sample_weight,
180+
minlength=2)[-1]
181+
for tp in list_tp])
182+
tn_sum = np.array([bincount(tn, weights=sample_weight,
183+
minlength=2)[-1]
184+
for tn in list_tn])
185+
fp_sum = np.array([bincount(fp, weights=sample_weight,
186+
minlength=2)[-1]
187+
for fp in list_fp])
188+
fn_sum = np.array([bincount(fn, weights=sample_weight,
189+
minlength=2)[-1]
190+
for fn in list_fn])
186191

187192
# Retain only selected labels
188193
indices = np.searchsorted(sorted_labels, labels[:n_labels])
189-
tp_sum = [tp[indices] for tp in tp_sum]
190-
tn_sum = [tn[indices] for tn in tn_sum]
191-
fp_sum = [fp[indices] for fp in fp_sum]
192-
fn_sum = [fn[indices] for fn in fn_sum]
194+
# For support, we can count the number of occurrences of each label
195+
support = np.array(bincount(y_true, weights=sample_weight,
196+
minlength=len(labels)))
197+
# Sort the support
198+
support = support[indices]
199+
200+
201+
LOGGER.debug('The indices which are retained are %s' % indices)
202+
203+
LOGGER.debug('TP: %s' % tp_sum)
204+
LOGGER.debug('TN: %s' % tn_sum)
205+
LOGGER.debug('FP: %s' % fp_sum)
206+
LOGGER.debug('FN: %s' % fn_sum)
193207

194-
LOGGER.debug('Computed for each label the stats')
208+
tp_sum = tp_sum[indices]
209+
tn_sum = tn_sum[indices]
210+
fp_sum = fp_sum[indices]
211+
fn_sum = fn_sum[indices]
212+
213+
if average == 'micro':
214+
tp_sum = np.array([tp_sum.sum()])
215+
tn_sum = np.array([tn_sum.sum()])
216+
fp_sum = np.array([fp_sum.sum()])
217+
fn_sum = np.array([fn_sum.sum()])
218+
219+
LOGGER.debug('Did we do the average micro %s' % tp_sum)
220+
221+
LOGGER.debug('Computed the necessary stats for the sensitivity and'
222+
' specificity')
195223

196224
# Compute the sensitivity and specificity
197225
sensitivity = [_prf_divide(tp, tp + fn, 'sensitivity', 'tp + fn', average,
198226
warn_for) for tp, fn in zip(tp_sum, fn_sum)]
199227
specificity = [_prf_divide(tn, tn + fp, 'specificity', 'tn + fp', average,
200228
warn_for) for tn, fp in zip(tn_sum, fp_sum)]
201229

230+
LOGGER.debug('Sensitivity = %s - Specificity = %s' % (sensitivity,
231+
specificity))
232+
202233
LOGGER.debug('Computed the sensitivity and specificity for each class')
203234
LOGGER.debug('The lengths of those two metrics are: %s - %s',
204235
len(sensitivity), len(specificity))
205236

206237
# If we need to weight the results
207238
if average == 'weighted':
208-
weights = tp_sum
239+
weights = support
209240
if weights.sum() == 0:
210241
return 0, 0, None
211242
else:
@@ -215,6 +246,149 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
215246
assert average != 'binary' or len(sensitivity) == 1
216247
sensitivity = np.average(sensitivity, weights=weights)
217248
specificity = np.average(specificity, weights=weights)
218-
tp_sum = None
249+
support = None
250+
251+
return sensitivity, specificity, support
252+
253+
254+
def sensitivity_score(y_true, y_pred, labels=None, pos_label=1,
255+
average='binary', sample_weight=None):
256+
"""Compute the sensitivity
257+
258+
The sensitivity is the ratio ``tp / (tp + fn)`` where ``tp`` is the number
259+
of true positives and ``fn`` the number of false negatives. The sensitivity
260+
quantifies the ability to avoid false negatives.
261+
262+
The best value is 1 and the worst value is 0.
263+
264+
Parameters
265+
----------
266+
y_true : ndarray, shape (n_samples, )
267+
Ground truth (correct) target values.
268+
269+
y_pred : ndarray, shape (n_samples, )
270+
Estimated targets as returned by a classifier.
271+
272+
labels : list, optional
273+
The set of labels to include when ``average != 'binary'``, and their
274+
order if ``average is None``. Labels present in the data can be
275+
excluded, for example to calculate a multiclass average ignoring a
276+
majority negative class, while labels not present in the data will
277+
result in 0 components in a macro average. For multilabel targets,
278+
labels are column indices. By default, all labels in ``y_true`` and
279+
``y_pred`` are used in sorted order.
280+
281+
pos_label : str or int, optional (default=1)
282+
The class to report if ``average='binary'`` and the data is binary.
283+
If the data are multiclass or multilabel, this will be ignored;
284+
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
285+
scores for that label only.
286+
287+
average : str or None, optional (default=None)
288+
If ``None``, the scores for each class are returned. Otherwise, this
289+
determines the type of averaging performed on the data:
290+
291+
``'binary'``:
292+
Only report results for the class specified by ``pos_label``.
293+
This is applicable only if targets (``y_{true,pred}``) are binary.
294+
``'macro'``:
295+
Calculate metrics for each label, and find their unweighted
296+
mean. This does not take label imbalance into account.
297+
``'weighted'``:
298+
Calculate metrics for each label, and find their average, weighted
299+
by support (the number of true instances for each label). This
300+
alters 'macro' to account for label imbalance.
301+
302+
warn_for : tuple or set, for internal use
303+
This determines which warnings will be made in the case that this
304+
function is being used to return only one of its metrics.
219305
220-
return sensitivity, specificity, tp_sum
306+
sample_weight : ndarray, shape (n_samples, )
307+
Sample weights.
308+
309+
Returns
310+
-------
311+
specificity : float (if ``average`` = None) or ndarray, \
312+
shape (n_unique_labels, )
313+
314+
"""
315+
s, _, _ = sensitivity_specificity_support(y_true, y_pred,
316+
labels=labels,
317+
pos_label=pos_label,
318+
average=average,
319+
warn_for=('specificity',),
320+
sample_weight=sample_weight)
321+
322+
return s
323+
324+
325+
def specificity_score(y_true, y_pred, labels=None, pos_label=1,
326+
average='binary', sample_weight=None):
327+
"""Compute the specificity
328+
329+
The specificity is the ratio ``tp / (tp + fn)`` where ``tp`` is the number
330+
of true positives and ``fn`` the number of false negatives. The specificity
331+
is intuitively the ability of the classifier to find all the positive
332+
samples.
333+
334+
The best value is 1 and the worst value is 0.
335+
336+
Parameters
337+
----------
338+
y_true : ndarray, shape (n_samples, )
339+
Ground truth (correct) target values.
340+
341+
y_pred : ndarray, shape (n_samples, )
342+
Estimated targets as returned by a classifier.
343+
344+
labels : list, optional
345+
The set of labels to include when ``average != 'binary'``, and their
346+
order if ``average is None``. Labels present in the data can be
347+
excluded, for example to calculate a multiclass average ignoring a
348+
majority negative class, while labels not present in the data will
349+
result in 0 components in a macro average. For multilabel targets,
350+
labels are column indices. By default, all labels in ``y_true`` and
351+
``y_pred`` are used in sorted order.
352+
353+
pos_label : str or int, optional (default=1)
354+
The class to report if ``average='binary'`` and the data is binary.
355+
If the data are multiclass or multilabel, this will be ignored;
356+
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
357+
scores for that label only.
358+
359+
average : str or None, optional (default=None)
360+
If ``None``, the scores for each class are returned. Otherwise, this
361+
determines the type of averaging performed on the data:
362+
363+
``'binary'``:
364+
Only report results for the class specified by ``pos_label``.
365+
This is applicable only if targets (``y_{true,pred}``) are binary.
366+
``'macro'``:
367+
Calculate metrics for each label, and find their unweighted
368+
mean. This does not take label imbalance into account.
369+
``'weighted'``:
370+
Calculate metrics for each label, and find their average, weighted
371+
by support (the number of true instances for each label). This
372+
alters 'macro' to account for label imbalance.
373+
374+
warn_for : tuple or set, for internal use
375+
This determines which warnings will be made in the case that this
376+
function is being used to return only one of its metrics.
377+
378+
sample_weight : ndarray, shape (n_samples, )
379+
Sample weights.
380+
381+
Returns
382+
-------
383+
specificity : float (if ``average`` = None) or ndarray, \
384+
shape (n_unique_labels, )
385+
386+
"""
387+
_, s, _ = sensitivity_specificity_support(y_true, y_pred,
388+
labels=labels,
389+
pos_label=pos_label,
390+
average=average,
391+
warn_for=('specificity',),
392+
sample_weight=sample_weight)
393+
394+
return s

0 commit comments

Comments
 (0)