Skip to content

Commit 04ffa46

Browse files
committed
fix CondContext.AddValue with ops.control_dependencies SciSharp#213
1 parent 9bd5f00 commit 04ffa46

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,14 @@ public override Tensor AddValue(Tensor val)
8888
_values.Add(result.name);
8989
_external_values[result.name] = result;
9090
}
91-
// TODO: how to do 'with' here??
92-
//with(ops.control_dependencies(null), ctrl =>
93-
//{
94-
var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred);
95-
result = new[]{r0, r1}[_branch];
96-
if (_outer_context != null)
97-
_outer_context.AddInnerOp(result.op);
98-
//});
91+
92+
with(ops.control_dependencies(null), ctrl =>
93+
{
94+
var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred);
95+
result = new[] { r0, r1 }[_branch];
96+
if (_outer_context != null)
97+
_outer_context.AddInnerOp(result.op);
98+
});
9999

100100
result.op.graph.prevent_fetching(result.op);
101101
result.op._set_control_flow_context(this);

src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace Tensorflow.Operations
2222
/// 4. A ControlFlowContext has _context_stack.
2323
/// Pushed and popped by ctxt.Enter() and ctxt.Exit()
2424
/// </summary>
25-
public abstract class ControlFlowContext : IPython, IControlFlowContext
25+
public abstract class ControlFlowContext : Python, IPython, IControlFlowContext
2626
{
2727
/// <summary>
2828
/// The predicate tensor in this branch

src/TensorFlowNET.Core/Operations/check_ops.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ public static Operation assert_equal(object t1, object t2, object[] data = null,
3535
};
3636
}
3737

38-
var condition = math_ops.reduce_all(gen_math_ops.equal(x, y));
38+
var eq = gen_math_ops.equal(x, y);
39+
var condition = math_ops.reduce_all(eq);
3940
var x_static = tensor_util.constant_value(x);
4041
var y_static = tensor_util.constant_value(y);
4142
return control_flow_ops.Assert(condition, data);

0 commit comments

Comments
 (0)