Skip to content

Commit ef64f81

Browse files
author
Guillaume Lemaitre
committed
add an example for multiclass
1 parent 3e2998b commit ef64f81

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""
2+
=============================================
3+
Multiclass classification with under-sampling
4+
=============================================
5+
6+
Some balancing methods allow for balancing dataset with multiples classes.
7+
We provide an example to illustrate the use of those methods which do
8+
not differ from the binary case.
9+
10+
"""
11+
12+
from sklearn.datasets import load_iris
13+
from sklearn.svm import LinearSVC
14+
from sklearn.model_selection import train_test_split
15+
16+
from imblearn.under_sampling import NearMiss
17+
from imblearn.pipeline import make_pipeline
18+
from imblearn.metrics import classification_report_imbalanced
19+
20+
print(__doc__)
21+
22+
RANDOM_STATE = 42
23+
24+
# Create a folder to fetch the dataset
25+
iris = load_iris()
26+
# Make the dataset imbalanced
27+
# Select only half of the first class
28+
iris.data = iris.data[25:-1, :]
29+
iris.target = iris.target[25:-1]
30+
31+
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target,
32+
random_state=RANDOM_STATE)
33+
34+
# Create a pipeline
35+
pipeline = make_pipeline(NearMiss(version=2), LinearSVC())
36+
pipeline.fit(X_train, y_train)
37+
38+
# Classify and report the results
39+
print(classification_report_imbalanced(y_test, pipeline.predict(X_test)))

examples/evaluation/plot_classification_report.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from sklearn.model_selection import train_test_split
1515

1616
from imblearn import over_sampling as os
17-
from imblearn import under_sampling as us
1817
from imblearn import pipeline as pl
1918
from imblearn.metrics import classification_report_imbalanced
2019

@@ -32,7 +31,8 @@
3231
pipeline = pl.make_pipeline(os.SMOTE(), LinearSVC())
3332

3433
# Split the data
35-
X_train, X_test, y_train, y_test = train_test_split(X, y)
34+
X_train, X_test, y_train, y_test = train_test_split(X, y,
35+
random_state=RANDOM_STATE)
3636

3737
# Train the classifier with balancing
3838
pipeline.fit(X_train, y_train)

0 commit comments

Comments
 (0)