Skip to content

Commit bcb803d

Browse files
committed
fix add_collections
1 parent 71e1fe6 commit bcb803d

File tree

10 files changed

+114
-20
lines changed

10 files changed

+114
-20
lines changed

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,18 @@ private T _as_graph_element_locked<T>(T obj, bool allow_tensor = true, bool allo
9191
throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}.");
9292
}
9393

94-
public void add_to_collection(string name, object value)
94+
public void add_to_collection<T>(string name, T value)
9595
{
96-
_collections[name] = value;
96+
if (_collections.ContainsKey(name))
97+
(_collections[name] as List<T>).Add(value);
98+
else
99+
_collections[name] = new List<T> { value };
100+
}
101+
102+
public void add_to_collections<T>(List<string> names, T value)
103+
{
104+
foreach (string name in names)
105+
add_to_collection(name, value);
97106
}
98107

99108
public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes,
@@ -236,9 +245,9 @@ public Operation[] get_operations()
236245
return _nodes_by_name.Values.Select(x => x).ToArray();
237246
}
238247

239-
public Dictionary<string, object> get_collection(string name)
248+
public object get_collection(string name)
240249
{
241-
return _collections;
250+
return _collections.ContainsKey(name) ? _collections[name] : null;
242251
}
243252

244253
public void Dispose()

src/TensorFlowNET.Core/Operations/OpDefLibrary.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public Operation _apply_op_helper(string op_type_name, string name = "", Diction
2020
name = op_type_name;
2121
}
2222

23-
string scope = g.unique_name(name) + "/";
23+
string scope = new ops.name_scope(name);
2424

2525
var default_type_attr_map = new Dictionary<string, object>();
2626
foreach (var attr_def in op_def.Attr)
@@ -88,15 +88,22 @@ public Operation _apply_op_helper(string op_type_name, string name = "", Diction
8888

8989
switch (attr_def.Type)
9090
{
91+
case "string":
92+
attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value);
93+
break;
9194
case "type":
9295
attr_value.Type = _MakeType((TF_DataType)value, attr_def);
9396
break;
9497
case "bool":
9598
attr_value.B = (bool)value;
9699
break;
97100
case "shape":
98-
attr_value.Shape = new TensorShapeProto();
101+
attr_value.Shape = value == null ?
102+
attr_def.DefaultValue.Shape :
103+
tensor_util.as_shape((long[])value);
99104
break;
105+
default:
106+
throw new InvalidDataException($"attr_def.Type {attr_def.Type}");
100107
}
101108

102109
attr_protos[key] = attr_value;

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,22 @@ public static NDArray convert_to_numpy_ndarray(object values)
7373
return nd;
7474
}
7575

76+
public static TensorShapeProto as_shape(long[] dims)
77+
{
78+
TensorShapeProto shape = new TensorShapeProto();
79+
80+
for (int i = 0; i < dims.Length; i++)
81+
{
82+
var dim = new TensorShapeProto.Types.Dim();
83+
dim.Size = dims[i];
84+
dim.Name = $"dim_{i}";
85+
86+
shape.Dim.Add(dim);
87+
}
88+
89+
return shape;
90+
}
91+
7692
public static TensorShape as_shape(this IShape shape, int[] dims)
7793
{
7894
return new TensorShape(dims);

src/TensorFlowNET.Core/Train/Optimizer.cs

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,14 @@ public Optimizer(double learning_rate, bool use_locking, string name = "")
3030
/// </summary>
3131
/// <param name="loss"></param>
3232
/// <returns></returns>
33-
public Optimizer minimize(Tensor loss, GateGradientType gate_gradients = GateGradientType.GATE_OP)
33+
public Optimizer minimize(Tensor loss,
34+
GateGradientType gate_gradients = GateGradientType.GATE_OP,
35+
bool colocate_gradients_with_ops = false)
3436
{
35-
compute_gradients(loss, gate_gradients);
37+
compute_gradients(loss,
38+
gate_gradients: gate_gradients,
39+
colocate_gradients_with_ops: colocate_gradients_with_ops);
40+
3641
return this;
3742
}
3843

@@ -41,15 +46,30 @@ public Optimizer minimize(Tensor loss, GateGradientType gate_gradients = GateGra
4146
/// </summary>
4247
/// <param name="loss"></param>
4348
/// <param name="gate_gradients"></param>
44-
public List<KeyValuePair<object, object>> compute_gradients(Tensor loss, GateGradientType gate_gradients = GateGradientType.GATE_OP)
49+
public List<KeyValuePair<object, object>> compute_gradients(Tensor loss,
50+
List<RefVariable> var_list = null,
51+
GateGradientType gate_gradients = GateGradientType.GATE_OP,
52+
bool colocate_gradients_with_ops = false)
4553
{
4654
int num_towers = 1;
4755
if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN)
4856
{
4957

5058
}
5159

52-
var var_list = variables.trainable_variables();
60+
var tmp = variables.trainable_variables();
61+
switch (tmp)
62+
{
63+
case List<RefVariable> values:
64+
var_list = values;
65+
break;
66+
}
67+
68+
foreach(var v in var_list)
69+
{
70+
71+
}
72+
5373
return null;
5474
}
5575
}

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,28 @@ private void _init_from_args(object initial_value,
6464
var shape = _initial_value.shape;
6565
dtype = _initial_value.dtype;
6666
_variable = gen_state_ops.variable_v2(shape, dtype, name);
67+
68+
// Manually overrides the variable's shape with the initial value's.
69+
if (validate_shape)
70+
{
71+
var initial_value_shape = _initial_value.shape;
72+
}
73+
74+
// If 'initial_value' makes use of other variables, make sure we don't
75+
// have an issue if these other variables aren't initialized first by
76+
// using their initialized_value() method.
77+
78+
ops.add_to_collections(collections, this);
79+
}
80+
81+
public static implicit operator _VariableScopeStore(RefVariable variable)
82+
{
83+
return null;
84+
}
85+
86+
public static implicit operator RefVariable(_VariableScopeStore store)
87+
{
88+
return null;
6789
}
6890
}
6991
}

