Skip to content

Commit 71e1fe6

Browse files
author
Esther2013
committed
RefVariable, variable_scope
1 parent 9f18881 commit 71e1fe6

File tree

12 files changed

+173
-32
lines changed

12 files changed

+173
-32
lines changed

src/TensorFlowNET.Core/Eager/Context.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
using System.Collections.Generic;
33
using System.Text;
44

5-
namespace Tensorflow.Eager
5+
namespace Tensorflow
66
{
77
public class Context
88
{

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ public bool is_fetchable<T>(T tensor_or_op)
152152
return false;
153153
}
154154

155+
public string get_name_scope()
156+
{
157+
return _name_stack;
158+
}
159+
155160
public string name_scope(string name)
156161
{
157162
string new_stack = "";

src/TensorFlowNET.Core/Operations/OpDefLibrary.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace Tensorflow
1010
{
1111
public class OpDefLibrary
1212
{
13-
public unsafe Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null)
13+
public Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null)
1414
{
1515
var g = ops.get_default_graph();
1616
var op_def = g.GetOpDef(op_type_name);

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,62 @@ public class RefVariable : VariableV1
88
{
99
public bool _in_graph_mode = true;
1010
public Tensor _initial_value;
11+
public string _graph_key;
12+
public bool _trainable;
13+
public Tensor _variable;
1114

12-
public RefVariable(object initial_value,
15+
public RefVariable(object initial_value,
16+
bool trainable = true,
17+
List<string> collections = null,
18+
bool validate_shape = true,
19+
string caching_device = "",
1320
string name = "",
14-
TF_DataType trainable = TF_DataType.DtInvalid,
15-
bool validate_shape = true) :
16-
base(initial_value, name, trainable, validate_shape)
21+
TF_DataType dtype = TF_DataType.DtInvalid) :
22+
base(initial_value, trainable, collections, validate_shape, caching_device, name, dtype)
1723
{
18-
_init_from_args(initial_value, name, trainable);
24+
_init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype);
1925
}
2026

2127
private void _init_from_args(object initial_value,
28+
bool trainable = true,
29+
List<string> collections = null,
30+
bool validate_shape = true,
31+
string caching_device = "",
2232
string name = "",
23-
TF_DataType trainable = TF_DataType.DtInvalid)
33+
TF_DataType dtype = TF_DataType.DtInvalid)
2434
{
25-
name = ops.name_scope("", "Variable", initial_value);
26-
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");
35+
if (initial_value is null)
36+
throw new ValueError("initial_value must be specified.");
37+
38+
var init_from_fn = false;
39+
40+
if(collections == null)
41+
{
42+
collections = new List<string> { ops.GraphKeys.GLOBAL_VARIABLES };
43+
}
44+
45+
// Store the graph key so optimizers know how to only retrieve variables from
46+
// this graph.
47+
_graph_key = ops.get_default_graph()._graph_key;
48+
49+
_trainable = trainable;
50+
if (!collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES))
51+
collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES);
52+
53+
ops.init_scope();
54+
name = new ops.name_scope(name, "Variable", init_from_fn ? new List<object>() : new List<object> { initial_value });
55+
if (init_from_fn)
56+
{
57+
58+
}
59+
else
60+
{
61+
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");
62+
}
63+
64+
var shape = _initial_value.shape;
65+
dtype = _initial_value.dtype;
66+
_variable = gen_state_ops.variable_v2(shape, dtype, name);
2767
}
2868
}
2969
}

src/TensorFlowNET.Core/Variables/VariableV1.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@ namespace Tensorflow
1616
/// </summary>
1717
public class VariableV1
1818
{
19-
public VariableV1(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true)
19+
public VariableV1(object initial_value,
20+
bool trainable = true,
21+
List<string> collections = null,
22+
bool validate_shape = true,
23+
string caching_device = "",
24+
string name = "",
25+
TF_DataType dtype = TF_DataType.DtInvalid)
2026
{
2127

2228
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class gen_state_ops
8+
{
9+
public static OpDefLibrary _op_def_lib = new OpDefLibrary();
10+
11+
/// <summary>
12+
/// Holds state in the form of a tensor that persists across steps.
13+
/// Outputs a ref to the tensor state so it may be read or modified.
14+
/// </summary>
15+
/// <param name="shape">The shape of the variable tensor.</param>
16+
/// <param name="dtype">The type of elements in the variable tensor.</param>
17+
/// <param name="name"></param>
18+
/// <param name="container"></param>
19+
/// <param name="shared_name"></param>
20+
/// <returns></returns>
21+
public static Tensor variable_v2(long[] shape, TF_DataType dtype, string name = "", string container = "", string shared_name = "")
22+
{
23+
var keywords = new Dictionary<string, object>();
24+
keywords.Add("dtype", dtype);
25+
keywords.Add("shape", shape);
26+
27+
var _op = _op_def_lib._apply_op_helper("VariableV2", name: name, keywords: keywords);
28+
29+
var _result = _op.outputs;
30+
var _inputs_flat = _op.inputs;
31+
32+
return new Tensor(_op, 0, dtype);
33+
}
34+
}
35+
}

src/TensorFlowNET.Core/Variables/variable_scope.py.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ public static RefVariable default_variable_creator(object initial_value, string
2626
}
2727
else
2828
{
29-
return new RefVariable(initial_value);
29+
return new RefVariable(initial_value,
30+
name: name,
31+
dtype: dtype);
3032
}
3133
}
3234

src/TensorFlowNET.Core/Variables/variables.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public class variables
1212
/// <returns></returns>
1313
public static object trainable_variables()
1414
{
15-
return ops.get_collection(ops.GraphKey.TRAINABLE_VARIABLES);
15+
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES);
1616
}
1717
}
1818
}

src/TensorFlowNET.Core/ops.GraphKeys.py.cs renamed to src/TensorFlowNET.Core/ops.GraphKeys.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,18 @@ public partial class ops
1515
/// specified, but it is also possible to pass an explicit list of
1616
/// variables.
1717
/// </summary>
18-
public static class GraphKey
18+
public static class GraphKeys
1919
{
2020
/// <summary>
2121
/// the subset of `Variable` objects that will be trained by an optimizer.
2222
/// </summary>
2323
public static string TRAINABLE_VARIABLES = "trainable_variables";
24+
25+
/// <summary>
26+
/// Key to collect Variable objects that are global (shared across machines).
27+
/// Default collection for all variables, except local ones.
28+
/// </summary>
29+
public static string GLOBAL_VARIABLES = "variables";
2430
}
2531
}
2632
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public partial class ops
8+
{
9+
public class name_scope
10+
{
11+
public string _name;
12+
public string _default_name;
13+
public object _values;
14+
public Context _ctx;
15+
public string _name_scope;
16+
17+
public name_scope(string name, string default_name, List<object> values)
18+
{
19+
_name = name;
20+
_default_name = default_name;
21+
_values = values;
22+
_ctx = new Context();
23+
24+
_name_scope = __enter__();
25+
}
26+
27+
public string __enter__()
28+
{
29+
if (String.IsNullOrEmpty(_name))
30+
{
31+
_name = _default_name;
32+
}
33+
34+
var g = get_default_graph();
35+
return g.name_scope(_name);
36+
}
37+
38+
public static implicit operator string(name_scope ns)
39+
{
40+
return ns._name_scope;
41+
}
42+
}
43+
}
44+
}

0 commit comments

Comments
 (0)