Skip to content

Commit eedc1cd

Browse files
committed
Use more natural class_weight="auto" heuristic
1 parent 2c79a98 commit eedc1cd

File tree

21 files changed

+323
-178
lines changed

21 files changed

+323
-178
lines changed

doc/modules/svm.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ Tips on Practical Use
405405
approximates the fraction of training errors and support vectors.
406406

407407
* In :class:`SVC`, if data for classification are unbalanced (e.g. many
408-
positive and few negative), set ``class_weight='auto'`` and/or try
408+
positive and few negative), set ``class_weight='balanced'`` and/or try
409409
different penalty parameters ``C``.
410410

411411
* The underlying :class:`LinearSVC` implementation uses a random

doc/whats_new.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ Enhancements
5656
:class:`linear_model.LogisticRegression`, by avoiding loss computation.
5757
By `Mathieu Blondel`_ and `Tom Dupre la Tour`_.
5858

59+
- Improved heuristic for ``class_weight="auto"`` for classifiers supporting
60+
``class_weight`` by Hanna Wallach and `Andreas Müller`_
61+
5962
Bug fixes
6063
.........
6164

@@ -339,6 +342,7 @@ Enhancements
339342
- :class:`svm.SVC` fitted on sparse input now implements ``decision_function``.
340343
By `Rob Zinkov`_ and `Andreas Müller`_.
341344

345+
342346
Documentation improvements
343347
..........................
344348

@@ -462,7 +466,7 @@ Bug fixes
462466
in GMM. By `Alexis Mignon`_.
463467

464468
- Fixed a error in the computation of conditional probabilities in
465-
:class:`naive_bayes.BernoulliNB`. By `Hanna Wallach`_.
469+
:class:`naive_bayes.BernoulliNB`. By Hanna Wallach.
466470

467471
- Make the method ``radius_neighbors`` of
468472
:class:`neighbors.NearestNeighbors` return the samples lying on the

examples/applications/face_recognition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@
105105
t0 = time()
106106
param_grid = {'C': [1e3, 5e3, 1e4, 5e4, 1e5],
107107
'gamma': [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1], }
108-
clf = GridSearchCV(SVC(kernel='rbf', class_weight='auto'), param_grid)
108+
clf = GridSearchCV(SVC(kernel='rbf', class_weight='balanced'), param_grid)
109109
clf = clf.fit(X_train_pca, y_train)
110110
print("done in %0.3fs" % (time() - t0))
111111
print("Best estimator found by grid search:")

sklearn/ensemble/forest.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
8989
curr_sample_weight *= sample_counts
9090

9191
if class_weight == 'subsample':
92-
curr_sample_weight *= compute_sample_weight('auto', y, indices)
92+
curr_sample_weight *= compute_sample_weight('balanced', y, indices)
9393

9494
tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)
9595

@@ -414,17 +414,17 @@ def _validate_y_class_weight(self, y):
414414
self.n_classes_.append(classes_k.shape[0])
415415

