Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions src/TensorFlowNET.Core/Graphs/Graph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,51 @@ namespace Tensorflow
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
/// https://www.tensorflow.org/guide/graphs
/// </summary>
/*
A TensorFlow computation, represented as a dataflow graph.

A `Graph` contains a set of
`tf.Operation` objects,
which represent units of computation; and
`tf.Tensor` objects, which represent
the units of data that flow between operations.

A default `Graph` is always registered, and accessible by calling
`tf.get_default_graph`.
To add an operation to the default graph, simply call one of the functions
that defines a new `Operation`:

```python
c = tf.constant(4.0)
assert c.graph is tf.get_default_graph()
```

Another typical usage involves the
`tf.Graph.as_default`
context manager, which overrides the current default graph for the
lifetime of the context:

```python
g = tf.Graph()
with g.as_default():
# Define operations and tensors in `g`.
c = tf.constant(30.0)
assert c.graph is g
```

Important note: This class *is not* thread-safe for graph construction. All
operations should be created from a single thread, or external
synchronization must be provided. Unless otherwise specified, all methods
are not thread-safe.

A `Graph` instance supports an arbitrary number of "collections"
that are identified by name. For convenience when building a large
graph, collections can store groups of related objects: for
example, the `tf.Variable` uses a collection (named
`tf.GraphKeys.GLOBAL_VARIABLES`) for
all variables that are created during the construction of a graph. The caller
may define additional collections by specifying a new name.
*/
public partial class Graph : IPython, IDisposable
{
private IntPtr _handle;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow.Operations
Expand Down Expand Up @@ -92,13 +93,15 @@ with ops.control_dependencies(new_summaries):

switch (original_result)
{
case Tensor result:
return (original_result, _BuildCondTensor(new[] { result.op }));
case Operation[] results:
return (original_result, _BuildCondTensor(results));
case Tensor tensor:
return (original_result, tensor);
case float[] fv:
{
var result = ops.convert_to_tensor(fv[0]);
return (original_result, result );
}
default:
return (original_result, null);
}
Expand All @@ -114,7 +117,7 @@ with ops.control_dependencies(new_summaries):
switch (original_result)
{
case Tensor[] results:
return (original_result, results);
return (original_result, new Tensor[] { _BuildCondTensor(results.Select(t=>t.op).ToArray())});
case Operation[] results:
return (original_result, new Tensor[] { _BuildCondTensor (results) });
case float[] fv:
Expand Down
6 changes: 3 additions & 3 deletions src/TensorFlowNET.Core/Operations/Operation.Input.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ public InputList inputs

for (int i = 0; i < NumInputs; i++)
{
var tf_outpus = Input(i);
var op = new Operation(tf_outpus.oper);
retval[i] = op.outputs[tf_outpus.index];
var tf_outputs = Input(i);
var op = new Operation(tf_outputs.oper);
retval[i] = op.outputs[tf_outputs.index];
}

_inputs = new InputList(retval);
Expand Down
55 changes: 48 additions & 7 deletions src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
Original file line number Diff line number Diff line change
Expand Up @@ -142,23 +142,45 @@ public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] c

return tpl.ToArray();
});
}

}

/// <summary>
/// Produces the content of `output_tensor` only after `dependencies`.
///
/// In some cases, a user may want the output of an operation to be
/// consumed externally only after some other dependencies have run
/// first.This function ensures returns `output_tensor`, but only after all
/// operations in `dependencies` have run.Note that this means that there is
/// no guarantee that `output_tensor` will be evaluated after any `dependencies`
/// have run.
///
/// See also `tf.tuple` and `tf.group`.
/// </summary>
/// <param name="dependencies">Iterable of operations to run before this op finishes.</param>
/// <param name="output_tensor">A `Tensor` or `IndexedSlices` that will be returned.</param>
/// <param name="name">(Optional) A name for this operation.</param>
/// <returns>Same as `output_tensor`.</returns>
public static Tensor with_dependencies(Operation[] dependencies, Tensor output_tensor, string name = null)
{
//TODO: missing original code
//if context.executing_eagerly():
// return output_tensor
var values = new List<object>();
values.AddRange(dependencies);
values.Add(output_tensor);

return with(ops.name_scope(name, "control_dependency", values), scope =>
{
name = scope;

return with(ops.control_dependencies(dependencies), ctl =>
// TODO: missing original code
//with ops.colocate_with(output_tensor):
{
output_tensor = ops.convert_to_tensor_or_composite(output_tensor);
return _Identity(output_tensor, name: name);
});
return with(ops.control_dependencies(dependencies), ctl =>
{
output_tensor = ops.convert_to_tensor_or_composite(output_tensor);
return _Identity(output_tensor, name: name);
});
}
});
}

