3434from tensorflow .contrib .layers .python .layers import target_column
3535from tensorflow .contrib .learn .python .learn import evaluable
3636from tensorflow .contrib .learn .python .learn import metric_spec
37- from tensorflow .contrib .learn .python .learn import session_run_hook
3837from tensorflow .contrib .learn .python .learn import trainable
3938from tensorflow .contrib .learn .python .learn .estimators import dnn_linear_combined
4039from tensorflow .contrib .learn .python .learn .estimators import estimator
5453from tensorflow .python .ops import variable_scope
5554from tensorflow .python .ops import variables
5655from tensorflow .python .platform import tf_logging as logging
56+ from tensorflow .python .training import session_run_hook
5757from tensorflow .python .training import training as train
5858
5959_CLASSES = "classes"
@@ -253,11 +253,26 @@ def _linear_classifier_model_fn(features, targets, mode, params):
253253
254254
255255def sdca_classifier_model_fn (features , targets , mode , params ):
256- """Estimator's linear model_fn."""
256+ """Estimator's model_fn for the SDCA optimizer.
257+
258+ Args:
259+ features: feature `Tensor` or `dict`. See the Estimator documentation.
260+ targets: targets `Tensor` or `dict`. See the Estimator documentation.
261+ mode: the mode. See the Estimator documentation.
262+ params: a `dict` with entries for "feature_columns", "optimizer",
263+ "weight_column_name", "loss_type", and optionally "update_weights_hook".
264+
265+ Returns:
266+ Tuple of predictions, loss, and train_op.
267+
268+ Raises:
269+ ValueError: if the parameters are invalid.
270+ """
257271 feature_columns = params ["feature_columns" ]
258272 optimizer = params ["optimizer" ]
259273 weight_column_name = params ["weight_column_name" ]
260274 loss_type = params ["loss_type" ]
275+ update_weights_hook = params .get ("update_weights_hook" )
261276
262277 if not isinstance (optimizer , sdca_optimizer .SDCAOptimizer ):
263278 raise ValueError ("Optimizer must be of type SDCAOptimizer" )
@@ -284,9 +299,12 @@ def sdca_classifier_model_fn(features, targets, mode, params):
284299 train_op = None
285300 if mode == estimator .ModeKeys .TRAIN :
286301 global_step = contrib_variables .get_global_step ()
287- train_op = optimizer .get_train_step (
288- columns_to_variables , weight_column_name , loss_type , features ,
289- targets , global_step )
302+ sdca_model , train_op = optimizer .get_train_step (columns_to_variables ,
303+ weight_column_name ,
304+ loss_type , features ,
305+ targets , global_step )
306+ if update_weights_hook is not None :
307+ update_weights_hook .set_parameters (sdca_model , train_op )
290308
291309 predictions = {}
292310 predictions [_LOGISTIC ] = math_ops .sigmoid (logits )
@@ -303,6 +321,28 @@ def _get_default_optimizer(feature_columns):
303321 return train .FtrlOptimizer (learning_rate = learning_rate )
304322
305323
324+ class _SdcaUpdateWeightsHook (session_run_hook .SessionRunHook ):
325+ """SessionRunHook to update and shrink SDCA model weights."""
326+
327+ def __init__ (self ):
328+ pass
329+
330+ def set_parameters (self , sdca_model , train_op ):
331+ self ._sdca_model = sdca_model
332+ self ._train_op = train_op
333+
334+ def begin (self ):
335+ """Construct the update_weights op.
336+
337+ The op is implicitly added to the default graph.
338+ """
339+ self ._update_op = self ._sdca_model .update_weights (self ._train_op )
340+
341+ def before_run (self , run_context ):
342+ """Return the update_weights op so that it is executed during this run."""
343+ return session_run_hook .SessionRunArgs (self ._update_op )
344+
345+
306346class LinearClassifier (evaluable .Evaluable , trainable .Trainable ):
307347 """Linear classifier model.
308348
@@ -432,15 +472,23 @@ def __init__(self, # _joint_weight pylint: disable=invalid-name
432472 self ._optimizer = _get_optimizer (optimizer )
433473 num_ps_replicas = config .num_ps_replicas if config else 0
434474
475+ chief_hook = None
435476 if isinstance (optimizer , sdca_optimizer .SDCAOptimizer ):
436477 assert not _joint_weight , ("_joint_weight is incompatible with the"
437478 " SDCAOptimizer" )
438479 model_fn = sdca_classifier_model_fn
480+ # We use a hook to perform the weight update and shrink step only on the
481+ # chief. Because the SdcaModel constructed by the estimator within the
482+ # call to fit() but we need to pass the hook to fit(), we pass the hook
483+ # as a parameter to the model_fn and have that propagate the model to the
484+ # hook.
485+ chief_hook = _SdcaUpdateWeightsHook ()
439486 params = {
440487 "feature_columns" : feature_columns ,
441488 "optimizer" : self ._optimizer ,
442489 "weight_column_name" : weight_column_name ,
443490 "loss_type" : "logistic_loss" ,
491+ "update_weights_hook" : chief_hook ,
444492 }
445493 else :
446494 model_fn = _linear_classifier_model_fn
@@ -462,29 +510,35 @@ def __init__(self, # _joint_weight pylint: disable=invalid-name
462510 params = params ,
463511 feature_engineering_fn = feature_engineering_fn )
464512
513+ self ._additional_run_hook = None
514+ if self ._estimator .config .is_chief :
515+ self ._additional_run_hook = chief_hook
516+
465517 def get_estimator (self ):
466518 return self ._estimator
467519
468520 def fit (self , x = None , y = None , input_fn = None , steps = None , batch_size = None ,
469521 monitors = None , max_steps = None ):
470522 """See trainable.Trainable."""
471523 # TODO(roumposg): Remove when deprecated monitors are removed.
472- if monitors is not None :
473- deprecated_monitors = [
474- m for m in monitors
475- if not isinstance (m , session_run_hook .SessionRunHook )
476- ]
477- for monitor in deprecated_monitors :
478- monitor .set_estimator (self )
479- monitor ._lock_estimator () # pylint: disable=protected-access
480-
524+ if monitors is None :
525+ monitors = []
526+ deprecated_monitors = [
527+ m for m in monitors
528+ if not isinstance (m , session_run_hook .SessionRunHook )
529+ ]
530+ for monitor in deprecated_monitors :
531+ monitor .set_estimator (self )
532+ monitor ._lock_estimator () # pylint: disable=protected-access
533+
534+ if self ._additional_run_hook :
535+ monitors .append (self ._additional_run_hook )
481536 result = self ._estimator .fit (x = x , y = y , input_fn = input_fn , steps = steps ,
482537 batch_size = batch_size , monitors = monitors ,
483538 max_steps = max_steps )
484539
485- if monitors is not None :
486- for monitor in deprecated_monitors :
487- monitor ._unlock_estimator () # pylint: disable=protected-access
540+ for monitor in deprecated_monitors :
541+ monitor ._unlock_estimator () # pylint: disable=protected-access
488542
489543 return result
490544
@@ -751,9 +805,10 @@ def _get_train_ops(self, features, targets):
751805 columns_to_variables )
752806
753807 def _train_op_fn (unused_loss ):
754- return self ._linear_optimizer .get_train_step (
808+ sdca_model , train_op = self ._linear_optimizer .get_train_step (
755809 columns_to_variables , self ._weight_column_name ,
756810 self ._loss_type (), features , targets , global_step )
811+ return sdca_model .update_weights (train_op )
757812
758813 model_fn_ops = self ._head .head_ops (features , targets ,
759814 estimator .ModeKeys .TRAIN , _train_op_fn ,
0 commit comments