Skip to content

Commit f369d3a

Browse files
committed
Fix RFE / RFECV estimator tags
1 parent 44f17b0 commit f369d3a

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

sklearn/feature_selection/rfe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ def __init__(self, estimator, n_features_to_select=None, step=1,
110110
self.estimator_params = estimator_params
111111
self.verbose = verbose
112112

113+
@property
114+
def _estimator_type(self):
115+
return self.estimator._estimator_type
116+
113117
def fit(self, X, y):
114118
"""Fit the RFE model and then the underlying estimator on the selected
115119
features.

sklearn/feature_selection/tests/test_rfe.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,16 @@
88
from scipy import sparse
99

1010
from sklearn.feature_selection.rfe import RFE, RFECV
11-
from sklearn.datasets import load_iris, make_friedman1, make_regression
11+
from sklearn.datasets import load_iris, make_friedman1
1212
from sklearn.metrics import zero_one_loss
1313
from sklearn.svm import SVC, SVR
14-
from sklearn.linear_model import LinearRegression
1514
from sklearn.ensemble import RandomForestClassifier
15+
from sklearn.cross_validation import cross_val_score
1616

1717
from sklearn.utils import check_random_state
1818
from sklearn.utils.testing import ignore_warnings
1919
from sklearn.utils.testing import assert_warns_message
20+
from sklearn.utils.testing import assert_greater
2021

2122
from sklearn.metrics import make_scorer
2223
from sklearn.metrics import get_scorer
@@ -94,7 +95,6 @@ def test_rfe_features_importance():
9495
assert_array_equal(rfe.get_support(), rfe_svc.get_support())
9596

9697

97-
9898
def test_rfe_deprecation_estimator_params():
9999
deprecation_message = ("The parameter 'estimator_params' is deprecated as "
100100
"of version 0.16 and will be removed in 0.18. The "
@@ -240,6 +240,15 @@ def test_rfecv_mockclassifier():
240240
assert_equal(len(rfecv.ranking_), X.shape[1])
241241

242242

243+
def test_rfe_estimator_tags():
244+
rfe = RFE(SVC(kernel='linear'))
245+
assert_equal(rfe._estimator_type, "classifier")
246+
# make sure that cross-validation is stratified
247+
iris = load_iris()
248+
score = cross_val_score(rfe, iris.data, iris.target)
249+
assert_greater(score.min(), .7)
250+
251+
243252
def test_rfe_min_step():
244253
n_features = 10
245254
X, y = make_friedman1(n_samples=50, n_features=n_features, random_state=0)
@@ -289,7 +298,7 @@ def formula2(n_features, n_features_to_select, step):
289298
X = generator.normal(size=(100, n_features))
290299
y = generator.rand(100).round()
291300
rfe = RFE(estimator=SVC(kernel="linear"),
292-
n_features_to_select=n_features_to_select, step=step)
301+
n_features_to_select=n_features_to_select, step=step)
293302
rfe.fit(X, y)
294303
# this number also equals to the maximum of ranking_
295304
assert_equal(np.max(rfe.ranking_),
@@ -317,6 +326,6 @@ def formula2(n_features, n_features_to_select, step):
317326
rfecv.fit(X, y)
318327

319328
assert_equal(rfecv.grid_scores_.shape[0],
320-
formula1(n_features, n_features_to_select, step))
329+
formula1(n_features, n_features_to_select, step))
321330
assert_equal(rfecv.grid_scores_.shape[0],
322-
formula2(n_features, n_features_to_select, step))
331+
formula2(n_features, n_features_to_select, step))

0 commit comments

Comments
 (0)