Skip to content

Commit 01aff46

Browse files
committed
Merge pull request scikit-learn#4347 from amueller/class_weight_auto
[MRG+1] Use more natural class_weight="auto" heuristic
2 parents 85986cd + 4ca3878 commit 01aff46

File tree

21 files changed

+348
-179
lines changed

21 files changed

+348
-179
lines changed

doc/modules/svm.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ Tips on Practical Use
405405
approximates the fraction of training errors and support vectors.
406406

407407
* In :class:`SVC`, if data for classification are unbalanced (e.g. many
408-
positive and few negative), set ``class_weight='auto'`` and/or try
408+
positive and few negative), set ``class_weight='balanced'`` and/or try
409409
different penalty parameters ``C``.
410410

411411
* The underlying :class:`LinearSVC` implementation uses a random

doc/whats_new.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ Enhancements
5656
:class:`linear_model.LogisticRegression`, by avoiding loss computation.
5757
By `Mathieu Blondel`_ and `Tom Dupre la Tour`_.
5858

59+
- The ``class_weight="auto"`` heuristic in classifiers supporting
60+
``class_weight`` was deprecated and replaced by the ``class_weight="balanced"``
61+
option, which has a simpler forumlar and interpretation.
62+
By Hanna Wallach and `Andreas Müller`_.
63+
5964
Bug fixes
6065
.........
6166

@@ -339,6 +344,7 @@ Enhancements
339344
- :class:`svm.SVC` fitted on sparse input now implements ``decision_function``.
340345
By `Rob Zinkov`_ and `Andreas Müller`_.
341346

347+
342348
Documentation improvements
343349
..........................
344350

@@ -462,7 +468,7 @@ Bug fixes
462468
in GMM. By `Alexis Mignon`_.
463469

464470
- Fixed a error in the computation of conditional probabilities in
465-
:class:`naive_bayes.BernoulliNB`. By `Hanna Wallach`_.
471+
:class:`naive_bayes.BernoulliNB`. By Hanna Wallach.
466472

467473
- Make the method ``radius_neighbors`` of
468474
:class:`neighbors.NearestNeighbors` return the samples lying on the

examples/applications/face_recognition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@
105105
t0 = time()
106106
param_grid = {'C': [1e3, 5e3, 1e4, 5e4, 1e5],
107107
'gamma': [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1], }
108-
clf = GridSearchCV(SVC(kernel='rbf', class_weight='auto'), param_grid)
108+
clf = GridSearchCV(SVC(kernel='rbf', class_weight='balanced'), param_grid)
109109
clf = clf.fit(X_train_pca, y_train)
110110
print("done in %0.3fs" % (time() - t0))
111111
print("Best estimator found by grid search:")

sklearn/ensemble/forest.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ class calls the ``fit`` method of each sub-estimator on random samples
4141

4242
from __future__ import division
4343

44+
import warnings
4445
from warnings import warn
46+
4547
from abc import ABCMeta, abstractmethod
4648

4749
import numpy as np
@@ -89,7 +91,11 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
8991
curr_sample_weight *= sample_counts
9092

9193
if class_weight == 'subsample':
92-
curr_sample_weight *= compute_sample_weight('auto', y, indices)
94+
with warnings.catch_warnings():
95+
warnings.simplefilter('ignore', DeprecationWarning)
96+
curr_sample_weight *= compute_sample_weight('auto', y, indices)
97+
elif class_weight == 'balanced_subsample':
98+
curr_sample_weight *= compute_sample_weight('balanced', y, indices)
9399

94100
tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)
95101

@@ -414,30 +420,40 @@ def _validate_y_class_weight(self, y):
414420
self.n_classes_.append(classes_k.shape[0])
415421