416416
if self.class_weight is not None:
417-
valid_presets = ('auto', 'subsample')
417+
valid_presets = ('auto', 'balanced', 'subsample', 'auto')
418418
if isinstance(self.class_weight, six.string_types):
419419
if self.class_weight not in valid_presets:
420420
raise ValueError('Valid presets for class_weight include '
421-
'"auto" and "subsample". Given "%s".'
421+
'"balanced" and "subsample". Given "%s".'
422422
% self.class_weight)
423423
if self.warm_start:
424-
warn('class_weight presets "auto" or "subsample" are '
424+
warn('class_weight presets "balanced" or "subsample" are '
425425
'not recommended for warm_start if the fitted data '
426426
'differs from the full dataset. In order to use '
427-
'"auto" weights, use compute_class_weight("auto", '
427+
'"auto" weights, use compute_class_weight("balanced", '
428428
'classes, y). In place of y you can use a large '
429429
'enough sample of the full training set target to '
430430
'properly estimate the class frequency '
@@ -433,7 +433,7 @@ def _validate_y_class_weight(self, y):
433433

434434
if self.class_weight != 'subsample' or not self.bootstrap:
435435
if self.class_weight == 'subsample':
436-
class_weight = 'auto'
436+
class_weight = 'balanced'
437437
else:
438438
class_weight = self.class_weight
439439
expanded_class_weight = compute_sample_weight(class_weight,
@@ -758,17 +758,18 @@ class RandomForestClassifier(ForestClassifier):
758758
and add more estimators to the ensemble, otherwise, just fit a whole
759759
new forest.
760760
761-
class_weight : dict, list of dicts, "auto", "subsample" or None, optional
761+
class_weight : dict, list of dicts, "balanced", "subsample" or None, optional
762762
763763
Weights associated with classes in the form ``{class_label: weight}``.
764764
If not given, all classes are supposed to have weight one. For
765765
multi-output problems, a list of dicts can be provided in the same
766766
order as the columns of y.
767767
768-
The "auto" mode uses the values of y to automatically adjust
769-
weights inversely proportional to class frequencies in the input data.
768+
The "balanced" mode uses the values of y to automatically adjust
769+
weights inversely proportional to class frequencies in the input data
770+
as ``n_samples / (n_classes * np.bincount(y))``
770771
771-
The "subsample" mode is the same as "auto" except that weights are
772+
The "subsample" mode is the same as "balanced" except that weights are
772773
computed based on the bootstrap sample for every tree grown.
773774
774775
For multi-output, the weights of each column of y will be multiplied.
@@ -1100,17 +1101,18 @@ class ExtraTreesClassifier(ForestClassifier):
11001101
and add more estimators to the ensemble, otherwise, just fit a whole
11011102
new forest.
11021103
1103-
class_weight : dict, list of dicts, "auto", "subsample" or None, optional
1104+
class_weight : dict, list of dicts, "balanced", "subsample" or None, optional
11041105
11051106
Weights associated with classes in the form ``{class_label: weight}``.
11061107
If not given, all classes are supposed to have weight one. For
11071108
multi-output problems, a list of dicts can be provided in the same
11081109
order as the columns of y.
11091110
1110-
The "auto" mode uses the values of y to automatically adjust
1111-
weights inversely proportional to class frequencies in the input data.
1111+
The "balanced" mode uses the values of y to automatically adjust
1112+
weights inversely proportional to class frequencies in the input data
1113+
as ``n_samples / (n_classes * np.bincount(y))``
11121114
1113-
The "subsample" mode is the same as "auto" except that weights are
1115+
The "subsample" mode is the same as "balanced" except that weights are
11141116
computed based on the bootstrap sample for every tree grown.
11151117
11161118
For multi-output, the weights of each column of y will be multiplied.

sklearn/ensemble/tests/test_forest.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def test_parallel():
329329
yield check_parallel, name, iris.data, iris.target
330330

331331
for name in FOREST_REGRESSORS:
332-
yield check_parallel, name, boston.data, boston.target
332+
yield check_parallel, name, boston.data, boston.target
333333

334334

335335
def check_pickle(name, X, y):
@@ -352,7 +352,7 @@ def test_pickle():
352352
yield check_pickle, name, iris.data[::2], iris.target[::2]
353353

354354
for name in FOREST_REGRESSORS:
355-
yield check_pickle, name, boston.data[::2], boston.target[::2]
355+
yield check_pickle, name, boston.data[::2], boston.target[::2]
356356

357357

358358
def check_multioutput(name):
@@ -749,10 +749,10 @@ def check_class_weights(name):
749749
# Check class_weights resemble sample_weights behavior.
750750
ForestClassifier = FOREST_CLASSIFIERS[name]
751751

752-
# Iris is balanced, so no effect expected for using 'auto' weights
752+
# Iris is balanced, so no effect expected for using 'balanced' weights
753753
clf1 = ForestClassifier(random_state=0)
754754
clf1.fit(iris.data, iris.target)
755-
clf2 = ForestClassifier(class_weight='auto', random_state=0)
755+
clf2 = ForestClassifier(class_weight='balanced', random_state=0)
756756
clf2.fit(iris.data, iris.target)
757757
assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_)
758758

@@ -765,8 +765,8 @@ def check_class_weights(name):
765765
random_state=0)
766766
clf3.fit(iris.data, iris_multi)
767767
assert_almost_equal(clf2.feature_importances_, clf3.feature_importances_)
768-
# Check against multi-output "auto" which should also have no effect
769-
clf4 = ForestClassifier(class_weight='auto', random_state=0)
768+
# Check against multi-output "balanced" which should also have no effect
769+
clf4 = ForestClassifier(class_weight='balanced', random_state=0)
770770
clf4.fit(iris.data, iris_multi)
771771
assert_almost_equal(clf3.feature_importances_, clf4.feature_importances_)
772772

@@ -782,7 +782,7 @@ def check_class_weights(name):
782782

783783
# Check that sample_weight and class_weight are multiplicative
784784
clf1 = ForestClassifier(random_state=0)
785-
clf1.fit(iris.data, iris.target, sample_weight**2)
785+
clf1.fit(iris.data, iris.target, sample_weight ** 2)
786786
clf2 = ForestClassifier(class_weight=class_weight, random_state=0)
787787
clf2.fit(iris.data, iris.target, sample_weight)
788788
assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_)
@@ -793,11 +793,11 @@ def test_class_weights():
793793
yield check_class_weights, name
794794

