Skip to content

Commit dd91139

Browse files
committed
2 parents e3c5bb7 + fc24b92 commit dd91139

File tree

10 files changed

+3424
-32
lines changed

10 files changed

+3424
-32
lines changed

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,51 @@ namespace Tensorflow
1212
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
1313
/// https://www.tensorflow.org/guide/graphs
1414
/// </summary>
15+
/*
16+
A TensorFlow computation, represented as a dataflow graph.
17+
18+
A `Graph` contains a set of
19+
`tf.Operation` objects,
20+
which represent units of computation; and
21+
`tf.Tensor` objects, which represent
22+
the units of data that flow between operations.
23+
24+
A default `Graph` is always registered, and accessible by calling
25+
`tf.get_default_graph`.
26+
To add an operation to the default graph, simply call one of the functions
27+
that defines a new `Operation`:
28+
29+
```python
30+
c = tf.constant(4.0)
31+
assert c.graph is tf.get_default_graph()
32+
```
33+
34+
Another typical usage involves the
35+
`tf.Graph.as_default`
36+
context manager, which overrides the current default graph for the
37+
lifetime of the context:
38+
39+
```python
40+
g = tf.Graph()
41+
with g.as_default():
42+
# Define operations and tensors in `g`.
43+
c = tf.constant(30.0)
44+
assert c.graph is g
45+
```
46+
47+
Important note: This class *is not* thread-safe for graph construction. All
48+
operations should be created from a single thread, or external
49+
synchronization must be provided. Unless otherwise specified, all methods
50+
are not thread-safe.
51+
52+
A `Graph` instance supports an arbitrary number of "collections"
53+
that are identified by name. For convenience when building a large
54+
graph, collections can store groups of related objects: for
55+
example, the `tf.Variable` uses a collection (named
56+
`tf.GraphKeys.GLOBAL_VARIABLES`) for
57+
all variables that are created during the construction of a graph. The caller
58+
may define additional collections by specifying a new name.
59+
*/
1560
public partial class Graph : IPython, IDisposable
1661
{
1762
private IntPtr _handle;

src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45

56
namespace Tensorflow.Operations
@@ -92,13 +93,15 @@ with ops.control_dependencies(new_summaries):
9293

9394
switch (original_result)
9495
{
96+
case Tensor result:
97+
return (original_result, _BuildCondTensor(new[] { result.op }));
9598
case Operation[] results:
9699
return (original_result, _BuildCondTensor(results));
97-
case Tensor tensor:
98-
return (original_result, tensor);
99100
case float[] fv:
101+
{
100102
var result = ops.convert_to_tensor(fv[0]);
101103
return (original_result, result );
104+
}
102105
default:
103106
return (original_result, null);
104107
}
@@ -114,7 +117,7 @@ with ops.control_dependencies(new_summaries):
114117
switch (original_result)
115118
{
116119
case Tensor[] results:
117-
return (original_result, results);
120+
return (original_result, new Tensor[] { _BuildCondTensor(results.Select(t=>t.op).ToArray())});
118121
case Operation[] results:
119122
return (original_result, new Tensor[] { _BuildCondTensor (results) });
120123
case float[] fv:

src/TensorFlowNET.Core/Operations/Operation.Input.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ public InputList inputs
2727

2828
for (int i = 0; i < NumInputs; i++)
2929
{
30-
var tf_outpus = Input(i);
31-
var op = new Operation(tf_outpus.oper);
32-
retval[i] = op.outputs[tf_outpus.index];
30+
var tf_outputs = Input(i);
31+
var op = new Operation(tf_outputs.oper);
32+
retval[i] = op.outputs[tf_outputs.index];
3333
}
3434

3535
_inputs = new InputList(retval);

src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,23 +142,45 @@ public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] c
142142

143143
return tpl.ToArray();
144144
});
145-
}
146-
145+
}
146+
147+
/// <summary>
148+
/// Produces the content of `output_tensor` only after `dependencies`.
149+
///
150+
/// In some cases, a user may want the output of an operation to be
151+
/// consumed externally only after some other dependencies have run
152+
/// first.This function ensures returns `output_tensor`, but only after all
153+
/// operations in `dependencies` have run.Note that this means that there is
154+
/// no guarantee that `output_tensor` will be evaluated after any `dependencies`
155+
/// have run.
156+
///
157+
/// See also `tf.tuple` and `tf.group`.
158+
/// </summary>
159+
/// <param name="dependencies">Iterable of operations to run before this op finishes.</param>
160+
/// <param name="output_tensor">A `Tensor` or `IndexedSlices` that will be returned.</param>
161+
/// <param name="name">(Optional) A name for this operation.</param>
162+
/// <returns>Same as `output_tensor`.</returns>
147163
public static Tensor with_dependencies(Operation[] dependencies, Tensor output_tensor, string name = null)
148164
{
165+
//TODO: missing original code
166+
//if context.executing_eagerly():
167+
// return output_tensor
149168
var values = new List<object>();
150169
values.AddRange(dependencies);
151170
values.Add(output_tensor);
152171

153172
return with(ops.name_scope(name, "control_dependency", values), scope =>
154173
{
155174
name = scope;
156-
157-
return with(ops.control_dependencies(dependencies), ctl =>
175+
// TODO: missing original code
176+
//with ops.colocate_with(output_tensor):
158177
{
159-
output_tensor = ops.convert_to_tensor_or_composite(output_tensor);
160-
return _Identity(output_tensor, name: name);
161-
});
178+
return with(ops.control_dependencies(dependencies), ctl =>
179+
{
180+
output_tensor = ops.convert_to_tensor_or_composite(output_tensor);
181+
return _Identity(output_tensor, name: name);
182+
});
183+
}
162184
});
163185
}
164186

