Skip to content

Commit c4a585c

Browse files
committed
remove global static Graph instance.
1 parent d7c7d3d commit c4a585c

File tree

7 files changed

+76
-22
lines changed

7 files changed

+76
-22
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow
7+
{
8+
public class DefaultGraphStack
9+
{
10+
List<StackModel> stack = new List<StackModel>();
11+
12+
public void set_controller(Graph @default)
13+
{
14+
if (!stack.Exists(x => x.Graph == @default))
15+
stack.Add(new StackModel { Graph = @default, IsDefault = true });
16+
17+
foreach (var s in stack)
18+
s.IsDefault = s.Graph == @default;
19+
}
20+
21+
public Graph get_controller()
22+
{
23+
if (stack.Count == 0)
24+
stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true });
25+
26+
return stack.First(x => x.IsDefault).Graph;
27+
}
28+
29+
public void reset()
30+
{
31+
stack.Clear();
32+
}
33+
}
34+
35+
public class StackModel
36+
{
37+
public Graph Graph { get; set; }
38+
public bool IsDefault { get; set; }
39+
}
40+
}

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public partial class Graph : IPython, IDisposable
8787
private Dictionary<string, object> _collections = new Dictionary<string, object>();
8888

8989
public bool building_function;
90-
90+
9191
public Graph()
9292
{
9393
_handle = c_api.TF_NewGraph();
@@ -113,7 +113,14 @@ public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true,
113113
return _as_graph_element_locked(obj, allow_tensor, allow_operation);
114114
}
115115

116-
public Graph as_default() => ops.set_default_graph(this);
116+
/// <summary>
117+
/// Returns a context manager that makes this `Graph` the default graph.
118+
/// </summary>
119+
/// <returns></returns>
120+
public Graph as_default()
121+
{
122+
return ops.set_default_graph(this);
123+
}
117124

118125
private Tensor _as_graph_element(object obj)
119126
{

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,12 @@ private void _set_scope(VariableScope scope = null)
172172
}
173173
else
174174
{
175-
with(tf.variable_scope(scope, default_name: _base_name),
176-
captured_scope =>
177-
{
178-
_scope = captured_scope;
179-
});
175+
with(tf.variable_scope(scope, default_name: _base_name), captured_scope =>
176+
{
177+
// convert variable_scope to VariableScope
178+
_scope = captured_scope;
179+
});
180180
}
181-
182181
}
183182
}
184183
}

src/TensorFlowNET.Core/Variables/variable_scope.py.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ public class variable_scope : IPython
2626
private bool? _reuse;
2727
bool _in_graph_mode;
2828
protected Graph _graph;
29+
bool _building_function;
2930

3031
public variable_scope(string name,
3132
string default_name = "",
@@ -70,6 +71,17 @@ public variable_scope(VariableScope scope,
7071

7172
public void __enter__()
7273
{
74+
// If the default graph is building a function, then we should not replace it
75+
// with the cached graph.
76+
if (ops.get_default_graph().building_function)
77+
_building_function = true;
78+
else
79+
_building_function = false;
80+
if (_in_graph_mode && !_building_function)
81+
{
82+
_graph.as_default();
83+
}
84+
7385
_scope = _enter_scope_uncached();
7486
}
7587

src/TensorFlowNET.Core/ops.name_scope.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,13 @@ public void __enter__()
5454
public void Dispose()
5555
{
5656
var g = get_default_graph();
57-
// Console.WriteLine($"name_scope: {g._name_stack} -> {old_stack}");
5857
g._name_stack = old_stack;
58+
// Console.WriteLine($"name_scope: {g._name_stack} -> {old_stack}");
5959
}
6060

6161
public void __exit__()
6262
{
63+
6364
}
6465

6566
/// <summary>

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ public static object get_collection_ref(string key)
5050
return get_default_graph().get_collection_ref(key);
5151
}
5252

53-
private static Graph default_graph;
53+
public static DefaultGraphStack default_graph_stack = new DefaultGraphStack();
54+
5455
/// <summary>
5556
/// Returns the default graph for the current thread.
5657
///
@@ -68,15 +69,13 @@ public static Graph get_default_graph()
6869
{
6970
//TODO: original source indicates there should be a _default_graph_stack!
7071
//return _default_graph_stack.get_default()
71-
if (default_graph == null)
72-
default_graph = tf.Graph();
73-
return default_graph;
72+
return default_graph_stack.get_controller();
7473
}
7574
public static Graph set_default_graph(Graph graph)
7675
{
7776
//TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack!
78-
default_graph = graph;
79-
return default_graph;
77+
default_graph_stack.set_controller(graph);
78+
return default_graph_stack.get_controller();
8079
}
8180

8281
/// <summary>
@@ -96,10 +95,7 @@ public static void reset_default_graph()
9695
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " +
9796
// "nested graphs. If you need a cleared graph, " +
9897
// "exit the nesting and create a new graph.");
99-
//_default_graph_stack.reset();
100-
if (default_graph!=null)
101-
default_graph.Dispose();
102-
default_graph = tf.Graph();
98+
default_graph_stack.reset();
10399
}
104100

105101
public static Graph _get_graph_from_inputs(params Tensor[] op_input_list)

test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ public Graph BuildGraph()
195195
return graph;
196196
}
197197

198-
private bool RunWithImportedGraph(Session sess, Graph graph)
198+
private bool Train(Session sess, Graph graph)
199199
{
200200
var stopwatch = Stopwatch.StartNew();
201201

@@ -274,8 +274,7 @@ public bool Train()
274274
{
275275
var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
276276

277-
return with(tf.Session(graph), sess
278-
=> RunWithImportedGraph(sess, graph));
277+
return with(tf.Session(graph), sess => Train(sess, graph));
279278
}
280279

281280
public bool Predict()

0 commit comments

Comments
 (0)