795795

796-
def check_class_weight_auto_and_bootstrap_multi_output(name):
797-
# Test class_weight works for multi-output
796+
def check_class_weight_balanced_and_bootstrap_multi_output(name):
797+
# Test class_weight works for multi-output"""
798798
ForestClassifier = FOREST_CLASSIFIERS[name]
799799
_y = np.vstack((y, np.array(y) * 2)).T
800-
clf = ForestClassifier(class_weight='auto', random_state=0)
800+
clf = ForestClassifier(class_weight='balanced', random_state=0)
801801
clf.fit(X, _y)
802802
clf = ForestClassifier(class_weight=[{-1: 0.5, 1: 1.}, {-2: 1., 2: 1.}],
803803
random_state=0)
@@ -806,9 +806,9 @@ def check_class_weight_auto_and_bootstrap_multi_output(name):
806806
clf.fit(X, _y)
807807

808808

809-
def test_class_weight_auto_and_bootstrap_multi_output():
809+
def test_class_weight_balanced_and_bootstrap_multi_output():
810810
for name in FOREST_CLASSIFIERS:
811-
yield check_class_weight_auto_and_bootstrap_multi_output, name
811+
yield check_class_weight_balanced_and_bootstrap_multi_output, name
812812

813813

814814
def check_class_weight_errors(name):

sklearn/linear_model/logistic.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -473,11 +473,13 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
473473
is called repeatedly with the same data, as y is modified
474474
along the path.
475475
476-
class_weight : {dict, 'auto'}, optional
477-
Over-/undersamples the samples of each class according to the given
478-
weights. If not given, all classes are supposed to have weight one.
479-
The 'auto' mode selects weights inversely proportional to class
480-
frequencies in the training set.
476+
class_weight : dict or 'balanced', optional
477+
Weights associated with classes in the form ``{class_label: weight}``.
478+
If not given, all classes are supposed to have weight one.
479+
480+
The "balanced" mode uses the values of y to automatically adjust
481+
weights inversely proportional to class frequencies in the input data
482+
as ``n_samples / (n_classes * np.bincount(y))``
481483
482484
dual : bool
483485
Dual or primal formulation. Dual formulation is only implemented for
@@ -734,11 +736,13 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
734736
tol : float
735737
Tolerance for stopping criteria.
736738
737-
class_weight : {dict, 'auto'}, optional
738-
Over-/undersamples the samples of each class according to the given
739-
weights. If not given, all classes are supposed to have weight one.
740-
The 'auto' mode selects weights inversely proportional to class
741-
frequencies in the training set.
739+
class_weight : dict or 'balanced', optional
740+
Weights associated with classes in the form ``{class_label: weight}``.
741+
If not given, all classes are supposed to have weight one.
742+
743+
The "balanced" mode uses the values of y to automatically adjust
744+
weights inversely proportional to class frequencies in the input data
745+
as ``n_samples / (n_classes * np.bincount(y))``
742746
743747
verbose : int
744748
For the liblinear and lbfgs solvers set verbose to any positive
@@ -903,11 +907,13 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
903907
To lessen the effect of regularization on synthetic feature weight
904908
(and therefore on the intercept) intercept_scaling has to be increased.
905909
906-
class_weight : {dict, 'auto'}, optional
907-
Over-/undersamples the samples of each class according to the given
908-
weights. If not given, all classes are supposed to have weight one.
909-
The 'auto' mode selects weights inversely proportional to class
910-
frequencies in the training set.
910+
class_weight : dict or 'balanced', optional
911+
Weights associated with classes in the form ``{class_label: weight}``.
912+
If not given, all classes are supposed to have weight one.
913+
914+
The "balanced" mode uses the values of y to automatically adjust
915+
weights inversely proportional to class frequencies in the input data
916+
as ``n_samples / (n_classes * np.bincount(y))``
911917
912918
max_iter : int
913919
Useful only for the newton-cg and lbfgs solvers. Maximum number of
@@ -1150,11 +1156,13 @@ class LogisticRegressionCV(LogisticRegression, BaseEstimator,
11501156
Specifies if a constant (a.k.a. bias or intercept) should be
11511157
added the decision function.
11521158
1153-
class_weight : {dict, 'auto'}, optional
1154-
Over-/undersamples the samples of each class according to the given
1155-
weights. If not given, all classes are supposed to have weight one.
1156-
The 'auto' mode selects weights inversely proportional to class
1157-
frequencies in the training set.
1159+
class_weight : dict or 'balanced', optional
1160+
Weights associated with classes in the form ``{class_label: weight}``.
1161+
If not given, all classes are supposed to have weight one.
1162+
1163+
The "balanced" mode uses the values of y to automatically adjust
1164+
weights inversely proportional to class frequencies in the input data
1165+
as ``n_samples / (n_classes * np.bincount(y))``
11581166
11591167
cv : integer or cross-validation generator
11601168
The default cross-validation generator used is Stratified K-Folds.
@@ -1185,11 +1193,13 @@ class LogisticRegressionCV(LogisticRegression, BaseEstimator,
11851193
max_iter : int, optional
11861194
Maximum number of iterations of the optimization algorithm.
11871195
1188-
class_weight : {dict, 'auto'}, optional
1189-
Over-/undersamples the samples of each class according to the given
1190-
weights. If not given, all classes are supposed to have weight one.
1191-
The 'auto' mode selects weights inversely proportional to class
1192-
frequencies in the training set.
1196+
class_weight : dict or 'balanced', optional
1197+
Weights associated with classes in the form ``{class_label: weight}``.
1198+
If not given, all classes are supposed to have weight one.
1199+
1200+
The "balanced" mode uses the values of y to automatically adjust
1201+
weights inversely proportional to class frequencies in the input data
1202+
as ``n_samples / (n_classes * np.bincount(y))``
11931203
11941204
n_jobs : int, optional
11951205
Number of CPU cores used during the cross-validation loop. If given
@@ -1363,9 +1373,9 @@ def fit(self, X, y):
13631373
iter_labels = [None]
13641374

13651375
if self.class_weight and not(isinstance(self.class_weight, dict) or
1366-
self.class_weight == 'auto'):
1376+
self.class_weight in ['balanced', 'auto']):
13671377
raise ValueError("class_weight provided should be a "
1368-
"dict or 'auto'")
1378+
"dict or 'balanced'")
13691379

