Skip to content

Commit d3002c0

Browse files
committed
fix name_scope issue when current_scope is exit.
upgrade tensorflow.dll to 1.14.0rc1.
1 parent 954713f commit d3002c0

File tree

7 files changed

+49
-12
lines changed

7 files changed

+49
-12
lines changed

src/TensorFlowNET.Core/Keras/Layers/Layer.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ namespace Tensorflow.Keras.Layers
1414
/// A layer is a class implementing common neural networks operations, such
1515
/// as convolution, batch norm, etc. These operations require managing weights,
1616
/// losses, updates, and inter-layer connectivity.
17+
///
18+
/// tensorflow\python\keras\engine\base_layer.py
1719
/// </summary>
1820
public class Layer : AutoTrackable
1921
{
@@ -55,9 +57,14 @@ public Layer(bool trainable = true,
5557
{
5658
this.trainable = trainable;
5759
this._dtype = dtype;
60+
// A stateful layer is a layer whose updates are run during inference too,
61+
// for instance stateful RNNs.
5862
stateful = false;
63+
// Indicates whether `build` needs to be called upon layer call, to create
64+
// the layer's weights.
5965
built = false;
6066
this.supports_masking = false;
67+
6168
_init_set_name(name);
6269
_trainable_weights = new List<RefVariable>();
6370
_compute_previous_mask = false;
@@ -154,7 +161,8 @@ protected void _maybe_build(Tensor input)
154161
if (_dtype == TF_DataType.DtInvalid)
155162
_dtype = input.dtype;
156163

157-
build(input.GetShape());
164+
var input_shapes = input.GetShape();
165+
build(input_shapes);
158166
built = true;
159167
}
160168

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@ public Layer(bool trainable = true,
2222
TF_DataType dtype = TF_DataType.DtInvalid,
2323
bool? _reuse = null) : base(trainable: trainable, name: name, dtype: dtype)
2424
{
25+
// For backwards compatibility, legacy layers do not use `ResourceVariable`
26+
// by default.
2527
this._use_resource_variables = false;
2628
this._reuse = _reuse;
29+
30+
// Avoid an incorrect lint error
31+
_trainable_weights = new List<RefVariable>();
2732
this.built = false;
2833
_keras_style = false;
2934
}
@@ -130,13 +135,12 @@ protected virtual RefVariable add_weight(string name,
130135
initializer: initializer,
131136
trainable: trainable,
132137
getter: (name1, shape1, dtype1, initializer1, trainable1) =>
133-
{
134-
return tf.get_variable(name1,
138+
tf.get_variable(name1,
135139
shape: new TensorShape(shape1),
136140
dtype: dtype1,
137141
initializer: initializer1,
138-
trainable: trainable1);
139-
});
142+
trainable: trainable1)
143+
);
140144

141145
//if (init_graph != null)
142146
//var trainable_variables = variables.trainable_variables();

src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public partial class Tensor
1515
public static Tensor operator -(Tensor x, Tensor y) => BinaryOpWrapper("sub", x, y);
1616
public static Tensor operator -(Tensor x, int y) => BinaryOpWrapper("sub", x, y);
1717
public static Tensor operator -(Tensor x, double y) => BinaryOpWrapper("sub", x, y);
18-
public static Tensor operator -(float x, Tensor y) => BinaryOpWrapper("Sub", x, y);
18+
public static Tensor operator -(float x, Tensor y) => BinaryOpWrapper("sub", x, y);
1919

2020
public static Tensor operator *(float x, Tensor y) => BinaryOpWrapper("mul", x, y);
2121
public static Tensor operator *(double x, Tensor y) => BinaryOpWrapper("mul", x, y);

src/TensorFlowNET.Core/Variables/PureVariableScope.cs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45

56
namespace Tensorflow
@@ -16,7 +17,8 @@ public class PureVariableScope : IPython
1617
private _VariableScopeStore _var_scope_store;
1718
private VariableScope variable_scope_object;
1819
private VariableScope _cached_variable_scope_object;
19-
20+
VariableScope _last_variable_scope_object;
21+
Dictionary<string, int> _old_subscopes;
2022
public PureVariableScope(string name,
2123
string old_name_scope = null,
2224
TF_DataType dtype = TF_DataType.DtInvalid)
@@ -51,6 +53,7 @@ public void __enter__()
5153
if(_scope != null)
5254
{
5355
_var_scope_store.open_variable_scope(_new_name);
56+
_old_subscopes = _var_scope_store.variable_scopes_count.ToDictionary(kv => kv.Key, kv => kv.Value);
5457
variable_scope_object = _cached_variable_scope_object;
5558
}
5659
else
@@ -66,6 +69,7 @@ public void __enter__()
6669
_var_scope_store.open_variable_scope(_new_name);
6770
}
6871
_var_scope_store.current_scope = variable_scope_object;
72+
_last_variable_scope_object = variable_scope_object;
6973
}
7074

