Skip to content

Commit ca8e7f4

Browse files
authored
[MRG] Example for JMLR (#211)
* Add a classification report example * add an example for multiclass * finish the example * Use signature instead of poping kwargs * Solve the issue with the doc * Correct mispealing * Add readme for dataset examples
1 parent ac16d91 commit ca8e7f4

File tree

7 files changed

+192
-23
lines changed

7 files changed

+192
-23
lines changed

doc/api.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ Metrics
133133
Functions
134134
---------
135135
.. autosummary::
136-
:toctree: generated/
136+
:toctree: generated/
137+
137138
metrics.sensitivity_specificity_support
138139
metrics.sensitivity_score
139140
metrics.specificity_score
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""
2+
=============================================
3+
Multiclass classification with under-sampling
4+
=============================================
5+
6+
Some balancing methods allow for balancing dataset with multiples classes.
7+
We provide an example to illustrate the use of those methods which do
8+
not differ from the binary case.
9+
10+
"""
11+
12+
from sklearn.datasets import load_iris
13+
from sklearn.svm import LinearSVC
14+
from sklearn.model_selection import train_test_split
15+
16+
from imblearn.under_sampling import NearMiss
17+
from imblearn.pipeline import make_pipeline
18+
from imblearn.metrics import classification_report_imbalanced
19+
20+
print(__doc__)
21+
22+
RANDOM_STATE = 42
23+
24+
# Create a folder to fetch the dataset
25+
iris = load_iris()
26+
# Make the dataset imbalanced
27+
# Select only half of the first class
28+
iris.data = iris.data[25:-1, :]
29+
iris.target = iris.target[25:-1]
30+
31+
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target,
32+
random_state=RANDOM_STATE)
33+
34+
# Create a pipeline
35+
pipeline = make_pipeline(NearMiss(version=2, random_state=RANDOM_STATE),
36+
LinearSVC(random_state=RANDOM_STATE))
37+
pipeline.fit(X_train, y_train)
38+
39+
# Classify and report the results
40+
print(classification_report_imbalanced(y_test, pipeline.predict(X_test)))

examples/datasets/README.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
.. _dataset_examples:
2+
3+
Dataset examples
4+
-----------------------
5+
6+
Examples concerning the :mod:`imblearn.datasets` module.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
=============================================
3+
Evaluate classification by compiling a report
4+
=============================================
5+
6+
Specific metrics have been developed to evaluate classifier which has been
7+
trained using imbalanced data. `imblearn` provides a classification
8+
report similar to `sklearn`, with additional metrics specific to imbalanced
9+
learning problem.
10+
"""
11+
12+
from sklearn import datasets
13+
from sklearn.svm import LinearSVC
14+
from sklearn.model_selection import train_test_split
15+
16+
from imblearn import over_sampling as os
17+
from imblearn import pipeline as pl
18+
from imblearn.metrics import classification_report_imbalanced
19+
20+
print(__doc__)
21+
22+
RANDOM_STATE = 42
23+
24+
# Generate a dataset
25+
X, y = datasets.make_classification(n_classes=2, class_sep=2,
26+
weights=[0.1, 0.9], n_informative=10,
27+
n_redundant=1, flip_y=0, n_features=20,
28+
n_clusters_per_class=4, n_samples=5000,
29+
random_state=RANDOM_STATE)
30+
31+
pipeline = pl.make_pipeline(os.SMOTE(random_state=RANDOM_STATE),
32+
LinearSVC(random_state=RANDOM_STATE))
33+
34+
# Split the data
35+
X_train, X_test, y_train, y_test = train_test_split(X, y,
36+
random_state=RANDOM_STATE)
37+
38+
# Train the classifier with balancing
39+
pipeline.fit(X_train, y_train)
40+
41+
# Test the classifier and get the prediction
42+
y_pred_bal = pipeline.predict(X_test)
43+
44+
# Show the classification report
45+
print(classification_report_imbalanced(y_test, y_pred_bal))

examples/evaluation/plot_metrics.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
=======================================
3+
Metrics specific to imbalanced learning
4+
=======================================
5+
6+
Specific metrics have been developed to evaluate classifier which
7+
has been trained using imbalanced data. `imblearn` provides mainly
8+
two additional metrics which are not implemented in `sklearn`: (i)
9+
geometric mean and (ii) index balanced accuracy.
10+
"""
11+
12+
from sklearn import datasets
13+
from sklearn.svm import LinearSVC
14+
from sklearn.model_selection import train_test_split
15+
16+
from imblearn import over_sampling as os
17+
from imblearn import pipeline as pl
18+
from imblearn.metrics import (geometric_mean_score,
19+
make_index_balanced_accuracy)
20+
21+
print(__doc__)
22+
23+
RANDOM_STATE = 42
24+
25+
# Generate a dataset
26+
X, y = datasets.make_classification(n_classes=3, class_sep=2,
27+
weights=[0.1, 0.9], n_informative=10,
28+
n_redundant=1, flip_y=0, n_features=20,
29+
n_clusters_per_class=4, n_samples=5000,
30+
random_state=RANDOM_STATE)
31+
32+
pipeline = pl.make_pipeline(os.SMOTE(random_state=RANDOM_STATE),
33+
LinearSVC(random_state=RANDOM_STATE))
34+
35+
# Split the data
36+
X_train, X_test, y_train, y_test = train_test_split(X, y,
37+
random_state=RANDOM_STATE)
38+
39+
# Train the classifier with balancing
40+
pipeline.fit(X_train, y_train)
41+
42+
# Test the classifier and get the prediction
43+
y_pred_bal = pipeline.predict(X_test)
44+
45+
###############################################################################
46+
# The geometric mean corresponds to the square root of the product of the
47+
# sensitivity and specificity. Combining the two metrics should account for
48+
# the balancing of the dataset.
49+
50+
print('The geometric mean is {}'.format(geometric_mean_score(
51+
y_test,
52+
y_pred_bal)))
53+
54+
###############################################################################
55+
# The index balanced accuracy can transform any metric to be used in
56+
# imbalanced learning problems.
57+
58+
alpha = 0.1
59+
geo_mean = make_index_balanced_accuracy(alpha=alpha, squared=True)(
60+
geometric_mean_score)
61+
62+
print('The IBA using alpha = {} and the geometric mean: {}'.format(
63+
alpha, geo_mean(
64+
y_test,
65+
y_pred_bal)))
66+
67+
alpha = 0.5
68+
geo_mean = make_index_balanced_accuracy(alpha=alpha, squared=True)(
69+
geometric_mean_score)
70+
71+
print('The IBA using alpha = {} and the geometric mean: {}'.format(
72+
alpha, geo_mean(
73+
y_test,
74+
y_pred_bal)))

examples/model_selection/plot_validation_curve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
weights=[0.1, 0.9], n_informative=10,
3030
n_redundant=1, flip_y=0, n_features=20,
3131
n_clusters_per_class=4, n_samples=5000,
32-
random_state=10)
32+
random_state=RANDOM_STATE)
3333
smote = os.SMOTE(random_state=RANDOM_STATE)
3434
cart = tree.DecisionTreeClassifier(random_state=RANDOM_STATE)
3535
pipeline = pl.make_pipeline(smote, cart)

imblearn/metrics/classification.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import logging
1515
import functools
1616

17+
from inspect import getcallargs
18+
1719
import numpy as np
1820

1921
from sklearn.metrics.classification import (_check_targets, _prf_divide,
@@ -22,6 +24,12 @@
2224
from sklearn.utils.fixes import bincount
2325
from sklearn.utils.multiclass import unique_labels
2426

27+
try:
28+
from inspect import signature
29+
except ImportError:
30+
from sklearn.externals.funcsigs import signature
31+
32+
2533
LOGGER = logging.getLogger(__name__)
2634

2735

@@ -563,10 +571,10 @@ def geometric_mean_score(y_true,
563571

564572

565573
def make_index_balanced_accuracy(alpha=0.1, squared=True):
566-
"""Balance any scoring function using the indexed balanced accuracy
574+
"""Balance any scoring function using the index balanced accuracy
567575
568576
This factory function wraps scoring function to express it as the
569-
indexed balanced accuracy (IBA). You need to use this function to
577+
index balanced accuracy (IBA). You need to use this function to
570578
decorate any scoring function.
571579
572580
Parameters
@@ -582,7 +590,7 @@ def make_index_balanced_accuracy(alpha=0.1, squared=True):
582590
-------
583591
iba_scoring_func : callable,
584592
Returns the scoring metric decorated which will automatically compute
585-
the indexed balanced accuracy.
593+
the index balanced accuracy.
586594
587595
Examples
588596
--------
@@ -603,21 +611,16 @@ def compute_score(*args, **kwargs):
603611
# Square if desired
604612
if squared:
605613
_score = np.power(_score, 2)
606-
# args will contain the y_pred and y_true
607-
# kwargs will contain the other parameters
608-
labels = kwargs.get('labels', None)
609-
pos_label = kwargs.get('pos_label', 1)
610-
average = kwargs.get('average', 'binary')
611-
sample_weight = kwargs.get('sample_weight', None)
612-
# Compute the sensitivity and specificity
613-
dict_sen_spe = {
614-
'labels': labels,
615-
'pos_label': pos_label,
616-
'average': average,
617-
'sample_weight': sample_weight
618-
}
619-
sen, spe, _ = sensitivity_specificity_support(*args,
620-
**dict_sen_spe)
614+
# Create the list of tags
615+
tags_scoring_func = getcallargs(scoring_func, *args, **kwargs)
616+
# Get the signature of the sens/spec function
617+
sens_spec_sig = signature(sensitivity_specificity_support)
618+
# Filter the inputs required by the sens/spec function
619+
tags_sens_spec = sens_spec_sig.bind(**tags_scoring_func)
620+
# Call the sens/spec function
621+
sen, spe, _ = sensitivity_specificity_support(
622+
*tags_sens_spec.args,
623+
**tags_sens_spec.kwargs)
621624
# Compute the dominance
622625
dom = sen - spe
623626
return (1. + alpha * dom) * _score
@@ -640,7 +643,7 @@ def classification_report_imbalanced(y_true,
640643
Specific metrics have been proposed to evaluate the classification
641644
performed on imbalanced dataset. This report compiles the
642645
state-of-the-art metrics: precision/recall/specificity, geometric
643-
mean, and indexed balanced accuracy of the
646+
mean, and index balanced accuracy of the
644647
geometric mean.
645648
646649
Parameters
@@ -674,7 +677,7 @@ def classification_report_imbalanced(y_true,
674677
-------
675678
report : string
676679
Text summary of the precision, recall, specificity, geometric mean,
677-
and indexed balanced accuracy.
680+
and index balanced accuracy.
678681
679682
Examples
680683
--------
@@ -746,7 +749,7 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.69\
746749
labels=labels,
747750
average=None,
748751
sample_weight=sample_weight)
749-
# Indexed balanced accuracy
752+
# Index balanced accuracy
750753
iba_gmean = make_index_balanced_accuracy(
751754
alpha=alpha, squared=True)(geometric_mean_score)
752755
iba = iba_gmean(

0 commit comments

Comments
 (0)