Skip to content

Commit e51afd3

Browse files
author
Guillaume Lemaitre
committed
Add the testing for geometric mean
1 parent ac3d0de commit e51afd3

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

imblearn/metrics/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
from .classification import sensitivity_specificity_support
77
from .classification import sensitivity_score
88
from .classification import specificity_score
9+
from .classification import geometric_mean_score
910

1011
__all__ = [
1112
'sensitivity_specificity_support',
1213
'sensitivity_score',
13-
'specificity_score'
14+
'specificity_score',
15+
'geometric_mean_score'
1416
]

imblearn/metrics/classification.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,4 +448,6 @@ def geometric_mean_score(y_true, y_pred, labels=None, pos_label=1,
448448
'specificity'),
449449
sample_weight=sample_weight)
450450

451+
LOGGER.debug('The sensitivity and specificity are : %s - %s' % (sen, spe))
452+
451453
return np.sqrt(sen * spe)

imblearn/metrics/tests/test_classification.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from imblearn.metrics import sensitivity_specificity_support
2222
from imblearn.metrics import sensitivity_score
2323
from imblearn.metrics import specificity_score
24+
from imblearn.metrics import geometric_mean_score
2425

2526
RND_SEED = 42
2627

@@ -185,7 +186,7 @@ def test_sensitivity_specificity_unused_pos_label():
185186

186187

187188
def test_sensitivity_specificity_multiclass():
188-
# Test Precision Recall and F1 Score for multiclass classification task
189+
# Test sensitivity and specificity for multiclass classification task
189190
y_true, y_pred, _ = make_prediction(binary=False)
190191

191192
# compute scores with default labels introspection
@@ -216,3 +217,29 @@ def test_sensitivity_specificity_multiclass():
216217
assert_array_almost_equal(spec, [0.92, 0.55, 0.86], 2)
217218
assert_array_almost_equal(sens, [0.79, 0.90, 0.10], 2)
218219
assert_array_equal(supp, [24, 20, 31])
220+
221+
222+
def test_geometric_mean_support_binary():
223+
"""Test the geometric mean for binary classification task"""
224+
y_true, y_pred, _ = make_prediction(binary=True)
225+
226+
# compute the geometric mean for the binary problem
227+
geo_mean = geometric_mean_score(y_true, y_pred)
228+
229+
assert_almost_equal(geo_mean, 0.77, 2)
230+
231+
232+
def test_geometric_mean_multiclass():
233+
# Test geometric mean for multiclass classification task
234+
y_true, y_pred, _ = make_prediction(binary=False)
235+
236+
# Compute the geometric mean for each of the classes
237+
geo_mean = geometric_mean_score(y_true, y_pred, average=None)
238+
assert_array_almost_equal(geo_mean, [0.85, 0.29, 0.7], 2)
239+
240+
# average tests
241+
geo_mean = geometric_mean_score(y_true, y_pred, average='macro')
242+
assert_almost_equal(geo_mean, 0.68, 2)
243+
244+
geo_mean = geometric_mean_score(y_true, y_pred, average='weighted')
245+
assert_array_almost_equal(geo_mean, 0.65, 2)

0 commit comments

Comments
 (0)