Skip to content

Commit 32b7075

Browse files
committed
Merge pull request scikit-learn#4680 from amueller/pipeline_named_steps_fix
[MRG+2-?] Pipeline named steps fix
2 parents bdef419 + 6bd0844 commit 32b7075

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

sklearn/pipeline.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,17 @@ class Pipeline(BaseEstimator):
3939
4040
Parameters
4141
----------
42-
steps: list
42+
steps : list
4343
List of (name, transform) tuples (implementing fit/transform) that are
4444
chained, in the order in which they are chained, with the last object
4545
an estimator.
4646
47+
Attributes
48+
----------
49+
named_steps : dict
50+
Read-only attribute to access any step parameter by user given name.
51+
Keys are step names and values are steps parameters.
52+
4753
Examples
4854
--------
4955
>>> from sklearn import svm
@@ -67,18 +73,23 @@ class Pipeline(BaseEstimator):
6773
>>> prediction = anova_svm.predict(X)
6874
>>> anova_svm.score(X, y) # doctest: +ELLIPSIS
6975
0.77...
76+
>>> # getting the selected features chosen by anova_filter
77+
>>> anova_svm.named_steps['anova'].get_support()
78+
... # doctest: +NORMALIZE_WHITESPACE
79+
array([ True, True, True, False, False, True, False, True, True, True,
80+
False, False, True, False, True, False, False, False, False,
81+
True], dtype=bool)
7082
"""
7183

7284
# BaseEstimator interface
7385

7486
def __init__(self, steps):
75-
self.named_steps = dict(steps)
7687
names, estimators = zip(*steps)
77-
if len(self.named_steps) != len(steps):
78-
raise ValueError("Names provided are not unique: %s" % (names,))
88+
if len(dict(steps)) != len(steps):
89+
raise ValueError("Provided step names are not unique: %s" % (names,))
7990

8091
# shallow copy of steps
81-
self.steps = tosequence(zip(names, estimators))
92+
self.steps = tosequence(steps)
8293
transforms = estimators[:-1]
8394
estimator = estimators[-1]
8495

@@ -102,14 +113,18 @@ def get_params(self, deep=True):
102113
if not deep:
103114
return super(Pipeline, self).get_params(deep=False)
104115
else:
105-
out = self.named_steps.copy()
116+
out = self.named_steps
106117
for name, step in six.iteritems(self.named_steps):
107118
for key, value in six.iteritems(step.get_params(deep=True)):
108119
out['%s__%s' % (name, key)] = value
109120

110121
out.update(super(Pipeline, self).get_params(deep=False))
111122
return out
112123

124+
@property
125+
def named_steps(self):
126+
return dict(self.steps)
127+
113128
@property
114129
def _final_estimator(self):
115130
return self.steps[-1][1]

0 commit comments

Comments
 (0)