@@ -8,22 +8,62 @@ public class RefVariable : VariableV1
88 {
99 public bool _in_graph_mode = true ;
1010 public Tensor _initial_value ;
11+ public string _graph_key ;
12+ public bool _trainable ;
13+ public Tensor _variable ;
1114
12- public RefVariable ( object initial_value ,
15+ public RefVariable ( object initial_value ,
16+ bool trainable = true ,
17+ List < string > collections = null ,
18+ bool validate_shape = true ,
19+ string caching_device = "" ,
1320 string name = "" ,
14- TF_DataType trainable = TF_DataType . DtInvalid ,
15- bool validate_shape = true ) :
16- base ( initial_value , name , trainable , validate_shape )
21+ TF_DataType dtype = TF_DataType . DtInvalid ) :
22+ base ( initial_value , trainable , collections , validate_shape , caching_device , name , dtype )
1723 {
18- _init_from_args ( initial_value , name , trainable ) ;
24+ _init_from_args ( initial_value , trainable , collections , validate_shape , caching_device , name , dtype ) ;
1925 }
2026
2127 private void _init_from_args ( object initial_value ,
28+ bool trainable = true ,
29+ List < string > collections = null ,
30+ bool validate_shape = true ,
31+ string caching_device = "" ,
2232 string name = "" ,
23- TF_DataType trainable = TF_DataType . DtInvalid )
33+ TF_DataType dtype = TF_DataType . DtInvalid )
2434 {
25- name = ops . name_scope ( "" , "Variable" , initial_value ) ;
26- _initial_value = ops . convert_to_tensor ( initial_value , name : "initial_value" ) ;
35+ if ( initial_value is null )
36+ throw new ValueError ( "initial_value must be specified." ) ;
37+
38+ var init_from_fn = false ;
39+
40+ if ( collections == null )
41+ {
42+ collections = new List < string > { ops . GraphKeys . GLOBAL_VARIABLES } ;
43+ }
44+
45+ // Store the graph key so optimizers know how to only retrieve variables from
46+ // this graph.
47+ _graph_key = ops . get_default_graph ( ) . _graph_key ;
48+
49+ _trainable = trainable ;
50+ if ( ! collections . Contains ( ops . GraphKeys . TRAINABLE_VARIABLES ) )
51+ collections . Add ( ops . GraphKeys . TRAINABLE_VARIABLES ) ;
52+
53+ ops . init_scope ( ) ;
54+ name = new ops . name_scope ( name , "Variable" , init_from_fn ? new List < object > ( ) : new List < object > { initial_value } ) ;
55+ if ( init_from_fn )
56+ {
57+
58+ }
59+ else
60+ {
61+ _initial_value = ops . convert_to_tensor ( initial_value , name : "initial_value" ) ;
62+ }
63+
64+ var shape = _initial_value . shape ;
65+ dtype = _initial_value . dtype ;
66+ _variable = gen_state_ops . variable_v2 ( shape , dtype , name ) ;
2767 }
2868 }
2969}
0 commit comments