Skip to content

Commit e7707cb

Browse files
committed
added GraphTests, all Igored at the moment
1 parent 30e4b21 commit e7707cb

File tree

4 files changed

+3300
-1
lines changed

4 files changed

+3300
-1
lines changed

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,51 @@ namespace Tensorflow
1212
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
1313
/// https://www.tensorflow.org/guide/graphs
1414
/// </summary>
15+
/*
16+
A TensorFlow computation, represented as a dataflow graph.
17+
18+
A `Graph` contains a set of
19+
`tf.Operation` objects,
20+
which represent units of computation; and
21+
`tf.Tensor` objects, which represent
22+
the units of data that flow between operations.
23+
24+
A default `Graph` is always registered, and accessible by calling
25+
`tf.get_default_graph`.
26+
To add an operation to the default graph, simply call one of the functions
27+
that defines a new `Operation`:
28+
29+
```python
30+
c = tf.constant(4.0)
31+
assert c.graph is tf.get_default_graph()
32+
```
33+
34+
Another typical usage involves the
35+
`tf.Graph.as_default`
36+
context manager, which overrides the current default graph for the
37+
lifetime of the context:
38+
39+
```python
40+
g = tf.Graph()
41+
with g.as_default():
42+
# Define operations and tensors in `g`.
43+
c = tf.constant(30.0)
44+
assert c.graph is g
45+
```
46+
47+
Important note: This class *is not* thread-safe for graph construction. All
48+
operations should be created from a single thread, or external
49+
synchronization must be provided. Unless otherwise specified, all methods
50+
are not thread-safe.
51+
52+
A `Graph` instance supports an arbitrary number of "collections"
53+
that are identified by name. For convenience when building a large
54+
graph, collections can store groups of related objects: for
55+
example, the `tf.Variable` uses a collection (named
56+
`tf.GraphKeys.GLOBAL_VARIABLES`) for
57+
all variables that are created during the construction of a graph. The caller
58+
may define additional collections by specifying a new name.
59+
*/
1560
public partial class Graph : IPython, IDisposable
1661
{
1762
private IntPtr _handle;

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,59 @@ public static object get_collection_ref(string key)
4949
return get_default_graph().get_collection_ref(key);
5050
}
5151

52-
private static Graph default_graph;
52+
private static Graph default_graph;
53+
/// <summary>
54+
/// Returns the default graph for the current thread.
55+
///
56+
/// The returned graph will be the innermost graph on which a
57+
/// `Graph.as_default()` context has been entered, or a global default
58+
/// graph if none has been explicitly created.
59+
///
60+
/// NOTE: The default graph is a property of the current thread.If you
61+
/// create a new thread, and wish to use the default graph in that
62+
/// thread, you must explicitly add a `with g.as_default():` in that
63+
/// thread's function.
64+
/// </summary>
65+
/// <returns></returns>
5366
public static Graph get_default_graph()
5467
{
68+
//TODO: original source indicates there should be a _default_graph_stack!
69+
//return _default_graph_stack.get_default()
5570
if (default_graph == null)
5671
default_graph = tf.Graph();
5772
return default_graph;
5873
}
5974
public static Graph set_default_graph(Graph graph)
6075
{
76+
//TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack!
6177
default_graph = graph;
6278
return default_graph;
79+
}
80+
81+
/// <summary>
82+
/// Clears the default graph stack and resets the global default graph.
83+
///
84+
/// NOTE: The default graph is a property of the current thread.This
85+
/// function applies only to the current thread.Calling this function while
86+
/// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined
87+
/// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
88+
/// after calling this function will result in undefined behavior.
89+
/// </summary>
90+
/// <returns></returns>
91+
public static void reset_default_graph()
92+
{
93+
//TODO: original source indicates there should be a _default_graph_stack!
94+
//if (!_default_graph_stack.is_cleared())
95+
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " +
96+
// "nested graphs. If you need a cleared graph, " +
97+
// "exit the nesting and create a new graph.");
98+
//_default_graph_stack.reset();
99+
if (default_graph!=null)
100+
default_graph.Dispose();
101+
default_graph = tf.Graph();
63102
}
64103

