Skip to content

Commit d28f450

Browse files
author
Guillaume Lemaitre
committed
advance the testing
1 parent 3a91d4d commit d28f450

File tree

3 files changed

+158
-138
lines changed

3 files changed

+158
-138
lines changed

imblearn/metrics/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,8 @@
33
metrics and pairwise metrics and distance computations.
44
"""
55

6-
import numpy as np
6+
from .classification import sensitivity_specificity_support
77

8+
__all__ = [
9+
'sensitivity_specificity_support'
10+
]

imblearn/metrics/classification.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from __future__ import division
1111

1212
import warnings
13+
import logging
1314

1415
import numpy as np
1516

@@ -148,21 +149,26 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
148149
le.fit(labels)
149150
y_true = le.transform(y_true)
150151
y_pred = le.transform(y_pred)
151-
sorted_labels = le.classes_n
152+
sorted_labels = le.classes_
152153

153154
# In a leave out strategy and for each label, compute:
154155
# TP, TN, FP, FN
155156
# These list contain an array in which each sample is labeled as
156157
# TP, TN, FP, FN
157-
list_tp = [(y_true == label) == (y_pred == label)
158+
list_tp = [np.bitwise_and((y_true == label), (y_pred == label))
158159
for label in sorted_labels]
159-
list_tn = [(y_true != label) == (y_pred != label)
160+
list_tn = [np.bitwise_and((y_true != label), (y_pred != label))
160161
for label in sorted_labels]
161-
list_fp = [(y_true == label) == (y_pred != label)
162+
list_fp = [np.bitwise_and((y_true == label), (y_pred != label))
162163
for label in sorted_labels]
163-
list_fn = [(y_true != label) == (y_pred == label)
164+
list_fn = [np.bitwise_and((y_true != label), (y_pred == label))
164165
for label in sorted_labels]
165166

167+
LOGGER.debug(list_tp)
168+
LOGGER.debug(list_tn)
169+
LOGGER.debug(list_fn)
170+
LOGGER.debug(list_fn)
171+
166172
# Compute the sum for each type
167173
tp_sum = [bincount(tp, weights=sample_weight, minlength=len(labels))
168174
for tp in list_tp]
@@ -173,6 +179,11 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
173179
fn_sum = [bincount(fn, weights=sample_weight, minlength=len(labels))
174180
for fn in list_fn]
175181

182+
LOGGER.debug(tp_sum)
183+
LOGGER.debug(tn_sum)
184+
LOGGER.debug(fp_sum)
185+
LOGGER.debug(fn_sum)
186+
176187
# Retain only selected labels
177188
indices = np.searchsorted(sorted_labels, labels[:n_labels])
178189
tp_sum = [tp[indices] for tp in tp_sum]
@@ -188,6 +199,10 @@ def sensitivity_specificity_support(y_true, y_pred, labels=None,
188199
specificity = [_prf_divide(tn, tn + fp, 'specificity', 'tn + fp', average,
189200
warn_for) for tn, fp in zip(tn_sum, fp_sum)]
190201

202+
LOGGER.debug('Computed the sensitivity and specificity for each class')
203+
LOGGER.debug('The lengths of those two metrics are: %s - %s',
204+
len(sensitivity), len(specificity))
205+
191206
# If we need to weight the results
192207
if average == 'weighted':
193208
weights = tp_sum

imblearn/metrics/tests/test_classification.py

Lines changed: 134 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from sklearn.utils.validation import check_random_state
1616

17+
from imblearn.metrics import sensitivity_specificity_support
18+
1719
RND_SEED = 42
1820

1921
###############################################################################
@@ -67,138 +69,138 @@ def make_prediction(dataset=None, binary=False):
6769
###############################################################################
6870
# Tests
6971

70-
def test_precision_recall_f1_score_binary():
71-
# Test Precision Recall and F1 Score for binary classification task
72+
def test_sensitivity_specificity_support_binary():
73+
"""Test the sensitivity specificity for binary classification task"""
7274
y_true, y_pred, _ = make_prediction(binary=True)
7375

7476
# detailed measures for each class
75-
p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average=None)
76-
assert_array_almost_equal(p, [0.73, 0.85], 2)
77-
assert_array_almost_equal(r, [0.88, 0.68], 2)
78-
assert_array_almost_equal(f, [0.80, 0.76], 2)
79-
assert_array_equal(s, [25, 25])
80-
81-
# individual scoring function that can be used for grid search: in the
82-
# binary class case the score is the value of the measure for the positive
83-
# class (e.g. label == 1). This is deprecated for average != 'binary'.
84-
for kwargs, my_assert in [({}, assert_no_warnings),
85-
({'average': 'binary'}, assert_no_warnings)]:
86-
ps = my_assert(precision_score, y_true, y_pred, **kwargs)
87-
assert_array_almost_equal(ps, 0.85, 2)
88-
89-
rs = my_assert(recall_score, y_true, y_pred, **kwargs)
90-
assert_array_almost_equal(rs, 0.68, 2)
91-
92-
fs = my_assert(f1_score, y_true, y_pred, **kwargs)
93-
assert_array_almost_equal(fs, 0.76, 2)
94-
95-
assert_almost_equal(my_assert(fbeta_score, y_true, y_pred, beta=2,
96-
**kwargs),
97-
(1 + 2 ** 2) * ps * rs / (2 ** 2 * ps + rs), 2)
98-
99-
100-
def test_precision_recall_f_binary_single_class():
101-
# Test precision, recall and F1 score behave with a single positive or
102-
# negative class
103-
# Such a case may occur with non-stratified cross-validation
104-
assert_equal(1., precision_score([1, 1], [1, 1]))
105-
assert_equal(1., recall_score([1, 1], [1, 1]))
106-
assert_equal(1., f1_score([1, 1], [1, 1]))
107-
108-
assert_equal(0., precision_score([-1, -1], [-1, -1]))
109-
assert_equal(0., recall_score([-1, -1], [-1, -1]))
110-
assert_equal(0., f1_score([-1, -1], [-1, -1]))
111-
112-
113-
@ignore_warnings
114-
def test_precision_recall_f_extra_labels():
115-
# Test handling of explicit additional (not in input) labels to PRF
116-
y_true = [1, 3, 3, 2]
117-
y_pred = [1, 1, 3, 2]
118-
y_true_bin = label_binarize(y_true, classes=np.arange(5))
119-
y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
120-
data = [(y_true, y_pred),
121-
(y_true_bin, y_pred_bin)]
122-
123-
for i, (y_true, y_pred) in enumerate(data):
124-
# No average: zeros in array
125-
actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4],
126-
average=None)
127-
assert_array_almost_equal([0., 1., 1., .5, 0.], actual)
128-
129-
# Macro average is changed
130-
actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4],
131-
average='macro')
132-
assert_array_almost_equal(np.mean([0., 1., 1., .5, 0.]), actual)
133-
134-
# No effect otheriwse
135-
for average in ['micro', 'weighted', 'samples']:
136-
if average == 'samples' and i == 0:
137-
continue
138-
assert_almost_equal(recall_score(y_true, y_pred,
139-
labels=[0, 1, 2, 3, 4],
140-
average=average),
141-
recall_score(y_true, y_pred, labels=None,
142-
average=average))
143-
144-
# Error when introducing invalid label in multilabel case
145-
# (although it would only affect performance if average='macro'/None)
146-
for average in [None, 'macro', 'micro', 'samples']:
147-
assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin,
148-
labels=np.arange(6), average=average)
149-
assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin,
150-
labels=np.arange(-1, 4), average=average)
151-
152-
153-
@ignore_warnings
154-
def test_precision_recall_f_ignored_labels():
155-
# Test a subset of labels may be requested for PRF
156-
y_true = [1, 1, 2, 3]
157-
y_pred = [1, 3, 3, 3]
158-
y_true_bin = label_binarize(y_true, classes=np.arange(5))
159-
y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
160-
data = [(y_true, y_pred),
161-
(y_true_bin, y_pred_bin)]
162-
163-
for i, (y_true, y_pred) in enumerate(data):
164-
recall_13 = partial(recall_score, y_true, y_pred, labels=[1, 3])
165-
recall_all = partial(recall_score, y_true, y_pred, labels=None)
166-
167-
assert_array_almost_equal([.5, 1.], recall_13(average=None))
168-
assert_almost_equal((.5 + 1.) / 2, recall_13(average='macro'))
169-
assert_almost_equal((.5 * 2 + 1. * 1) / 3,
170-
recall_13(average='weighted'))
171-
assert_almost_equal(2. / 3, recall_13(average='micro'))
172-
173-
# ensure the above were meaningful tests:
174-
for average in ['macro', 'weighted', 'micro']:
175-
assert_not_equal(recall_13(average=average),
176-
recall_all(average=average))
177-
178-
179-
@ignore_warnings
180-
def test_precision_recall_fscore_support_errors():
181-
y_true, y_pred, _ = make_prediction(binary=True)
182-
183-
# Bad beta
184-
assert_raises(ValueError, precision_recall_fscore_support,
185-
y_true, y_pred, beta=0.0)
186-
187-
# Bad pos_label
188-
assert_raises(ValueError, precision_recall_fscore_support,
189-
y_true, y_pred, pos_label=2, average='binary')
190-
191-
# Bad average option
192-
assert_raises(ValueError, precision_recall_fscore_support,
193-
[0, 1, 2], [1, 2, 0], average='mega')
194-
195-
196-
def test_precision_recall_f_unused_pos_label():
197-
# Check warning that pos_label unused when set to non-default value
198-
# but average != 'binary'; even if data is binary.
199-
assert_warns_message(UserWarning,
200-
"Note that pos_label (set to 2) is "
201-
"ignored when average != 'binary' (got 'macro'). You "
202-
"may use labels=[pos_label] to specify a single "
203-
"positive class.", precision_recall_fscore_support,
204-
[1, 2, 1], [1, 2, 2], pos_label=2, average='macro')
77+
sens, spec, supp = sensitivity_specificity_support(y_true, y_pred,
78+
average=None)
79+
assert_array_almost_equal(sens, [0.88, 0.68], 2)
80+
assert_array_almost_equal(spec, [0.73, 0.85], 2)
81+
assert_array_equal(supp, [25, 25])
82+
83+
# # individual scoring function that can be used for grid search: in the
84+
# # binary class case the score is the value of the measure for the positive
85+
# # class (e.g. label == 1). This is deprecated for average != 'binary'.
86+
# for kwargs, my_assert in [({}, assert_no_warnings),
87+
# ({'average': 'binary'}, assert_no_warnings)]:
88+
# ps = my_assert(precision_score, y_true, y_pred, **kwargs)
89+
# assert_array_almost_equal(ps, 0.85, 2)
90+
91+
# rs = my_assert(recall_score, y_true, y_pred, **kwargs)
92+
# assert_array_almost_equal(rs, 0.68, 2)
93+
94+
# fs = my_assert(f1_score, y_true, y_pred, **kwargs)
95+
# assert_array_almost_equal(fs, 0.76, 2)
96+
97+
# assert_almost_equal(my_assert(fbeta_score, y_true, y_pred, beta=2,
98+
# **kwargs),
99+
# (1 + 2 ** 2) * ps * rs / (2 ** 2 * ps + rs), 2)
100+
101+
102+
# def test_precision_recall_f_binary_single_class():
103+
# # Test precision, recall and F1 score behave with a single positive or
104+
# # negative class
105+
# # Such a case may occur with non-stratified cross-validation
106+
# assert_equal(1., precision_score([1, 1], [1, 1]))
107+
# assert_equal(1., recall_score([1, 1], [1, 1]))
108+
# assert_equal(1., f1_score([1, 1], [1, 1]))
109+
110+
# assert_equal(0., precision_score([-1, -1], [-1, -1]))
111+
# assert_equal(0., recall_score([-1, -1], [-1, -1]))
112+
# assert_equal(0., f1_score([-1, -1], [-1, -1]))
113+
114+
115+
# @ignore_warnings
116+
# def test_precision_recall_f_extra_labels():
117+
# # Test handling of explicit additional (not in input) labels to PRF
118+
# y_true = [1, 3, 3, 2]
119+
# y_pred = [1, 1, 3, 2]
120+
# y_true_bin = label_binarize(y_true, classes=np.arange(5))
121+
# y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
122+
# data = [(y_true, y_pred),
123+
# (y_true_bin, y_pred_bin)]
124+
125+
# for i, (y_true, y_pred) in enumerate(data):
126+
# # No average: zeros in array
127+
# actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4],
128+
# average=None)
129+
# assert_array_almost_equal([0., 1., 1., .5, 0.], actual)
130+
131+
# # Macro average is changed
132+
# actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4],
133+
# average='macro')
134+
# assert_array_almost_equal(np.mean([0., 1., 1., .5, 0.]), actual)
135+
136+
# # No effect otheriwse
137+
# for average in ['micro', 'weighted', 'samples']:
138+
# if average == 'samples' and i == 0:
139+
# continue
140+
# assert_almost_equal(recall_score(y_true, y_pred,
141+
# labels=[0, 1, 2, 3, 4],
142+
# average=average),
143+
# recall_score(y_true, y_pred, labels=None,
144+
# average=average))
145+
146+
# # Error when introducing invalid label in multilabel case
147+
# # (although it would only affect performance if average='macro'/None)
148+
# for average in [None, 'macro', 'micro', 'samples']:
149+
# assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin,
150+
# labels=np.arange(6), average=average)
151+
# assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin,
152+
# labels=np.arange(-1, 4), average=average)
153+
154+
155+
# @ignore_warnings
156+
# def test_precision_recall_f_ignored_labels():
157+
# # Test a subset of labels may be requested for PRF
158+
# y_true = [1, 1, 2, 3]
159+
# y_pred = [1, 3, 3, 3]
160+
# y_true_bin = label_binarize(y_true, classes=np.arange(5))
161+
# y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
162+
# data = [(y_true, y_pred),
163+
# (y_true_bin, y_pred_bin)]
164+
165+
# for i, (y_true, y_pred) in enumerate(data):
166+
# recall_13 = partial(recall_score, y_true, y_pred, labels=[1, 3])
167+
# recall_all = partial(recall_score, y_true, y_pred, labels=None)
168+
169+
# assert_array_almost_equal([.5, 1.], recall_13(average=None))
170+
# assert_almost_equal((.5 + 1.) / 2, recall_13(average='macro'))
171+
# assert_almost_equal((.5 * 2 + 1. * 1) / 3,
172+
# recall_13(average='weighted'))
173+
# assert_almost_equal(2. / 3, recall_13(average='micro'))
174+
175+
# # ensure the above were meaningful tests:
176+
# for average in ['macro', 'weighted', 'micro']:
177+
# assert_not_equal(recall_13(average=average),
178+
# recall_all(average=average))
179+
180+
181+
# @ignore_warnings
182+
# def test_precision_recall_fscore_support_errors():
183+
# y_true, y_pred, _ = make_prediction(binary=True)
184+
185+
# # Bad beta
186+
# assert_raises(ValueError, precision_recall_fscore_support,
187+
# y_true, y_pred, beta=0.0)
188+
189+
# # Bad pos_label
190+
# assert_raises(ValueError, precision_recall_fscore_support,
191+
# y_true, y_pred, pos_label=2, average='binary')
192+
193+
# # Bad average option
194+
# assert_raises(ValueError, precision_recall_fscore_support,
195+
# [0, 1, 2], [1, 2, 0], average='mega')
196+
197+
198+
# def test_precision_recall_f_unused_pos_label():
199+
# # Check warning that pos_label unused when set to non-default value
200+
# # but average != 'binary'; even if data is binary.
201+
# assert_warns_message(UserWarning,
202+
# "Note that pos_label (set to 2) is "
203+
# "ignored when average != 'binary' (got 'macro'). You "
204+
# "may use labels=[pos_label] to specify a single "
205+
# "positive class.", precision_recall_fscore_support,
206+
# [1, 2, 1], [1, 2, 2], pos_label=2, average='macro')

0 commit comments

Comments
 (0)