416422
if self.class_weight is not None:
417-
valid_presets = ('auto', 'subsample')
423+
valid_presets = ('auto', 'balanced', 'balanced_subsample', 'subsample', 'auto')
418424
if isinstance(self.class_weight, six.string_types):
419425
if self.class_weight not in valid_presets:
420426
raise ValueError('Valid presets for class_weight include '
421-
'"auto" and "subsample". Given "%s".'
427+
'"balanced" and "balanced_subsample". Given "%s".'
422428
% self.class_weight)
429+
if self.class_weight == "subsample":
430+
warn("class_weight='subsample' is deprecated and will be removed in 0.18."
431+
" It was replaced by class_weight='balanced_subsample' "
432+
"using the balanced strategy.", DeprecationWarning)
423433
if self.warm_start:
424-
warn('class_weight presets "auto" or "subsample" are '
434+
warn('class_weight presets "balanced" or "balanced_subsample" are '
425435
'not recommended for warm_start if the fitted data '
426436
'differs from the full dataset. In order to use '
427-
'"auto" weights, use compute_class_weight("auto", '
437+
'"balanced" weights, use compute_class_weight("balanced", '
428438
'classes, y). In place of y you can use a large '
429439
'enough sample of the full training set target to '
430440
'properly estimate the class frequency '
431441
'distributions. Pass the resulting weights as the '
432442
'class_weight parameter.')
433443

434-
if self.class_weight != 'subsample' or not self.bootstrap:
444+
if (self.class_weight not in ['subsample', 'balanced_subsample'] or
445+
not self.bootstrap):
435446
if self.class_weight == 'subsample':
436447
class_weight = 'auto'
448+
elif self.class_weight == "balanced_subsample":
449+
class_weight = "balanced"
437450
else:
438451
class_weight = self.class_weight
439-
expanded_class_weight = compute_sample_weight(class_weight,
440-
y_original)
452+
with warnings.catch_warnings():
453+
if class_weight == "auto":
454+
warnings.simplefilter('ignore', DeprecationWarning)
455+
expanded_class_weight = compute_sample_weight(class_weight,
456+
y_original)
441457

442458
return y, expanded_class_weight
443459

@@ -758,17 +774,18 @@ class RandomForestClassifier(ForestClassifier):
758774
and add more estimators to the ensemble, otherwise, just fit a whole
759775
new forest.
760776
761-
class_weight : dict, list of dicts, "auto", "subsample" or None, optional
777+
class_weight : dict, list of dicts, "balanced", "balanced_subsample" or None, optional
762778
763779
Weights associated with classes in the form ``{class_label: weight}``.
764780
If not given, all classes are supposed to have weight one. For
765781
multi-output problems, a list of dicts can be provided in the same
766782
order as the columns of y.
767783
768-
The "auto" mode uses the values of y to automatically adjust
769-
weights inversely proportional to class frequencies in the input data.
784+
The "balanced" mode uses the values of y to automatically adjust
785+
weights inversely proportional to class frequencies in the input data
786+
as ``n_samples / (n_classes * np.bincount(y))``
770787
771-
The "subsample" mode is the same as "auto" except that weights are
788+
The "balanced_subsample" mode is the same as "balanced" except that weights are
772789
computed based on the bootstrap sample for every tree grown.
773790
774791
For multi-output, the weights of each column of y will be multiplied.
@@ -1100,17 +1117,18 @@ class ExtraTreesClassifier(ForestClassifier):
11001117
and add more estimators to the ensemble, otherwise, just fit a whole
11011118
new forest.
11021119
1103-
class_weight : dict, list of dicts, "auto", "subsample" or None, optional
1120+
class_weight : dict, list of dicts, "balanced", "balanced_subsample" or None, optional
11041121
11051122
Weights associated with classes in the form ``{class_label: weight}``.
11061123
If not given, all classes are supposed to have weight one. For
11071124
multi-output problems, a list of dicts can be provided in the same
11081125
order as the columns of y.
11091126
1110-
The "auto" mode uses the values of y to automatically adjust
1111-
weights inversely proportional to class frequencies in the input data.
1127+
The "balanced" mode uses the values of y to automatically adjust
1128+
weights inversely proportional to class frequencies in the input data
1129+
as ``n_samples / (n_classes * np.bincount(y))``
11121130
1113-
The "subsample" mode is the same as "auto" except that weights are
1131+
The "balanced_subsample" mode is the same as "balanced" except that weights are
11141132
computed based on the bootstrap sample for every tree grown.
11151133
11161134
For multi-output, the weights of each column of y will be multiplied.

sklearn/ensemble/tests/test_forest.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from sklearn.utils.testing import assert_greater_equal
2525
from sklearn.utils.testing import assert_raises
2626
from sklearn.utils.testing import assert_warns
27+
from sklearn.utils.testing import assert_warns_message
2728
from sklearn.utils.testing import ignore_warnings
2829

2930
from sklearn import datasets
@@ -749,10 +750,10 @@ def check_class_weights(name):
749750
# Check class_weights resemble sample_weights behavior.
750751
ForestClassifier = FOREST_CLASSIFIERS[name]
751752

752-
# Iris is balanced, so no effect expected for using 'auto' weights
753+
# Iris is balanced, so no effect expected for using 'balanced' weights
753754
clf1 = ForestClassifier(random_state=0)
754755
clf1.fit(iris.data, iris.target)
755-
clf2 = ForestClassifier(class_weight='auto', random_state=0)
756+
clf2 = ForestClassifier(class_weight='balanced', random_state=0)
756757
clf2.fit(iris.data, iris.target)
757758
assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_)
758759

