|
| 1 | +# coding: utf-8 |
| 2 | + |
1 | 3 | """Metrics to assess performance on classification task given class prediction
|
2 | 4 |
|
3 | 5 | Functions named as ``*_score`` return a scalar value to maximize: the higher
|
@@ -61,7 +63,7 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
|
61 | 63 |
|
62 | 64 | pos_label : str or int, optional (default=1)
|
63 | 65 | The class to report if ``average='binary'`` and the data is binary.
|
64 |
| - If the data are multiclass or multilabel, this will be ignored; |
| 66 | + If the data are multiclass, this will be ignored; |
65 | 67 | setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
|
66 | 68 | scores for that label only.
|
67 | 69 |
|
@@ -202,23 +204,13 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
|
202 | 204 | LOGGER.debug('Computed the necessary stats for the sensitivity and'
|
203 | 205 | ' specificity')
|
204 | 206 |
|
205 |
| - LOGGER.debug(tp_sum) |
206 |
| - LOGGER.debug(tn_sum) |
207 |
| - LOGGER.debug(fp_sum) |
208 |
| - LOGGER.debug(fn_sum) |
209 |
| - |
210 | 207 | # Compute the sensitivity and specificity
|
211 | 208 | with np.errstate(divide='ignore', invalid='ignore'):
|
212 | 209 | sensitivity = _prf_divide(tp_sum, tp_sum + fn_sum, 'sensitivity',
|
213 | 210 | 'tp + fn', average, warn_for)
|
214 | 211 | specificity = _prf_divide(tn_sum, tn_sum + fp_sum, 'specificity',
|
215 | 212 | 'tn + fp', average, warn_for)
|
216 | 213 |
|
217 |
| - # sensitivity = [_prf_divide(tp, tp + fn, 'sensitivity', 'tp + fn', average, |
218 |
| - # warn_for) for tp, fn in zip(tp_sum, fn_sum)] |
219 |
| - # specificity = [_prf_divide(tn, tn + fp, 'specificity', 'tn + fp', average, |
220 |
| - # warn_for) for tn, fp in zip(tn_sum, fp_sum)] |
221 |
| - |
222 | 214 | # If we need to weight the results
|
223 | 215 | if average == 'weighted':
|
224 | 216 | weights = support
|
@@ -259,13 +251,11 @@ def sensitivity_score(y_true, y_pred, labels=None, pos_label=1,
|
259 | 251 | order if ``average is None``. Labels present in the data can be
|
260 | 252 | excluded, for example to calculate a multiclass average ignoring a
|
261 | 253 | majority negative class, while labels not present in the data will
|
262 |
| - result in 0 components in a macro average. For multilabel targets, |
263 |
| - labels are column indices. By default, all labels in ``y_true`` and |
264 |
| - ``y_pred`` are used in sorted order. |
| 254 | + result in 0 components in a macro average. |
265 | 255 |
|
266 | 256 | pos_label : str or int, optional (default=1)
|
267 | 257 | The class to report if ``average='binary'`` and the data is binary.
|
268 |
| - If the data are multiclass or multilabel, this will be ignored; |
| 258 | + If the data are multiclass, this will be ignored; |
269 | 259 | setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
|
270 | 260 | scores for that label only.
|
271 | 261 |
|
@@ -331,13 +321,11 @@ def specificity_score(y_true, y_pred, labels=None, pos_label=1,
|
331 | 321 | order if ``average is None``. Labels present in the data can be
|
332 | 322 | excluded, for example to calculate a multiclass average ignoring a
|
333 | 323 | majority negative class, while labels not present in the data will
|
334 |
| - result in 0 components in a macro average. For multilabel targets, |
335 |
| - labels are column indices. By default, all labels in ``y_true`` and |
336 |
| - ``y_pred`` are used in sorted order. |
| 324 | + result in 0 components in a macro average. |
337 | 325 |
|
338 | 326 | pos_label : str or int, optional (default=1)
|
339 | 327 | The class to report if ``average='binary'`` and the data is binary.
|
340 |
| - If the data are multiclass or multilabel, this will be ignored; |
| 328 | + If the data are multiclass, this will be ignored; |
341 | 329 | setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
|
342 | 330 | scores for that label only.
|
343 | 331 |
|
@@ -377,3 +365,87 @@ def specificity_score(y_true, y_pred, labels=None, pos_label=1,
|
377 | 365 | sample_weight=sample_weight)
|
378 | 366 |
|
379 | 367 | return s
|
| 368 | + |
| 369 | + |
| 370 | +def geometric_mean_score(y_true, y_pred, labels=None, pos_label=1, |
| 371 | + average='binary', sample_weight=None): |
| 372 | + """Compute the geometric mean |
| 373 | +
|
| 374 | + The geometric mean is the squared root of the product of the sensitivity |
| 375 | + and specificity. This measure tries to maximize the accuracy on each |
| 376 | + of the two classes while keeping these accuracies balanced. |
| 377 | +
|
| 378 | + The specificity is the ratio ``tp / (tp + fn)`` where ``tp`` is the number |
| 379 | + of true positives and ``fn`` the number of false negatives. The specificity |
| 380 | + is intuitively the ability of the classifier to find all the positive |
| 381 | + samples. |
| 382 | +
|
| 383 | + The best value is 1 and the worst value is 0. |
| 384 | +
|
| 385 | + Parameters |
| 386 | + ---------- |
| 387 | + y_true : ndarray, shape (n_samples, ) |
| 388 | + Ground truth (correct) target values. |
| 389 | +
|
| 390 | + y_pred : ndarray, shape (n_samples, ) |
| 391 | + Estimated targets as returned by a classifier. |
| 392 | +
|
| 393 | + labels : list, optional |
| 394 | + The set of labels to include when ``average != 'binary'``, and their |
| 395 | + order if ``average is None``. Labels present in the data can be |
| 396 | + excluded, for example to calculate a multiclass average ignoring a |
| 397 | + majority negative class, while labels not present in the data will |
| 398 | + result in 0 components in a macro average. |
| 399 | +
|
| 400 | + pos_label : str or int, optional (default=1) |
| 401 | + The class to report if ``average='binary'`` and the data is binary. |
| 402 | + If the data are multiclass or multilabel, this will be ignored; |
| 403 | + setting ``labels=[pos_label]`` and ``average != 'binary'`` will report |
| 404 | + scores for that label only. |
| 405 | +
|
| 406 | + average : str or None, optional (default=None) |
| 407 | + If ``None``, the scores for each class are returned. Otherwise, this |
| 408 | + determines the type of averaging performed on the data: |
| 409 | +
|
| 410 | + ``'binary'``: |
| 411 | + Only report results for the class specified by ``pos_label``. |
| 412 | + This is applicable only if targets (``y_{true,pred}``) are binary. |
| 413 | + ``'macro'``: |
| 414 | + Calculate metrics for each label, and find their unweighted |
| 415 | + mean. This does not take label imbalance into account. |
| 416 | + ``'weighted'``: |
| 417 | + Calculate metrics for each label, and find their average, weighted |
| 418 | + by support (the number of true instances for each label). This |
| 419 | + alters 'macro' to account for label imbalance. |
| 420 | +
|
| 421 | + warn_for : tuple or set, for internal use |
| 422 | + This determines which warnings will be made in the case that this |
| 423 | + function is being used to return only one of its metrics. |
| 424 | +
|
| 425 | + sample_weight : ndarray, shape (n_samples, ) |
| 426 | + Sample weights. |
| 427 | +
|
| 428 | + Returns |
| 429 | + ------- |
| 430 | + geometric_mean : float (if ``average`` = None) or ndarray, \ |
| 431 | + shape (n_unique_labels, ) |
| 432 | +
|
| 433 | + References |
| 434 | + ---------- |
| 435 | + .. [1] Kubat, M. and Matwin, S. "Addressing the curse of |
| 436 | + imbalanced training sets: one-sided selection" ICML (1997) |
| 437 | +
|
| 438 | + .. [2] Barandela, R., Sánchez, J. S., Garcıa, V., & Rangel, E. "Strategies |
| 439 | + for learning in class imbalance problems", Pattern Recognition, |
| 440 | + 36(3), (2003), pp 849-851. |
| 441 | +
|
| 442 | + """ |
| 443 | + sen, spe, _ = sensitivity_specificity_support(y_true, y_pred, |
| 444 | + labels=labels, |
| 445 | + pos_label=pos_label, |
| 446 | + average=average, |
| 447 | + warn_for=('specificity', |
| 448 | + 'specificity'), |
| 449 | + sample_weight=sample_weight) |
| 450 | + |
| 451 | + return np.sqrt(sen * spe) |
0 commit comments