@@ -393,8 +415,27 @@ public static Tensor[] _convert_flows_to_tensorarrays<T>(T[] tensors_or_tensorar
393415
return tensors_or_flows;
394416
}
395417

418+
/// <summary>
419+
/// Returns the value of an available element of `inputs`.
420+
///
421+
/// This op tests each of the tensors in `inputs` in turn to determine if any of
422+
/// them is available.If it finds an available tensor, it returns it and its
423+
/// index in `inputs`.
424+
///
425+
/// It is an error if more than one tensor in `inputs` is available.If no tensor
426+
/// in `inputs` is available, the returned tensor and index are not set.
427+
///
428+
/// This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
429+
/// `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
430+
/// before merging.
431+
/// </summary>
432+
/// <param name="inputs">inputs: The input tensors, at most one of which is available.</param>
433+
/// <param name="name">A name for this operation (optional).</param>
434+
/// <returns></returns>
396435
public static Tensor merge(Tensor[] inputs, string name = null)
397436
{
437+
if (inputs.Any(x => x == null))
438+
throw new ValueError($"At least one of the merge inputs is null: {inputs}");
398439
return with(ops.name_scope(name, "Merge", inputs), scope =>
399440
{
400441
name = scope;

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,59 @@ public static object get_collection_ref(string key)
4949
return get_default_graph().get_collection_ref(key);
5050
}
5151

52-
private static Graph default_graph;
52+
private static Graph default_graph;
53+
/// <summary>
54+
/// Returns the default graph for the current thread.
55+
///
56+
/// The returned graph will be the innermost graph on which a
57+
/// `Graph.as_default()` context has been entered, or a global default
58+
/// graph if none has been explicitly created.
59+
///
60+
/// NOTE: The default graph is a property of the current thread.If you
61+
/// create a new thread, and wish to use the default graph in that
62+
/// thread, you must explicitly add a `with g.as_default():` in that
63+
/// thread's function.
64+
/// </summary>
65+
/// <returns></returns>
5366
public static Graph get_default_graph()
5467
{
68+
//TODO: original source indicates there should be a _default_graph_stack!
69+
//return _default_graph_stack.get_default()
5570
if (default_graph == null)
5671
default_graph = tf.Graph();
5772
return default_graph;
5873
}
5974
public static Graph set_default_graph(Graph graph)
6075
{
76+
//TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack!
6177
default_graph = graph;
6278
return default_graph;
79+
}
80+
81+
/// <summary>
82+
/// Clears the default graph stack and resets the global default graph.
83+
///
84+
/// NOTE: The default graph is a property of the current thread.This
85+
/// function applies only to the current thread.Calling this function while
86+
/// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined
87+
/// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
88+
/// after calling this function will result in undefined behavior.
89+
/// </summary>
90+
/// <returns></returns>
91+
public static void reset_default_graph()
92+
{
93+
//TODO: original source indicates there should be a _default_graph_stack!
94+
//if (!_default_graph_stack.is_cleared())
95+
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " +
96+
// "nested graphs. If you need a cleared graph, " +
97+
// "exit the nesting and create a new graph.");
98+
//_default_graph_stack.reset();
99+
if (default_graph!=null)
100+
default_graph.Dispose();
101+
default_graph = tf.Graph();
63102
}
64103

104+
65105
public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null)
66106
{
67107
foreach(var op_input in op_input_list)

test/TensorFlowNET.UnitTest/PythonTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest
1313
/// </summary>
1414
public class PythonTest : Python
1515
{
16-
public void assertItemsEqual(ICollection expected, ICollection given)
16+
public void assertItemsEqual(ICollection given, ICollection expected)
1717
{
1818
Assert.IsNotNull(expected);
1919
Assert.IsNotNull(given);

test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs renamed to test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
using Tensorflow;
77
using Tensorflow.Eager;
88

9-
namespace TensorFlowNET.UnitTest
9+
namespace TensorFlowNET.UnitTest.ops_test
1010
{
1111
/// <summary>
1212
/// excerpt of tensorflow/python/framework/ops_test.py
@@ -157,8 +157,8 @@ public void TestNested()
157157
});
158158
});
159159
});
160-
assertItemsEqual(new[] { a_1.op, a_2.op, a_3.op, a_4.op }, b_1.op.control_inputs);
161-
assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs);
160+
assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op });
161+
assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs);
162162
}
163163

164164
[TestMethod]
@@ -200,6 +200,7 @@ public void TestClear()
200200
b_none2 = constant_op.constant(12.0);
201201
});
202202
});
203+
// Note assertItemsEqual(given, expected), expected and given parameters should be swapped below
203204
assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs);
204205
assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs);
205206
assertItemsEqual(new object[0], b_none.op.control_inputs);
@@ -256,6 +257,7 @@ public void TestComplex()
256257
});
257258
});
258259

260+
// Note assertItemsEqual(given, expected), expected and given parameters should be swapped below
259261
assertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs);
260262
assertItemsEqual(new[] {a_1.op, a_2.op}, b_2.op.control_inputs);
261263
assertItemsEqual(new[] { a_1.op, a_2.op}, b_3.op.control_inputs);

0 commit comments

Comments
 (0)