@@ -765,8 +766,8 @@ def check_class_weights(name):
765766
random_state=0)
766767
clf3.fit(iris.data, iris_multi)
767768
assert_almost_equal(clf2.feature_importances_, clf3.feature_importances_)
768-
# Check against multi-output "auto" which should also have no effect
769-
clf4 = ForestClassifier(class_weight='auto', random_state=0)
769+
# Check against multi-output "balanced" which should also have no effect
770+
clf4 = ForestClassifier(class_weight='balanced', random_state=0)
770771
clf4.fit(iris.data, iris_multi)
771772
assert_almost_equal(clf3.feature_importances_, clf4.feature_importances_)
772773

@@ -793,22 +794,26 @@ def test_class_weights():
793794
yield check_class_weights, name
794795

795796

796-
def check_class_weight_auto_and_bootstrap_multi_output(name):
797-
# Test class_weight works for multi-output
797+
def check_class_weight_balanced_and_bootstrap_multi_output(name):
798+
# Test class_weight works for multi-output"""
798799
ForestClassifier = FOREST_CLASSIFIERS[name]
799800
_y = np.vstack((y, np.array(y) * 2)).T
800-
clf = ForestClassifier(class_weight='auto', random_state=0)
801+
clf = ForestClassifier(class_weight='balanced', random_state=0)
801802
clf.fit(X, _y)
802803
clf = ForestClassifier(class_weight=[{-1: 0.5, 1: 1.}, {-2: 1., 2: 1.}],
803804
random_state=0)
804805
clf.fit(X, _y)
806+
# smoke test for subsample and balanced subsample
807+
clf = ForestClassifier(class_weight='balanced_subsample', random_state=0)
808+
clf.fit(X, _y)
805809
clf = ForestClassifier(class_weight='subsample', random_state=0)
810+
#assert_warns_message(DeprecationWarning, "balanced_subsample", clf.fit, X, _y)
806811
clf.fit(X, _y)
807812

808813

809-
def test_class_weight_auto_and_bootstrap_multi_output():
814+
def test_class_weight_balanced_and_bootstrap_multi_output():
810815
for name in FOREST_CLASSIFIERS:
811-
yield check_class_weight_auto_and_bootstrap_multi_output, name
816+
yield check_class_weight_balanced_and_bootstrap_multi_output, name
812817

813818

