@@ -30,9 +30,14 @@ public Optimizer(double learning_rate, bool use_locking, string name = "")
3030 /// </summary>
3131 /// <param name="loss"></param>
3232 /// <returns></returns>
33- public Optimizer minimize ( Tensor loss , GateGradientType gate_gradients = GateGradientType . GATE_OP )
33+ public Optimizer minimize ( Tensor loss ,
34+ GateGradientType gate_gradients = GateGradientType . GATE_OP ,
35+ bool colocate_gradients_with_ops = false )
3436 {
35- compute_gradients ( loss , gate_gradients ) ;
37+ compute_gradients ( loss ,
38+ gate_gradients : gate_gradients ,
39+ colocate_gradients_with_ops : colocate_gradients_with_ops ) ;
40+
3641 return this ;
3742 }
3843
@@ -41,15 +46,30 @@ public Optimizer minimize(Tensor loss, GateGradientType gate_gradients = GateGra
4146 /// </summary>
4247 /// <param name="loss"></param>
4348 /// <param name="gate_gradients"></param>
44- public List < KeyValuePair < object , object > > compute_gradients ( Tensor loss , GateGradientType gate_gradients = GateGradientType . GATE_OP )
49+ public List < KeyValuePair < object , object > > compute_gradients ( Tensor loss ,
50+ List < RefVariable > var_list = null ,
51+ GateGradientType gate_gradients = GateGradientType . GATE_OP ,
52+ bool colocate_gradients_with_ops = false )
4553 {
4654 int num_towers = 1 ;
4755 if ( distribute_lib . get_loss_reduction ( ) == VariableAggregationType . MEAN )
4856 {
4957
5058 }
5159
52- var var_list = variables . trainable_variables ( ) ;
60+ var tmp = variables . trainable_variables ( ) ;
61+ switch ( tmp )
62+ {
63+ case List < RefVariable > values :
64+ var_list = values ;
65+ break ;
66+ }
67+
68+ foreach ( var v in var_list )
69+ {
70+
71+ }
72+
5373 return null ;
5474 }
5575 }
0 commit comments