13701380
path_func = delayed(_log_reg_scoring_path)
13711381

sklearn/linear_model/perceptron.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,15 @@ class Perceptron(BaseSGDClassifier, _LearntSelectorMixin):
4444
eta0 : double
4545
Constant by which the updates are multiplied. Defaults to 1.
4646
47-
class_weight : dict, {class_label: weight} or "auto" or None, optional
47+
class_weight : dict, {class_label: weight} or "balanced" or None, optional
4848
Preset for the class_weight fit parameter.
4949
5050
Weights associated with classes. If not given, all classes
5151
are supposed to have weight one.
5252
53-
The "auto" mode uses the values of y to automatically adjust
54-
weights inversely proportional to class frequencies.
53+
The "balanced" mode uses the values of y to automatically adjust
54+
weights inversely proportional to class frequencies in the input data
55+
as ``n_samples / (n_classes * np.bincount(y))``
5556
5657
warm_start : bool, optional
5758
When set to True, reuse the solution of the previous call to fit as

sklearn/linear_model/ridge.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,13 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
507507
``(2*C)^-1`` in other linear models such as LogisticRegression or
508508
LinearSVC.
509509
510-
class_weight : dict, optional
511-
Weights associated with classes in the form
512-
``{class_label : weight}``. If not given, all classes are
513-
supposed to have weight one.
510+
class_weight : dict or 'balanced', optional
511+
Weights associated with classes in the form ``{class_label: weight}``.
512+
If not given, all classes are supposed to have weight one.
513+
514+
The "balanced" mode uses the values of y to automatically adjust
515+
weights inversely proportional to class frequencies in the input data
516+
as ``n_samples / (n_classes * np.bincount(y))``
514517
515518
copy_X : boolean, optional, default True
516519
If True, X will be copied; else, it may be overwritten.
@@ -994,10 +997,13 @@ class RidgeClassifierCV(LinearClassifierMixin, _BaseRidgeCV):
994997
If None, Generalized Cross-Validation (efficient Leave-One-Out)
995998
will be used.
996999
997-
class_weight : dict, optional
998-
Weights associated with classes in the form
999-
``{class_label : weight}``. If not given, all classes are
1000-
supposed to have weight one.
1000+
class_weight : dict or 'balanced', optional
1001+
Weights associated with classes in the form ``{class_label: weight}``.
1002+
If not given, all classes are supposed to have weight one.
1003+
1004+
The "balanced" mode uses the values of y to automatically adjust
1005+
weights inversely proportional to class frequencies in the input data
1006+
as ``n_samples / (n_classes * np.bincount(y))``
10011007
10021008
Attributes
10031009
----------

0 commit comments

Comments
 (0)