Skip to content

Commit fe84821

Browse files
Make updating weights optional in SDCA.
Change: 135945044
1 parent 097e124 commit fe84821

5 files changed

Lines changed: 142 additions & 48 deletions

File tree

tensorflow/contrib/learn/python/learn/estimators/linear.py

Lines changed: 73 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from tensorflow.contrib.layers.python.layers import target_column
3535
from tensorflow.contrib.learn.python.learn import evaluable
3636
from tensorflow.contrib.learn.python.learn import metric_spec
37-
from tensorflow.contrib.learn.python.learn import session_run_hook
3837
from tensorflow.contrib.learn.python.learn import trainable
3938
from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined
4039
from tensorflow.contrib.learn.python.learn.estimators import estimator
@@ -54,6 +53,7 @@
5453
from tensorflow.python.ops import variable_scope
5554
from tensorflow.python.ops import variables
5655
from tensorflow.python.platform import tf_logging as logging
56+
from tensorflow.python.training import session_run_hook
5757
from tensorflow.python.training import training as train
5858

5959
_CLASSES = "classes"
@@ -253,11 +253,26 @@ def _linear_classifier_model_fn(features, targets, mode, params):
253253

254254

255255
def sdca_classifier_model_fn(features, targets, mode, params):
256-
"""Estimator's linear model_fn."""
256+
"""Estimator's model_fn for the SDCA optimizer.
257+
258+
Args:
259+
features: feature `Tensor` or `dict`. See the Estimator documentation.
260+
targets: targets `Tensor` or `dict`. See the Estimator documentation.
261+
mode: the mode. See the Estimator documentation.
262+
params: a `dict` with entries for "feature_columns", "optimizer",
263+
"weight_column_name", "loss_type", and optionally "update_weights_hook".
264+
265+
Returns:
266+
Tuple of predictions, loss, and train_op.
267+
268+
Raises:
269+
ValueError: if the parameters are invalid.
270+
"""
257271
feature_columns = params["feature_columns"]
258272
optimizer = params["optimizer"]
259273
weight_column_name = params["weight_column_name"]
260274
loss_type = params["loss_type"]
275+
update_weights_hook = params.get("update_weights_hook")
261276

262277
if not isinstance(optimizer, sdca_optimizer.SDCAOptimizer):
263278
raise ValueError("Optimizer must be of type SDCAOptimizer")
@@ -284,9 +299,12 @@ def sdca_classifier_model_fn(features, targets, mode, params):
284299
train_op = None
285300
if mode == estimator.ModeKeys.TRAIN:
286301
global_step = contrib_variables.get_global_step()
287-
train_op = optimizer.get_train_step(
288-
columns_to_variables, weight_column_name, loss_type, features,
289-
targets, global_step)
302+
sdca_model, train_op = optimizer.get_train_step(columns_to_variables,
303+
weight_column_name,
304+
loss_type, features,
305+
targets, global_step)
306+
if update_weights_hook is not None:
307+
update_weights_hook.set_parameters(sdca_model, train_op)
290308

291309
predictions = {}
292310
predictions[_LOGISTIC] = math_ops.sigmoid(logits)
@@ -303,6 +321,28 @@ def _get_default_optimizer(feature_columns):
303321
return train.FtrlOptimizer(learning_rate=learning_rate)
304322

305323

