|
3 | 3 | from __future__ import division, print_function
|
4 | 4 |
|
5 | 5 | import warnings
|
6 |
| - |
7 | 6 | from collections import Counter
|
8 | 7 |
|
9 | 8 | import numpy as np
|
10 |
| - |
| 9 | +from six import string_types |
| 10 | +import sklearn |
11 | 11 | from sklearn.base import ClassifierMixin
|
12 | 12 | from sklearn.ensemble import RandomForestClassifier
|
13 |
| -from sklearn.cross_validation import StratifiedKFold |
14 |
| - |
15 |
| -from six import string_types |
16 | 13 |
|
17 | 14 | from ..base import BaseBinarySampler
|
18 | 15 |
|
19 | 16 |
|
| 17 | +def _get_cv_splits(X, y, cv, random_state): |
| 18 | + if hasattr(sklearn, 'model_selection'): |
| 19 | + from sklearn.model_selection import StratifiedKFold |
| 20 | + cv_iterator = StratifiedKFold( |
| 21 | + n_splits=cv, shuffle=False, random_state=random_state).split(X, y) |
| 22 | + else: |
| 23 | + from sklearn.cross_validation import StratifiedKFold |
| 24 | + cv_iterator = StratifiedKFold( |
| 25 | + y, n_folds=cv, shuffle=False, random_state=random_state) |
| 26 | + |
| 27 | + return cv_iterator |
| 28 | + |
| 29 | + |
20 | 30 | class InstanceHardnessThreshold(BaseBinarySampler):
|
21 | 31 | """Class to perform under-sampling based on the instance hardness
|
22 | 32 | threshold.
|
@@ -225,8 +235,7 @@ def _sample(self, X, y):
|
225 | 235 | """
|
226 | 236 |
|
227 | 237 | # Create the different folds
|
228 |
| - skf = StratifiedKFold( |
229 |
| - y, n_folds=self.cv, shuffle=False, random_state=self.random_state) |
| 238 | + skf = _get_cv_splits(X, y, self.cv, self.random_state) |
230 | 239 |
|
231 | 240 | probabilities = np.zeros(y.shape[0], dtype=float)
|
232 | 241 |
|
|
0 commit comments