104+
65105
public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null)
66106
{
67107
foreach(var op_input in op_input_list)
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Microsoft.VisualStudio.TestTools.UnitTesting;
6+
using Tensorflow;
7+
using Tensorflow.Operations;
8+
9+
namespace TensorFlowNET.UnitTest.ops_test
10+
{
11+
/// <summary>
12+
/// excerpt of tensorflow/python/framework/ops_test.py
13+
/// </summary>
14+
[TestClass]
15+
public class GraphTest : PythonTest
16+
{
17+
18+
[TestInitialize]
19+
public void SetUp()
20+
{
21+
ops.reset_default_graph();
22+
}
23+
24+
[TestCleanup]
25+
public void TearDown()
26+
{
27+
ops.reset_default_graph();
28+
}
29+
30+
private void _AssertDefault(Graph expected) {
31+
Assert.AreSame(ops.get_default_graph(), expected);
32+
}
33+
34+
35+
[Ignore("Todo: Port")]
36+
[TestMethod]
37+
public void testResetDefaultGraphNesting()
38+
{
39+
/*
40+
def testResetDefaultGraphNesting(self):
41+
g0 = ops.Graph()
42+
with self.assertRaises(AssertionError):
43+
with g0.as_default():
44+
ops.reset_default_graph()
45+
*/
46+
}
47+
48+
[Ignore("Todo: Port")]
49+
[TestMethod]
50+
public void testGraphContextManagerCancelsEager()
51+
{
52+
/*
53+
def testGraphContextManagerCancelsEager(self):
54+
with context.eager_mode():
55+
with ops.Graph().as_default():
56+
self.assertFalse(context.executing_eagerly())
57+
*/
58+
}
59+
60+
61+
[Ignore("Todo: Port")]
62+
[TestMethod]
63+
public void testGraphContextManager()
64+
{
65+
/*
66+
def testGraphContextManager(self):
67+
g0 = ops.Graph()
68+
with g0.as_default() as g1:
69+
self.assertIs(g0, g1)
70+
*/
71+
}
72+
73+
[Ignore("Todo: Port")]
74+
[TestMethod]
75+
public void testDefaultGraph()
76+
{
77+
/*
78+
def testDefaultGraph(self):
79+
orig = ops.get_default_graph()
80+
self._AssertDefault(orig)
81+
g0 = ops.Graph()
82+
self._AssertDefault(orig)
83+
context_manager_0 = g0.as_default()
84+
self._AssertDefault(orig)
85+
with context_manager_0 as g0:
86+
self._AssertDefault(g0)
87+
with ops.Graph().as_default() as g1:
88+
self._AssertDefault(g1)
89+
self._AssertDefault(g0)
90+
self._AssertDefault(orig)
91+
*/
92+
}
93+
94+
[Ignore("Todo: Port")]
95+
[TestMethod]
96+
public void testPreventFeeding()
97+
{
98+
/*
99+
def testPreventFeeding(self):
100+
g = ops.Graph()
101+
a = constant_op.constant(2.0)
102+
self.assertTrue(g.is_feedable(a))
103+
g.prevent_feeding(a)
104+
self.assertFalse(g.is_feedable(a))
105+
*/
106+
}
107+
108+
109+
[Ignore("Todo: Port")]
110+
[TestMethod]
111+
public void testAsGraphElementConversions()
112+
{
113+
/*
114+
def testAsGraphElementConversions(self):
115+
116+
class ConvertibleObj(object):
117+
118+
def _as_graph_element(self):
119+
return "FloatOutput:0"
120+
121+
class NonConvertibleObj(object):
122+
123+
pass
124+
125+
g = ops.Graph()
126+
a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
127+
self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
128+
with self.assertRaises(TypeError):
129+
g.as_graph_element(NonConvertibleObj())
130+
*/
131+
}
132+
133+
[Ignore("Todo: Port")]
134+
[TestMethod]
135+
public void testGarbageCollected()
136+
{
137+
/*
138+
# Regression test against creating custom __del__ functions in classes
139+
# involved in cyclic references, e.g. Graph and Operation. (Python won't gc
140+
# cycles that require calling a __del__ method, because the __del__ method can
141+
# theoretically increase the object's refcount to "save" it from gc, and any
142+
# already-deleted objects in the cycle would have be to restored.)
143+
def testGarbageCollected(self):
144+
# Create a graph we can delete and a weak reference to monitor if it's gc'd
145+
g = ops.Graph()
146+
g_ref = weakref.ref(g)
147+
# Create some ops
148+
with g.as_default():
149+
a = constant_op.constant(2.0)
150+
b = constant_op.constant(3.0)
151+
c = math_ops.add(a, b)
152+
# Create a session we can delete
153+
with session.Session(graph=g) as sess:
154+
self.evaluate(c)
155+
# Delete all references and trigger gc
156+
del g
157+
del a
158+
del b
159+
del c
160+
del sess
161+
gc.collect()
162+
self.assertIsNone(g_ref())
163+
*/
164+
}
165+
166+
[Ignore("Todo: Port")]
167+
[TestMethod]
168+
public void testRunnableAfterInvalidShape()
169+
{
170+
/*
171+
def testRunnableAfterInvalidShape(self):
172+
with ops.Graph().as_default():
173+
with self.assertRaises(ValueError):
174+
math_ops.add([1, 2], [1, 2, 3])
175+
a = constant_op.constant(1)
176+
with session.Session() as sess:
177+
self.evaluate(a)
178+
*/
179+
}
180+
181+
[Ignore("Todo: Port")]
182+
[TestMethod]
183+
public void testRunnableAfterInvalidShapeWithKernelLabelMap()
184+
{
185+
/*
186+
def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
187+
g = ops.Graph()
188+
with g.as_default():
189+
with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
190+
with self.assertRaises(ValueError):
191+
test_ops.kernel_label_required(1)
192+
a = constant_op.constant(1)
193+
with session.Session() as sess:
194+
self.evaluate(a)
195+
*/
196+
}
197+
198+
199+
}
200+
}

0 commit comments

Comments
 (0)