814819
def check_class_weight_errors(name):

sklearn/linear_model/logistic.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -473,11 +473,13 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
473473
is called repeatedly with the same data, as y is modified
474474
along the path.
475475
476-
class_weight : {dict, 'auto'}, optional
477-
Over-/undersamples the samples of each class according to the given
478-
weights. If not given, all classes are supposed to have weight one.
479-
The 'auto' mode selects weights inversely proportional to class
480-
frequencies in the training set.
476+
class_weight : dict or 'balanced', optional
477+
Weights associated with classes in the form ``{class_label: weight}``.
478+
If not given, all classes are supposed to have weight one.
479+
480+
The "balanced" mode uses the values of y to automatically adjust
481+
weights inversely proportional to class frequencies in the input data
482+
as ``n_samples / (n_classes * np.bincount(y))``
481483
482484
dual : bool
483485
Dual or primal formulation. Dual formulation is only implemented for
@@ -734,11 +736,13 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
734736
tol : float
735737
Tolerance for stopping criteria.
736738
737-
class_weight : {dict, 'auto'}, optional
738-
Over-/undersamples the samples of each class according to the given
739-
weights. If not given, all classes are supposed to have weight one.
740-
The 'auto' mode selects weights inversely proportional to class
741-
frequencies in the training set.
739+
class_weight : dict or 'balanced', optional
740+
Weights associated with classes in the form ``{class_label: weight}``.
741+
If not given, all classes are supposed to have weight one.
742+
743+
The "balanced" mode uses the values of y to automatically adjust
744+
weights inversely proportional to class frequencies in the input data
745+
as ``n_samples / (n_classes * np.bincount(y))``
742746
743747
verbose : int
744748
For the liblinear and lbfgs solvers set verbose to any positive
@@ -903,11 +907,13 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
903907
To lessen the effect of regularization on synthetic feature weight
904908
(and therefore on the intercept) intercept_scaling has to be increased.
905909
906-
class_weight : {dict, 'auto'}, optional
907-
Over-/undersamples the samples of each class according to the given
908-
weights. If not given, all classes are supposed to have weight one.
909-
The 'auto' mode selects weights inversely proportional to class
910-
frequencies in the training set.
910+
class_weight : dict or 'balanced', optional
911+
Weights associated with classes in the form ``{class_label: weight}``.
912+
If not given, all classes are supposed to have weight one.
913+
914+
The "balanced" mode uses the values of y to automatically adjust
915+
weights inversely proportional to class frequencies in the input data
916+
as ``n_samples / (n_classes * np.bincount(y))``
911917
912918
max_iter : int
913919
Useful only for the newton-cg and lbfgs solvers. Maximum number of
@@ -1150,11 +1156,13 @@ class LogisticRegressionCV(LogisticRegression, BaseEstimator,
11501156
Specifies if a constant (a.k.a. bias or intercept) should be
11511157
added the decision function.
11521158
1153-
class_weight : {dict, 'auto'}, optional
1154-
Over-/undersamples the samples of each class according to the given
1155-
weights. If not given, all classes are supposed to have weight one.
1156-
The 'auto' mode selects weights inversely proportional to class
1157-
frequencies in the training set.
1159+
class_weight : dict or 'balanced', optional
1160+
Weights associated with classes in the form ``{class_label: weight}``.
1161+
If not given, all classes are supposed to have weight one.
1162+
1163+
The "balanced" mode uses the values of y to automatically adjust
1164+
weights inversely proportional to class frequencies in the input data
1165+
as ``n_samples / (n_classes * np.bincount(y))``
11581166
11591167
cv : integer or cross-validation generator
11601168
The default cross-validation generator used is Stratified K-Folds.
@@ -1185,11 +1193,13 @@ class LogisticRegressionCV(LogisticRegression, BaseEstimator,
11851193
max_iter : int, optional
11861194
Maximum number of iterations of the optimization algorithm.
11871195
1188-
class_weight : {dict, 'auto'}, optional
1189-
Over-/undersamples the samples of each class according to the given
1190-
weights. If not given, all classes are supposed to have weight one.
1191-
The 'auto' mode selects weights inversely proportional to class
1192-
frequencies in the training set.
1196+
class_weight : dict or 'balanced', optional
1197+
Weights associated with classes in the form ``{class_label: weight}``.
1198+
If not given, all classes are supposed to have weight one.
1199+
1200+
The "balanced" mode uses the values of y to automatically adjust
1201+
weights inversely proportional to class frequencies in the input data
1202+
as ``n_samples / (n_classes * np.bincount(y))``
11931203
11941204
n_jobs : int, optional
11951205
Number of CPU cores used during the cross-validation loop. If given
@@ -1363,9 +1373,9 @@ def fit(self, X, y):
13631373
iter_labels = [None]
13641374

13651375
if self.class_weight and not(isinstance(self.class_weight, dict) or
1366-
self.class_weight == 'auto'):
1376+
self.class_weight in ['balanced', 'auto']):
13671377
raise ValueError("class_weight provided should be a "
1368-
"dict or 'auto'")
1378+
"dict or 'balanced'")
13691379

