|
8 | 8 | from scipy import sparse |
9 | 9 |
|
10 | 10 | from sklearn.feature_selection.rfe import RFE, RFECV |
11 | | -from sklearn.datasets import load_iris, make_friedman1, make_regression |
| 11 | +from sklearn.datasets import load_iris, make_friedman1 |
12 | 12 | from sklearn.metrics import zero_one_loss |
13 | 13 | from sklearn.svm import SVC, SVR |
14 | | -from sklearn.linear_model import LinearRegression |
15 | 14 | from sklearn.ensemble import RandomForestClassifier |
| 15 | +from sklearn.cross_validation import cross_val_score |
16 | 16 |
|
17 | 17 | from sklearn.utils import check_random_state |
18 | 18 | from sklearn.utils.testing import ignore_warnings |
19 | 19 | from sklearn.utils.testing import assert_warns_message |
| 20 | +from sklearn.utils.testing import assert_greater |
20 | 21 |
|
21 | 22 | from sklearn.metrics import make_scorer |
22 | 23 | from sklearn.metrics import get_scorer |
@@ -94,7 +95,6 @@ def test_rfe_features_importance(): |
94 | 95 | assert_array_equal(rfe.get_support(), rfe_svc.get_support()) |
95 | 96 |
|
96 | 97 |
|
97 | | - |
98 | 98 | def test_rfe_deprecation_estimator_params(): |
99 | 99 | deprecation_message = ("The parameter 'estimator_params' is deprecated as " |
100 | 100 | "of version 0.16 and will be removed in 0.18. The " |
@@ -240,6 +240,15 @@ def test_rfecv_mockclassifier(): |
240 | 240 | assert_equal(len(rfecv.ranking_), X.shape[1]) |
241 | 241 |
|
242 | 242 |
|
| 243 | +def test_rfe_estimator_tags(): |
| 244 | + rfe = RFE(SVC(kernel='linear')) |
| 245 | + assert_equal(rfe._estimator_type, "classifier") |
| 246 | + # make sure that cross-validation is stratified |
| 247 | + iris = load_iris() |
| 248 | + score = cross_val_score(rfe, iris.data, iris.target) |
| 249 | + assert_greater(score.min(), .7) |
| 250 | + |
| 251 | + |
243 | 252 | def test_rfe_min_step(): |
244 | 253 | n_features = 10 |
245 | 254 | X, y = make_friedman1(n_samples=50, n_features=n_features, random_state=0) |
@@ -289,7 +298,7 @@ def formula2(n_features, n_features_to_select, step): |
289 | 298 | X = generator.normal(size=(100, n_features)) |
290 | 299 | y = generator.rand(100).round() |
291 | 300 | rfe = RFE(estimator=SVC(kernel="linear"), |
292 | | - n_features_to_select=n_features_to_select, step=step) |
| 301 | + n_features_to_select=n_features_to_select, step=step) |
293 | 302 | rfe.fit(X, y) |
294 | 303 | # this number also equals to the maximum of ranking_ |
295 | 304 | assert_equal(np.max(rfe.ranking_), |
@@ -317,6 +326,6 @@ def formula2(n_features, n_features_to_select, step): |
317 | 326 | rfecv.fit(X, y) |
318 | 327 |
|
319 | 328 | assert_equal(rfecv.grid_scores_.shape[0], |
320 | | - formula1(n_features, n_features_to_select, step)) |
| 329 | + formula1(n_features, n_features_to_select, step)) |
321 | 330 | assert_equal(rfecv.grid_scores_.shape[0], |
322 | | - formula2(n_features, n_features_to_select, step)) |
| 331 | + formula2(n_features, n_features_to_select, step)) |
0 commit comments