|
21 | 21 | from imblearn.metrics import sensitivity_specificity_support
|
22 | 22 | from imblearn.metrics import sensitivity_score
|
23 | 23 | from imblearn.metrics import specificity_score
|
| 24 | +from imblearn.metrics import geometric_mean_score |
24 | 25 |
|
25 | 26 | RND_SEED = 42
|
26 | 27 |
|
@@ -185,7 +186,7 @@ def test_sensitivity_specificity_unused_pos_label():
|
185 | 186 |
|
186 | 187 |
|
187 | 188 | def test_sensitivity_specificity_multiclass():
|
188 |
| - # Test Precision Recall and F1 Score for multiclass classification task |
| 189 | + # Test sensitivity and specificity for multiclass classification task |
189 | 190 | y_true, y_pred, _ = make_prediction(binary=False)
|
190 | 191 |
|
191 | 192 | # compute scores with default labels introspection
|
@@ -216,3 +217,29 @@ def test_sensitivity_specificity_multiclass():
|
216 | 217 | assert_array_almost_equal(spec, [0.92, 0.55, 0.86], 2)
|
217 | 218 | assert_array_almost_equal(sens, [0.79, 0.90, 0.10], 2)
|
218 | 219 | assert_array_equal(supp, [24, 20, 31])
|
| 220 | + |
| 221 | + |
| 222 | +def test_geometric_mean_support_binary(): |
| 223 | + """Test the geometric mean for binary classification task""" |
| 224 | + y_true, y_pred, _ = make_prediction(binary=True) |
| 225 | + |
| 226 | + # compute the geometric mean for the binary problem |
| 227 | + geo_mean = geometric_mean_score(y_true, y_pred) |
| 228 | + |
| 229 | + assert_almost_equal(geo_mean, 0.77, 2) |
| 230 | + |
| 231 | + |
| 232 | +def test_geometric_mean_multiclass(): |
| 233 | + # Test geometric mean for multiclass classification task |
| 234 | + y_true, y_pred, _ = make_prediction(binary=False) |
| 235 | + |
| 236 | + # Compute the geometric mean for each of the classes |
| 237 | + geo_mean = geometric_mean_score(y_true, y_pred, average=None) |
| 238 | + assert_array_almost_equal(geo_mean, [0.85, 0.29, 0.7], 2) |
| 239 | + |
| 240 | + # average tests |
| 241 | + geo_mean = geometric_mean_score(y_true, y_pred, average='macro') |
| 242 | + assert_almost_equal(geo_mean, 0.68, 2) |
| 243 | + |
| 244 | + geo_mean = geometric_mean_score(y_true, y_pred, average='weighted') |
| 245 | + assert_array_almost_equal(geo_mean, 0.65, 2) |
0 commit comments