11using System ;
22using System . Collections . Generic ;
33using System . Text ;
4+ using Tensorflow . Framework ;
5+ using static Tensorflow . Python ;
46
57namespace Tensorflow . Train
68{
@@ -10,9 +12,10 @@ namespace Tensorflow.Train
1012 /// </summary>
1113 public class AdamOptimizer : Optimizer
1214 {
13- private float _beta1 ;
14- private float _beta2 ;
15- private float _epsilon ;
15+ float _beta1 ;
16+ float _beta2 ;
17+ float _epsilon ;
18+ Tensor _lr_t , _beta1_t , _beta2_t , _epsilon_t ;
1619
1720 public AdamOptimizer ( float learning_rate , float beta1 = 0.9f , float beta2 = 0.999f , float epsilon = 1e-8f , bool use_locking = false , string name = "Adam" )
1821 : base ( learning_rate , use_locking , name )
@@ -21,5 +24,51 @@ public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.99
2124 _beta2 = beta2 ;
2225 _epsilon = epsilon ;
2326 }
27+
28+ public override Operation _apply_sparse ( IndexedSlices grad , RefVariable var )
29+ {
30+ return _apply_sparse_shared ( grad . values , var , grad . indices , ( x , i , v ) =>
31+ {
32+ return state_ops . scatter_add ( x , i , v , use_locking : _use_locking ) ;
33+ } ) ;
34+ }
35+
36+ private Operation _apply_sparse_shared ( Tensor grad , RefVariable var , Tensor indices , Func < RefVariable , Tensor , Tensor , Tensor > scatter_add )
37+ {
38+ var ( beta1_power_v , beta2_power_v ) = _get_beta_accumulators ( ) ;
39+ Tensor beta1_power = math_ops . cast ( beta1_power_v , var . dtype . as_base_dtype ( ) ) ;
40+ Tensor beta2_power = math_ops . cast ( beta2_power_v , var . dtype . as_base_dtype ( ) ) ;
41+ var lr_t = math_ops . cast ( _lr_t , var . dtype . as_base_dtype ( ) ) ;
42+ var beta1_t = math_ops . cast ( _beta1_t , var . dtype . as_base_dtype ( ) ) ;
43+ var beta2_t = math_ops . cast ( _beta2_t , var . dtype . as_base_dtype ( ) ) ;
44+ var epsilon_t = math_ops . cast ( _epsilon_t , var . dtype . as_base_dtype ( ) ) ;
45+ var lr = ( lr_t * math_ops . sqrt ( 1 - beta2_power ) / ( 1 - beta1_power ) ) ;
46+ var m = get_slot ( var , "m" ) ;
47+ var m_scaled_g_values = grad * ( 1 - beta1_t ) ;
48+ var m_t = state_ops . assign ( m , m * beta1_t , use_locking : _use_locking ) ;
49+ with ( ops . control_dependencies ( new [ ] { m_t } ) , delegate
50+ {
51+ m_t = scatter_add ( m , indices , m_scaled_g_values ) ;
52+ } ) ;
53+
54+ var v = get_slot ( var , "v" ) ;
55+ var v_scaled_g_values = ( grad * grad ) * ( 1 - beta2_t ) ;
56+ var v_t = state_ops . assign ( v , v * beta2_t , use_locking : _use_locking ) ;
57+ with ( ops . control_dependencies ( new [ ] { v_t } ) , delegate
58+ {
59+ v_t = scatter_add ( v , indices , v_scaled_g_values ) ;
60+ } ) ;
61+ var v_sqrt = math_ops . sqrt ( v_t ) ;
62+ var var_update = state_ops . assign_sub ( var , lr * m_t / ( v_sqrt + epsilon_t ) , use_locking : _use_locking ) ;
63+ return control_flow_ops . group ( new [ ] { var_update , m_t , v_t } ) ;
64+ }
65+
66+ private ( RefVariable , RefVariable ) _get_beta_accumulators ( )
67+ {
68+ ops . init_scope ( ) ;
69+ var graph = ops . get_default_graph ( ) ;
70+ return ( _get_non_slot_variable ( "beta1_power" , graph : graph ) ,
71+ _get_non_slot_variable ( "beta2_power" , graph : graph ) ) ;
72+ }
2473 }
2574}
0 commit comments