Expand Down Expand Up @@ -393,8 +415,27 @@ public static Tensor[] _convert_flows_to_tensorarrays<T>(T[] tensors_or_tensorar
return tensors_or_flows;
}

/// <summary>
/// Returns the value of an available element of `inputs`.
///
/// This op tests each of the tensors in `inputs` in turn to determine if any of
/// them is available.If it finds an available tensor, it returns it and its
/// index in `inputs`.
///
/// It is an error if more than one tensor in `inputs` is available.If no tensor
/// in `inputs` is available, the returned tensor and index are not set.
///
/// This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
/// `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
/// before merging.
/// </summary>
/// <param name="inputs">inputs: The input tensors, at most one of which is available.</param>
/// <param name="name">A name for this operation (optional).</param>
/// <returns></returns>
public static Tensor merge(Tensor[] inputs, string name = null)
{
if (inputs.Any(x => x == null))
throw new ValueError($"At least one of the merge inputs is null: {inputs}");
return with(ops.name_scope(name, "Merge", inputs), scope =>
{
name = scope;
Expand Down
42 changes: 41 additions & 1 deletion src/TensorFlowNET.Core/ops.py.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,59 @@ public static object get_collection_ref(string key)
return get_default_graph().get_collection_ref(key);
}

private static Graph default_graph;
private static Graph default_graph;
/// <summary>
/// Returns the default graph for the current thread.
///
/// The returned graph will be the innermost graph on which a
/// `Graph.as_default()` context has been entered, or a global default
/// graph if none has been explicitly created.
///
/// NOTE: The default graph is a property of the current thread.If you
/// create a new thread, and wish to use the default graph in that
/// thread, you must explicitly add a `with g.as_default():` in that
/// thread's function.
/// </summary>
/// <returns></returns>
public static Graph get_default_graph()
{
//TODO: original source indicates there should be a _default_graph_stack!
//return _default_graph_stack.get_default()
if (default_graph == null)
default_graph = tf.Graph();
return default_graph;
}
public static Graph set_default_graph(Graph graph)
{
//TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack!
default_graph = graph;
return default_graph;
}

/// <summary>
/// Clears the default graph stack and resets the global default graph.
///
/// NOTE: The default graph is a property of the current thread.This
/// function applies only to the current thread.Calling this function while
/// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined
/// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
/// after calling this function will result in undefined behavior.
/// </summary>
/// <returns></returns>
public static void reset_default_graph()
{
//TODO: original source indicates there should be a _default_graph_stack!
//if (!_default_graph_stack.is_cleared())
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " +
// "nested graphs. If you need a cleared graph, " +
// "exit the nesting and create a new graph.");
//_default_graph_stack.reset();
if (default_graph!=null)
default_graph.Dispose();
default_graph = tf.Graph();
}


public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null)
{
foreach(var op_input in op_input_list)
Expand Down
2 changes: 1 addition & 1 deletion test/TensorFlowNET.UnitTest/PythonTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest
/// </summary>
public class PythonTest : Python
{
public void assertItemsEqual(ICollection expected, ICollection given)
public void assertItemsEqual(ICollection given, ICollection expected)
{
Assert.IsNotNull(expected);
Assert.IsNotNull(given);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
using Tensorflow;
using Tensorflow.Eager;

namespace TensorFlowNET.UnitTest
namespace TensorFlowNET.UnitTest.ops_test
{
/// <summary>
/// excerpt of tensorflow/python/framework/ops_test.py
Expand Down Expand Up @@ -157,8 +157,8 @@ public void TestNested()
});
});
});
assertItemsEqual(new[] { a_1.op, a_2.op, a_3.op, a_4.op }, b_1.op.control_inputs);
assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs);
assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op });
assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs);
}

[TestMethod]
Expand Down Expand Up @@ -200,6 +200,7 @@ public void TestClear()
b_none2 = constant_op.constant(12.0);
});
});
// Note assertItemsEqual(given, expected), expected and given parameters should be swapped below
assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs);
assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs);
assertItemsEqual(new object[0], b_none.op.control_inputs);
Expand Down Expand Up @@ -256,6 +257,7 @@ public void TestComplex()
});
});

// Note assertItemsEqual(given, expected), expected and given parameters should be swapped below
assertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs);
assertItemsEqual(new[] {a_1.op, a_2.op}, b_2.op.control_inputs);
assertItemsEqual(new[] { a_1.op, a_2.op}, b_3.op.control_inputs);
Expand Down
Loading