Skip to content

Commit efe5ef2

Browse files
committed
finally fixed TestCond test case
1 parent 40dccd3 commit efe5ef2

File tree

5 files changed

+46
-51
lines changed

5 files changed

+46
-51
lines changed

src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace Tensorflow.Operations
1010
/// </summary>
1111
public class CondContext : ControlFlowContext
1212
{
13-
private string _name;
13+
1414

1515
/// <summary>
1616
/// The boolean tensor for the cond predicate

src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ public ControlFlowContext()
3737
_context_stack = new Stack<IControlFlowContext>();
3838
}
3939

40-
public string name { get; set; }
40+
public string name { get => _name; }
41+
protected string _name;
4142

4243
public void __init__()
4344
{

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -279,12 +279,36 @@ public override bool Equals(object obj)
279279
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param>
280280
public void _update_input(int index, Tensor tensor)
281281
{
282-
throw new NotImplementedException("_update_input");
282+
var input = _tf_input(index);
283+
var output = tensor._as_tf_output();
284+
_assert_same_graph( tensor);
285+
// Reset cached inputs.
286+
_inputs=new InputList(new Tensor[]{ tensor }); // is this right? original code: self._inputs_val=None
283287
// TODO: implement below code dependencies
284-
//_assert_same_graph( tensor);
285-
//// Reset cached inputs.
286-
//_inputs_val = null;
287-
//c_api.UpdateEdge(_graph._c_graph, tensor._as_tf_output(), _tf_input(index));
288+
//c_api.UpdateEdge(_graph._c_graph, output, input);
289+
}
290+
291+
private void _assert_same_graph(Tensor tensor)
292+
{
293+
//TODO: implement
294+
}
295+
296+
/// <summary>
297+
/// Create and return a new TF_Output for output_idx'th output of this op.
298+
/// </summary>
299+
public TF_Output _tf_output(int output_idx)
300+
{
301+
var tf_output = new TF_Output(op, output_idx);
302+
return tf_output;
303+
}
304+
305+
/// <summary>
306+
/// Create and return a new TF_Input for input_idx'th input of this op.
307+
/// </summary>
308+
public TF_Input _tf_input(int input_idx)
309+
{
310+
var tf_input = new TF_Input(op, input_idx);
311+
return tf_input;
288312
}
289313
}
290314
}

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -255,16 +255,17 @@ public Tensor this[int slice_spec]
255255

256256
public override string ToString()
257257
{
258-
if(NDims == 0)
259-
{
260-
switch (dtype)
261-
{
262-
case TF_DataType.TF_INT32:
263-
return Data<int>()[0].ToString();
264-
}
265-
}
266-
267-
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype.ToString()}";
258+
// this can throw IndexOutOfRangeException
259+
//if(NDims == 0)
260+
//{
261+
// switch (dtype)
262+
// {
263+
// case TF_DataType.TF_INT32:
264+
// return Data<int>()[0].ToString();
265+
// }
266+
//}
267+
268+
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}";
268269
}
269270

270271
public void Dispose()

test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ public void TestUniqueName()
6464
});
6565
}
6666

67-
[Ignore("Switch op gets not inserted correctly in the graph")]
6867
[TestMethod]
6968
public void TestCond()
7069
{
@@ -94,42 +93,12 @@ public void TestCond()
9493
//self.assertEqual(op.outputs, new object[0]);
9594
var op_input = op.inputs[0].op;
9695
self.assertEqual(op_input.type, "Switch");
97-
self.assertEqual(op_input.inputs[0], x);
96+
self.assertEqual(op_input.inputs[0].name, x.name);
9897
self.assertEqual(op.graph, g);
9998
self.assertIsNotNone(op._get_control_flow_context());
100-
self.assertEqual((op._get_control_flow_context() as ControlFlowContext).name, "cond/cond_text");
99+
var cond_text = op._get_control_flow_context() as ControlFlowContext;
100+
self.assertEqual(cond_text.name, "cond/cond_text");
101101
});
102-
/*
103-
@test_util.run_v1_only("b/120545219")
104-
def testCond(self):
105-
g = ops.Graph()
106-
with g.as_default():
107-
x = test_ops.int_output()
108-
109-
def true_fn():
110-
ops._create_c_op(ops.get_default_graph(),
111-
ops._NodeDef("IntInput", "cond/myop"), [x], [])
112-
new_ops = g._add_new_tf_operations()
113-
self.assertEqual(len(new_ops), 1)
114-
return x
115-
116-
control_flow_ops.cond(x < 10, true_fn, lambda: x)
117-
118-
op = g.get_operation_by_name("cond/myop")
119-
self.assertIsNotNone(op)
120-
self.assertEqual(op.name, "cond/myop")
121-
self.assertEqual(op.type, "IntInput")
122-
self.assertEqual(op.outputs, [])
123-
op_input = op.inputs[0].op
124-
self.assertEqual(op_input.type, "Switch")
125-
self.assertEqual(op_input.inputs[0], x)
126-
self.assertEqual(op.graph, g)
127-
# pylint: disable=protected-access
128-
self.assertIsNotNone(op._get_control_flow_context())
129-
self.assertEqual(op._get_control_flow_context().name,
130-
"cond/cond_text")
131-
# pylint: enable=protected-access
132-
*/
133102
}
134103

135104
[Ignore("Todo: Port")]

0 commit comments

Comments
 (0)