src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ public static Tensor variable_v2(long[] shape, TF_DataType dtype, string name =
2323
var keywords = new Dictionary<string, object>();
2424
keywords.Add("dtype", dtype);
2525
keywords.Add("shape", shape);
26+
keywords.Add("container", container);
27+
keywords.Add("shared_name", shared_name);
2628

2729
var _op = _op_def_lib._apply_op_helper("VariableV2", name: name, keywords: keywords);
2830

src/TensorFlowNET.Core/Variables/variable_scope.py.cs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,30 @@ public static VariableScope get_variable_scope()
3939

4040
public static _VariableScopeStore get_variable_scope_store()
4141
{
42+
_VariableScopeStore ret = null;
4243
var scope_store = ops.get_collection(_VARSCOPESTORE_KEY);
4344
if (scope_store == null)
4445
{
45-
scope_store = new _VariableScopeStore();
46-
ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store);
46+
ret = new _VariableScopeStore();
47+
ops.add_to_collection(_VARSCOPESTORE_KEY, ret);
4748
}
4849
else
4950
{
50-
// scope_store = scope_store[0];
51+
switch (scope_store)
52+
{
53+
case List<RefVariable> values:
54+
ret = values[0];
55+
break;
56+
case List<_VariableScopeStore> values:
57+
ret = values[0];
58+
break;
59+
default:
60+
throw new InvalidOperationException("get_variable_scope_store");
61+
}
62+
5163
}
5264

53-
return scope_store;
65+
return ret;
5466
}
5567

5668
public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null)

src/TensorFlowNET.Core/ops.name_scope.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public class name_scope
1414
public Context _ctx;
1515
public string _name_scope;
1616

17-
public name_scope(string name, string default_name, List<object> values)
17+
public name_scope(string name, string default_name = "", List<object> values = null)
1818
{
1919
_name = name;
2020
_default_name = default_name;

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,21 @@ namespace Tensorflow
1212
{
1313
public partial class ops
1414
{
15-
public static void add_to_collection(string name, object value)
15+
public static void add_to_collection<T>(string name, T value)
1616
{
1717
var graph = tf.get_default_graph();
1818
graph.add_to_collection(name, value);
1919
}
2020

21-
public static _VariableScopeStore get_collection(string key)
21+
public static void add_to_collections<T>(List<string> names, T value)
2222
{
23-
return null;// get_default_graph().get_collection(key);
23+
var graph = tf.get_default_graph();
24+
graph.add_to_collections(names, value);
25+
}
26+
27+
public static object get_collection(string key)
28+
{
29+
return get_default_graph().get_collection(key);
2430
}
2531

2632
public static Graph get_default_graph()

test/TensorFlowNET.Examples/LinearRegression.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ public void Run()
2727
var train_Y = np.array(1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221,
2828
2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3);
2929
var n_samples = train_X.shape[0];
30-
30+
3131
// tf Graph Input
3232
var X = tf.placeholder(tf.float64);
3333
var Y = tf.placeholder(tf.float64);
3434

35-
// Set model weights
35+
// Set model weights
3636
var W = tf.Variable(rng.randn<double>(), name: "weight");
3737
var b = tf.Variable(rng.randn<double>(), name: "bias");
3838

0 commit comments

Comments
 (0)