13701380
path_func = delayed(_log_reg_scoring_path)
13711381

sklearn/linear_model/perceptron.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,15 @@ class Perceptron(BaseSGDClassifier, _LearntSelectorMixin):
4444
eta0 : double
4545
Constant by which the updates are multiplied. Defaults to 1.
4646
47-
class_weight : dict, {class_label: weight} or "auto" or None, optional
47+
class_weight : dict, {class_label: weight} or "balanced" or None, optional
4848
Preset for the class_weight fit parameter.
4949
5050
Weights associated with classes. If not given, all classes
5151
are supposed to have weight one.
5252
53-
The "auto" mode uses the values of y to automatically adjust
54-
weights inversely proportional to class frequencies.
53+
The "balanced" mode uses the values of y to automatically adjust
54+
weights inversely proportional to class frequencies in the input data
55+
as ``n_samples / (n_classes * np.bincount(y))``
5556
5657
warm_start : bool, optional
5758
When set to True, reuse the solution of the previous call to fit as

sklearn/linear_model/ridge.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,13 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
507507
``(2*C)^-1`` in other linear models such as LogisticRegression or
508508
LinearSVC.
509509
510-
class_weight : dict, optional
511-
Weights associated with classes in the form
512-
``{class_label : weight}``. If not given, all classes are
513-
supposed to have weight one.
510+
class_weight : dict or 'balanced', optional
511+
Weights associated with classes in the form ``{class_label: weight}``.
512+
If not given, all classes are supposed to have weight one.
513+
514+
The "balanced" mode uses the values of y to automatically adjust
515+
weights inversely proportional to class frequencies in the input data
516+
as ``n_samples / (n_classes * np.bincount(y))``
514517
515518
copy_X : boolean, optional, default True
516519
If True, X will be copied; else, it may be overwritten.
@@ -994,10 +997,13 @@ class RidgeClassifierCV(LinearClassifierMixin, _BaseRidgeCV):
994997
If None, Generalized Cross-Validation (efficient Leave-One-Out)
995998
will be used.
996999
997-
class_weight : dict, optional
998-
Weights associated with classes in the form
999-
``{class_label : weight}``. If not given, all classes are
1000-
supposed to have weight one.
1000+
class_weight : dict or 'balanced', optional
1001+
Weights associated with classes in the form ``{class_label: weight}``.
1002+
If not given, all classes are supposed to have weight one.
1003+
1004+
The "balanced" mode uses the values of y to automatically adjust
1005+
weights inversely proportional to class frequencies in the input data
1006+
as ``n_samples / (n_classes * np.bincount(y))``
10011007
10021008
Attributes
10031009
----------

0 commit comments

Comments
 (0)