Skip to content

Commit 6bd0844

Browse files
committed
ENH make pipeline.named_steps a property, fix pipeline.named_steps doctest
1 parent b17b3aa commit 6bd0844

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

sklearn/pipeline.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,22 @@ class Pipeline(BaseEstimator):
7474
>>> anova_svm.score(X, y) # doctest: +ELLIPSIS
7575
0.77...
7676
>>> # getting the selected features chosen by anova_filter
77-
>>> support = anova_svm.named_steps.get_support()
77+
>>> anova_svm.named_steps['anova'].get_support()
78+
... # doctest: +NORMALIZE_WHITESPACE
79+
array([ True, True, True, False, False, True, False, True, True, True,
80+
False, False, True, False, True, False, False, False, False,
81+
True], dtype=bool)
7882
"""
7983

8084
# BaseEstimator interface
8185

8286
def __init__(self, steps):
83-
self.named_steps = dict(steps)
8487
names, estimators = zip(*steps)
85-
if len(self.named_steps) != len(steps):
86-
raise ValueError("Names provided are not unique: %s" % (names,))
88+
if len(dict(steps)) != len(steps):
89+
raise ValueError("Provided step names are not unique: %s" % (names,))
8790

8891
# shallow copy of steps
89-
self.steps = tosequence(zip(names, estimators))
92+
self.steps = tosequence(steps)
9093
transforms = estimators[:-1]
9194
estimator = estimators[-1]
9295

@@ -110,14 +113,18 @@ def get_params(self, deep=True):
110113
if not deep:
111114
return super(Pipeline, self).get_params(deep=False)
112115
else:
113-
out = self.named_steps.copy()
116+
out = self.named_steps
114117
for name, step in six.iteritems(self.named_steps):
115118
for key, value in six.iteritems(step.get_params(deep=True)):
116119
out['%s__%s' % (name, key)] = value
117120

118121
out.update(super(Pipeline, self).get_params(deep=False))
119122
return out
120123

124+
@property
125+
def named_steps(self):
126+
return dict(self.steps)
127+
121128
@property
122129
def _final_estimator(self):
123130
return self.steps[-1][1]

0 commit comments

Comments
 (0)