7175
public void Dispose()
@@ -75,7 +79,12 @@ public void Dispose()
7579

7680
public void __exit__()
7781
{
78-
82+
// If jumping out from a non-prolonged scope, restore counts.
83+
if (_scope != null)
84+
_var_scope_store.variable_scopes_count = _old_subscopes;
85+
else
86+
_var_scope_store.close_variable_subscopes(_new_name);
87+
_var_scope_store.current_scope = _old;
7988
}
8089

8190
public static implicit operator VariableScope(PureVariableScope scope)

src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace Tensorflow
77
public class _VariableScopeStore
88
{
99
public VariableScope current_scope { get; set; }
10-
private Dictionary<string, int> variable_scopes_count;
10+
public Dictionary<string, int> variable_scopes_count;
1111

1212
public _VariableScopeStore()
1313
{
@@ -23,6 +23,13 @@ public void open_variable_scope(string scope_name)
2323
variable_scopes_count[scope_name] = 1;
2424
}
2525

26+
public void close_variable_subscopes(string scope_name)
27+
{
28+
foreach (var k in variable_scopes_count.Keys)
29+
if (scope_name == null || k.StartsWith(scope_name + "/"))
30+
variable_scopes_count[k] = 0;
31+
}
32+
2633
public int variable_scope_count(string scope_name)
2734
{
2835
if (variable_scopes_count.ContainsKey(scope_name))

src/TensorFlowNET.Core/Variables/variable_scope.py.cs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,13 @@ private VariableScope _enter_scope_uncached()
106106

107107
if (_name != null || _scope != null)
108108
{
109-
var name_scope = _name == null ? _scope.name.Split('/').Last() : _name;
109+
var name_scope = _scope.name.Split('/').Last();
110110
if (current_name_scope == null)
111111
current_name_scope = ops.name_scope(name_scope);
112112
current_name_scope.__enter__();
113113
var current_name_scope_name = current_name_scope;
114114
_current_name_scope = current_name_scope;
115-
string old_name_scope = current_name_scope_name;
115+
string old_name_scope = _scope.original_name_scope;
116116

117117
if(_scope == null)
118118
pure_variable_scope = new PureVariableScope(_name, old_name_scope: old_name_scope);
@@ -139,14 +139,22 @@ private VariableScope _enter_scope_uncached()
139139
}
140140
}
141141

142+
/// <summary>
143+
/// Get a name with the given prefix unique in the current variable scope.
144+
/// </summary>
145+
/// <param name="prefix"></param>
146+
/// <returns></returns>
142147
public static string _get_unique_variable_scope(string prefix)
143148
{
144149
var var_scope_store = get_variable_scope_store();
145150
var current_scope = get_variable_scope();
146151
string name = !string.IsNullOrEmpty(current_scope.name) ? current_scope.name + "/" + prefix : prefix;
147152
if (var_scope_store.variable_scope_count(name) == 0)
148153
return prefix;
149-
throw new NotImplementedException("_get_unique_variable_scope");
154+
var idx = 1;
155+
while (var_scope_store.variable_scope_count($"{name}_{idx}") > 0)
156+
idx += 1;
157+
return $"{prefix}_{idx}";
150158
}
151159

152160
public static RefVariable default_variable_creator(object initial_value,
@@ -250,6 +258,7 @@ public static implicit operator VariableScope(variable_scope scope)
250258

251259
public void __exit__()
252260
{
261+
_cached_pure_variable_scope.__exit__();
253262
if (_current_name_scope != null)
254263
_current_name_scope.__exit__();
255264
}
1.64 MB
Binary file not shown.

0 commit comments

Comments
 (0)