Skip to content

Commit ac3d0de

Browse files
author
Guillaume Lemaitre
committed
Added geometric mean
1 parent 0eddca6 commit ac3d0de

File tree

1 file changed

+91
-19
lines changed

1 file changed

+91
-19
lines changed

imblearn/metrics/classification.py

Lines changed: 91 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# coding: utf-8
2+
13
"""Metrics to assess performance on classification task given class prediction
24
35
Functions named as ``*_score`` return a scalar value to maximize: the higher
@@ -61,7 +63,7 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
6163
6264
pos_label : str or int, optional (default=1)
6365
The class to report if ``average='binary'`` and the data is binary.
64-
If the data are multiclass or multilabel, this will be ignored;
66+
If the data are multiclass, this will be ignored;
6567
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
6668
scores for that label only.
6769
@@ -202,23 +204,13 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
202204
LOGGER.debug('Computed the necessary stats for the sensitivity and'
203205
' specificity')
204206

205-
LOGGER.debug(tp_sum)
206-
LOGGER.debug(tn_sum)
207-
LOGGER.debug(fp_sum)
208-
LOGGER.debug(fn_sum)
209-
210207
# Compute the sensitivity and specificity
211208
with np.errstate(divide='ignore', invalid='ignore'):
212209
sensitivity = _prf_divide(tp_sum, tp_sum + fn_sum, 'sensitivity',
213210
'tp + fn', average, warn_for)
214211
specificity = _prf_divide(tn_sum, tn_sum + fp_sum, 'specificity',
215212
'tn + fp', average, warn_for)
216213

217-
# sensitivity = [_prf_divide(tp, tp + fn, 'sensitivity', 'tp + fn', average,
218-
# warn_for) for tp, fn in zip(tp_sum, fn_sum)]
219-
# specificity = [_prf_divide(tn, tn + fp, 'specificity', 'tn + fp', average,
220-
# warn_for) for tn, fp in zip(tn_sum, fp_sum)]
221-
222214
# If we need to weight the results
223215
if average == 'weighted':
224216
weights = support
@@ -259,13 +251,11 @@ def sensitivity_score(y_true, y_pred, labels=None, pos_label=1,
259251
order if ``average is None``. Labels present in the data can be
260252
excluded, for example to calculate a multiclass average ignoring a
261253
majority negative class, while labels not present in the data will
262-
result in 0 components in a macro average. For multilabel targets,
263-
labels are column indices. By default, all labels in ``y_true`` and
264-
``y_pred`` are used in sorted order.
254+
result in 0 components in a macro average.
265255
266256
pos_label : str or int, optional (default=1)
267257
The class to report if ``average='binary'`` and the data is binary.
268-
If the data are multiclass or multilabel, this will be ignored;
258+
If the data are multiclass, this will be ignored;
269259
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
270260
scores for that label only.
271261
@@ -331,13 +321,11 @@ def specificity_score(y_true, y_pred, labels=None, pos_label=1,
331321
order if ``average is None``. Labels present in the data can be
332322
excluded, for example to calculate a multiclass average ignoring a
333323
majority negative class, while labels not present in the data will
334-
result in 0 components in a macro average. For multilabel targets,
335-
labels are column indices. By default, all labels in ``y_true`` and
336-
``y_pred`` are used in sorted order.
324+
result in 0 components in a macro average.
337325
338326
pos_label : str or int, optional (default=1)
339327
The class to report if ``average='binary'`` and the data is binary.
340-
If the data are multiclass or multilabel, this will be ignored;
328+
If the data are multiclass, this will be ignored;
341329
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
342330
scores for that label only.
343331
@@ -377,3 +365,87 @@ def specificity_score(y_true, y_pred, labels=None, pos_label=1,
377365
sample_weight=sample_weight)
378366

379367
return s
368+
369+
370+
def geometric_mean_score(y_true, y_pred, labels=None, pos_label=1,
371+
average='binary', sample_weight=None):
372+
"""Compute the geometric mean
373+
374+
The geometric mean is the squared root of the product of the sensitivity
375+
and specificity. This measure tries to maximize the accuracy on each
376+
of the two classes while keeping these accuracies balanced.
377+
378+
The specificity is the ratio ``tp / (tp + fn)`` where ``tp`` is the number
379+
of true positives and ``fn`` the number of false negatives. The specificity
380+
is intuitively the ability of the classifier to find all the positive
381+
samples.
382+
383+
The best value is 1 and the worst value is 0.
384+
385+
Parameters
386+
----------
387+
y_true : ndarray, shape (n_samples, )
388+
Ground truth (correct) target values.
389+
390+
y_pred : ndarray, shape (n_samples, )
391+
Estimated targets as returned by a classifier.
392+
393+
labels : list, optional
394+
The set of labels to include when ``average != 'binary'``, and their
395+
order if ``average is None``. Labels present in the data can be
396+
excluded, for example to calculate a multiclass average ignoring a
397+
majority negative class, while labels not present in the data will
398+
result in 0 components in a macro average.
399+
400+
pos_label : str or int, optional (default=1)
401+
The class to report if ``average='binary'`` and the data is binary.
402+
If the data are multiclass or multilabel, this will be ignored;
403+
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
404+
scores for that label only.
405+
406+
average : str or None, optional (default=None)
407+
If ``None``, the scores for each class are returned. Otherwise, this
408+
determines the type of averaging performed on the data:
409+
410+
``'binary'``:
411+
Only report results for the class specified by ``pos_label``.
412+
This is applicable only if targets (``y_{true,pred}``) are binary.
413+
``'macro'``:
414+
Calculate metrics for each label, and find their unweighted
415+
mean. This does not take label imbalance into account.
416+
``'weighted'``:
417+
Calculate metrics for each label, and find their average, weighted
418+
by support (the number of true instances for each label). This
419+
alters 'macro' to account for label imbalance.
420+
421+
warn_for : tuple or set, for internal use
422+
This determines which warnings will be made in the case that this
423+
function is being used to return only one of its metrics.
424+
425+
sample_weight : ndarray, shape (n_samples, )
426+
Sample weights.
427+
428+
Returns
429+
-------
430+
geometric_mean : float (if ``average`` = None) or ndarray, \
431+
shape (n_unique_labels, )
432+
433+
References
434+
----------
435+
.. [1] Kubat, M. and Matwin, S. "Addressing the curse of
436+
imbalanced training sets: one-sided selection" ICML (1997)
437+
438+
.. [2] Barandela, R., Sánchez, J. S., Garcıa, V., & Rangel, E. "Strategies
439+
for learning in class imbalance problems", Pattern Recognition,
440+
36(3), (2003), pp 849-851.
441+
442+
"""
443+
sen, spe, _ = sensitivity_specificity_support(y_true, y_pred,
444+
labels=labels,
445+
pos_label=pos_label,
446+
average=average,
447+
warn_for=('specificity',
448+
'specificity'),
449+
sample_weight=sample_weight)
450+
451+
return np.sqrt(sen * spe)

0 commit comments

Comments
 (0)