324+
class _SdcaUpdateWeightsHook(session_run_hook.SessionRunHook):
325+
"""SessionRunHook to update and shrink SDCA model weights."""
326+
327+
def __init__(self):
328+
pass
329+
330+
def set_parameters(self, sdca_model, train_op):
331+
self._sdca_model = sdca_model
332+
self._train_op = train_op
333+
334+
def begin(self):
335+
"""Construct the update_weights op.
336+
337+
The op is implicitly added to the default graph.
338+
"""
339+
self._update_op = self._sdca_model.update_weights(self._train_op)
340+
341+
def before_run(self, run_context):
342+
"""Return the update_weights op so that it is executed during this run."""
343+
return session_run_hook.SessionRunArgs(self._update_op)
344+
345+
306346
class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
307347
"""Linear classifier model.
308348
@@ -432,15 +472,23 @@ def __init__(self, # _joint_weight pylint: disable=invalid-name
432472
self._optimizer = _get_optimizer(optimizer)
433473
num_ps_replicas = config.num_ps_replicas if config else 0
434474

475+
chief_hook = None
435476
if isinstance(optimizer, sdca_optimizer.SDCAOptimizer):
436477
assert not _joint_weight, ("_joint_weight is incompatible with the"
437478
" SDCAOptimizer")
438479
model_fn = sdca_classifier_model_fn
480+
# We use a hook to perform the weight update and shrink step only on the
481+
# chief. Because the SdcaModel constructed by the estimator within the
482+
# call to fit() but we need to pass the hook to fit(), we pass the hook
483+
# as a parameter to the model_fn and have that propagate the model to the
484+
# hook.
485+
chief_hook = _SdcaUpdateWeightsHook()
439486
params = {
440487
"feature_columns": feature_columns,
441488
"optimizer": self._optimizer,
442489
"weight_column_name": weight_column_name,
443490
"loss_type": "logistic_loss",
491+
"update_weights_hook": chief_hook,
444492
}
445493
else:
446494
model_fn = _linear_classifier_model_fn
@@ -462,29 +510,35 @@ def __init__(self, # _joint_weight pylint: disable=invalid-name
462510
params=params,
463511
feature_engineering_fn=feature_engineering_fn)
464512

513+
self._additional_run_hook = None
514+
if self._estimator.config.is_chief:
515+
self._additional_run_hook = chief_hook
516+
465517
def get_estimator(self):
466518
return self._estimator
467519

468520
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
469521
monitors=None, max_steps=None):
470522
"""See trainable.Trainable."""
471523
# TODO(roumposg): Remove when deprecated monitors are removed.
472-
if monitors is not None:
473-
deprecated_monitors = [
474-
m for m in monitors
475-
if not isinstance(m, session_run_hook.SessionRunHook)
476-
]
477-
for monitor in deprecated_monitors:
478-
monitor.set_estimator(self)
479-
monitor._lock_estimator() # pylint: disable=protected-access
480-
524+
if monitors is None:
525+
monitors = []
526+
deprecated_monitors = [
527+
m for m in monitors
528+
if not isinstance(m, session_run_hook.SessionRunHook)
529+
]
530+
for monitor in deprecated_monitors:
531+
monitor.set_estimator(self)
532+
monitor._lock_estimator() # pylint: disable=protected-access
533+
534+
if self._additional_run_hook:
535+
monitors.append(self._additional_run_hook)
481536
result = self._estimator.fit(x=x, y=y, input_fn=input_fn, steps=steps,
482537
batch_size=batch_size, monitors=monitors,
483538
max_steps=max_steps)
484539

485-
if monitors is not None:
486-
for monitor in deprecated_monitors:
487-
monitor._unlock_estimator() # pylint: disable=protected-access
540+
for monitor in deprecated_monitors:
541+
monitor._unlock_estimator() # pylint: disable=protected-access
488542

489543
return result
490544

@@ -751,9 +805,10 @@ def _get_train_ops(self, features, targets):
751805
columns_to_variables)
752806

753807
def _train_op_fn(unused_loss):
754-
return self._linear_optimizer.get_train_step(
808+
sdca_model, train_op = self._linear_optimizer.get_train_step(
755809
columns_to_variables, self._weight_column_name,
756810
self._loss_type(), features, targets, global_step)
811+
return sdca_model.update_weights(train_op)
757812

758813
model_fn_ops = self._head.head_ops(features, targets,
759814
estimator.ModeKeys.TRAIN, _train_op_fn,

tensorflow/contrib/learn/python/learn/estimators/svm.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
5555
method), should be set to (#concurrent train ops per worker) x (#workers). If
5656
num_loss_partitions is larger or equal to this value, convergence is
5757
guaranteed but becomes slower as num_loss_partitions increases. If it is set
58-
to a smaller value, the optimizer is more agressive in reducing the global
58+
to a smaller value, the optimizer is more aggressive in reducing the global
5959
loss but convergence is not guaranteed. The recommended value in tf.learn
6060
(where there is one process per worker) is the number of workers running the
6161
train steps. It defaults to 1 (single machine).
@@ -146,6 +146,7 @@ def __init__(self,
146146

147147
self._feature_columns = feature_columns
148148
self._model_dir = model_dir or tempfile.mkdtemp()
149+
self._chief_hook = linear._SdcaUpdateWeightsHook() # pylint: disable=protected-access
149150
self._estimator = estimator.Estimator(
150151
model_fn=linear.sdca_classifier_model_fn,
151152
model_dir=self._model_dir,
@@ -155,12 +156,19 @@ def __init__(self,
155156
"optimizer": self._optimizer,
156157
"weight_column_name": weight_column_name,
157158
"loss_type": "hinge_loss",
159+
"update_weights_hook": self._chief_hook,
158160
},
159161
feature_engineering_fn=feature_engineering_fn)
162+
if not self._estimator.config.is_chief:
163+
self._chief_hook = None
160164

161165
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
162166
monitors=None, max_steps=None):
163167
"""See trainable.Trainable."""
168+
if monitors is None:
169+
monitors = []
170+
if self._chief_hook:
171+
monitors.append(self._chief_hook)
164172
return self._estimator.fit(x=x, y=y, input_fn=input_fn, steps=steps,
165173
batch_size=batch_size, monitors=monitors,
166174
max_steps=max_steps)

tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def testSimple(self):
189189
train_op = lr.minimize()
190190
for _ in range(_MAX_ITERATIONS):
191191
train_op.run()
192+
lr.update_weights(train_op).run()
192193
# The high tolerance in unregularized_loss comparisons is due to the
193194
# fact that it's possible to trade off unregularized_loss vs.
194195
# regularization and still have a sum that is quite close to the
@@ -248,6 +249,7 @@ def Minimize():
248249

249250
for t in threads:
250251
t.join()
252+
lr.update_weights(train_op).run()
251253

252254
# The high tolerance in unregularized_loss comparisons is due to the
253255
# fact that it's possible to trade off unregularized_loss vs.
@@ -294,6 +296,7 @@ def testSimpleNoL2(self):
294296
train_op = lr.minimize()
295297
for _ in range(_MAX_ITERATIONS):
296298
train_op.run()
299+
lr.update_weights(train_op).run()
297300

298301
# There is neither L1 nor L2 loss, so regularized and unregularized
299302
# losses should be exactly the same.
@@ -346,6 +349,7 @@ def testSomeUnweightedExamples(self):
346349
train_op = lr.minimize()
347350
for _ in range(_MAX_ITERATIONS):
348351
train_op.run()
352+
lr.update_weights(train_op).run()
349353

350354
self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05)
351355
self.assertAllClose(0.525457, loss.eval(), atol=0.01)
@@ -416,6 +420,7 @@ def testImbalanced(self):
416420
train_op = lr.minimize()
417421
for _ in range(_MAX_ITERATIONS):
418422
train_op.run()
423+
lr.update_weights(train_op).run()
419424

420425
self.assertAllClose(0.226487 + 0.102902,
421426
unregularized_loss.eval(),
@@ -456,6 +461,7 @@ def testImbalancedWithExampleWeights(self):
456461
train_op = lr.minimize()
457462
for _ in range(_MAX_ITERATIONS):
458463
train_op.run()
464+
lr.update_weights(train_op).run()
459465

460466
self.assertAllClose(0.284860, unregularized_loss.eval(), atol=0.08)
461467
self.assertAllClose(0.408044, loss.eval(), atol=0.012)
@@ -494,6 +500,7 @@ def testInstancesOfOneClassOnly(self):
494500
train_op = lr.minimize()
495501
for _ in range(_MAX_ITERATIONS):
496502
train_op.run()
503+
lr.update_weights(train_op).run()
497504
self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05)
498505
self.assertAllClose(0.525457, loss.eval(), atol=0.01)
499506
predicted_labels = get_binary_predictions_for_logistic(predictions)
@@ -580,6 +587,7 @@ def testSimple(self):
580587
train_op = lr.minimize()
581588
for _ in range(_MAX_ITERATIONS):
582589
train_op.run()
590+
lr.update_weights(train_op).run()
583591

584592
# Predictions should be 2/3 of label due to minimizing regularized loss:
585593
# (label - 2 * weight)^2 / 2 + L2 * 2 * weight^2
@@ -626,6 +634,7 @@ def testL2Regularization(self):
626634
train_op = lr.minimize()
627635
for _ in range(_MAX_ITERATIONS):
628636
train_op.run()
637+
lr.update_weights(train_op).run()
629638

630639
# Predictions should be 1/5 of label due to minimizing regularized loss:
631640
# (label - 2 * weight)^2 + L2 * 16 * weight^2
@@ -661,6 +670,7 @@ def testL1Regularization(self):
661670
train_op = lr.minimize()
662671
for _ in range(_MAX_ITERATIONS):
663672
train_op.run()
673+
lr.update_weights(train_op).run()
664674

665675
# Predictions should be -4.0, 48/5 due to minimizing regularized loss:
666676
# (label - 2 * weight)^2 / 2 + L2 * 2 * weight^2 + L1 * 4 * weight
@@ -696,6 +706,7 @@ def testFeatureValues(self):
696706
train_op = lr.minimize()
697707
for _ in range(_MAX_ITERATIONS):
698708
train_op.run()
709+
lr.update_weights(train_op).run()
699710

700711
# There are 4 (sparse) variable weights to be learned. 2 for age and 2 for
701712
# gender. Let w_1, w_2 be age weights, w_3, w_4 be gender weights, y_1,
@@ -729,6 +740,7 @@ def testDenseFeaturesWithDefaultWeights(self):
729740
train_op = lr.minimize()
730741
for _ in range(_MAX_ITERATIONS):
731742
train_op.run()
743+
lr.update_weights(train_op).run()
732744

733745
# The loss function for these particular features is given by:
734746
# 1/2(label_1-w_1)^2 + 1/2(label_2-w_2)^2 + \lambda/2 (w_1^2 + w_2^2). So,
@@ -759,6 +771,7 @@ def testDenseFeaturesWithArbitraryWeights(self):
759771
train_op = lr.minimize()
760772
for _ in range(_MAX_ITERATIONS):
761773
train_op.run()
774+
lr.update_weights(train_op).run()
762775

763776
# The loss function for these particular features is given by:
764777
# 1/2 s_1 (label_1-w_1)^2 + 1/2 s_2(label_2-w_2)^2 +
@@ -816,6 +829,7 @@ def testSimple(self):
816829
train_op = model.minimize()
817830
for _ in range(_MAX_ITERATIONS):
818831
train_op.run()
832+
model.update_weights(train_op).run()
819833

820834
binary_predictions = get_binary_predictions_for_hinge(predictions)
821835
self.assertAllEqual([-1.0, 1.0], predictions.eval())
@@ -841,6 +855,7 @@ def testDenseFeaturesPerfectlySeparable(self):
841855
train_op = model.minimize()
842856
for _ in range(_MAX_ITERATIONS):
843857
train_op.run()
858+
model.update_weights(train_op).run()
844859

845860
self.assertAllClose([1.0, -1.0], predictions.eval(), atol=0.05)
846861
self.assertAllEqual([1, 0], binary_predictions.eval())
@@ -871,6 +886,7 @@ def testDenseFeaturesSeparableWithinMargins(self):
871886
train_op = model.minimize()
872887
for _ in range(_MAX_ITERATIONS):
873888
train_op.run()
889+
model.update_weights(train_op).run()
874890

875891
# (1.0, 0.5) and (1.0, -0.5) are separable by x-axis but the datapoints
876892
# are within the margins so there is unregularized loss (1/2 per example).
@@ -899,6 +915,7 @@ def testDenseFeaturesWeightedExamples(self):
899915
train_op = model.minimize()
900916
for _ in range(_MAX_ITERATIONS):
901917
train_op.run()
918+
model.update_weights(train_op).run()
902919

903920
# Point (1.0, 0.5) has higher weight than (1.0, -0.5) so the model will
904921
# try to increase the margin from (1.0, 0.5). Due to regularization,
@@ -953,6 +970,7 @@ def testSimple(self):
953970
train_op = model.minimize()
954971
for _ in range(_MAX_ITERATIONS):
955972
train_op.run()
973+
model.update_weights(train_op).run()
956974

957975
binary_predictions = get_binary_predictions_for_hinge(predictions)
958976
self.assertAllClose([-0.67, 0.67], predictions.eval(), atol=0.05)

0 commit comments

Comments
 (0)