Skip to content

Commit c997c72

Browse files
committed
more cond test cases: testCondTrue and testCondFalse
1 parent 949ab3e commit c997c72

File tree

3 files changed

+78
-58
lines changed

3 files changed

+78
-58
lines changed

src/TensorFlowNET.Core/Sessions/Session.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ public Session(Graph g, SessionOptions opts = null, Status s = null)
3838
Status.Check(true);
3939
}
4040

41+
public Session as_default()
42+
{
43+
tf.defaultSession = this;
44+
return this;
45+
}
46+
4147
public static Session LoadFromSavedModel(string path)
4248
{
4349
var graph = c_api.TF_NewGraph();

test/TensorFlowNET.UnitTest/PythonTest.cs

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -132,23 +132,66 @@ protected object _eval_tensor(object tensor)
132132
}
133133

134134
/// <summary>
135-
/// Evaluates tensors and returns numpy values.
135+
/// Evaluates tensors and returns a dictionary of {name:result, ...}.
136136
/// <param name="tensors">A Tensor or a nested list/tuple of Tensors.</param>
137137
/// </summary>
138-
/// <returns> tensors numpy values.</returns>
139-
public object evaluate(params Tensor[] tensors)
138+
public Dictionary<string, NDArray> evaluate(params Tensor[] tensors)
140139
{
140+
var results = new Dictionary<string, NDArray>();
141141
// if context.executing_eagerly():
142142
// return self._eval_helper(tensors)
143143
// else:
144144
{
145145
var sess = ops.get_default_session();
146-
if (sess == None)
147-
with(self.session(), s => sess = s);
148-
return sess.run(tensors);
146+
if (sess == null)
147+
sess = self.session();
148+
149+
with<Session>(sess, s =>
150+
{
151+
foreach (var t in tensors)
152+
results[t.name] = t.eval();
153+
});
154+
return results;
155+
}
156+
}
157+
158+
public NDArray evaluate(Tensor tensor)
159+
{
160+
NDArray result = null;
161+
// if context.executing_eagerly():
162+
// return self._eval_helper(tensors)
163+
// else:
164+
{
165+
var sess = ops.get_default_session();
166+
if (sess == null)
167+
sess = self.session();
168+
with<Session>(sess, s =>
169+
{
170+
result = tensor.eval();
171+
});
172+
return result;
173+
}
174+
}
175+
176+
public object eval_scalar(Tensor tensor)
177+
{
178+
NDArray result = null;
179+
// if context.executing_eagerly():
180+
// return self._eval_helper(tensors)
181+
// else:
182+
{
183+
var sess = ops.get_default_session();
184+
if (sess == null)
185+
sess = self.session();
186+
with<Session>(sess, s =>
187+
{
188+
result = tensor.eval();
189+
});
190+
return result.Array.GetValue(0);
149191
}
150192
}
151193

194+
152195
//Returns a TensorFlow Session for use in executing tests.
153196
public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
154197
{
@@ -188,16 +231,11 @@ public Session session(Graph graph = null, object config = null, bool use_gpu =
188231
//if (context.executing_eagerly())
189232
// yield None
190233
//else
191-
{
192-
with<Session>(self._create_session(graph, config, force_gpu), sess =>
193-
{
194-
with(self._constrain_devices_and_set_default(sess, use_gpu, force_gpu), (x) =>
195-
{
196-
s = sess;
197-
});
198-
});
199-
}
200-
return s;
234+
//{
235+
s = self._create_session(graph, config, force_gpu);
236+
self._constrain_devices_and_set_default(s, use_gpu, force_gpu);
237+
//}
238+
return s.as_default();
201239
}
202240

203241
private IPython _constrain_devices_and_set_default(Session sess, bool useGpu, bool forceGpu)

test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs

Lines changed: 18 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,53 +13,29 @@ public class CondTestCases : PythonTest
1313
[TestMethod]
1414
public void testCondTrue()
1515
{
16-
var x = tf.constant(2);
17-
var y = tf.constant(5);
18-
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
19-
() => tf.add(y, tf.constant(23)));
20-
self.assertEquals(self.evaluate(z), 34);
16+
with(tf.Graph().as_default(), g =>
17+
{
18+
var x = tf.constant(2);
19+
var y = tf.constant(5);
20+
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
21+
() => tf.add(y, tf.constant(23)));
22+
//tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
23+
self.assertEquals(eval_scalar(z), 34);
24+
});
2125
}
2226

23-
[Ignore("Todo")]
27+
[Ignore("This Test Fails due to missing edges in the graph!")]
2428
[TestMethod]
2529
public void testCondFalse()
2630
{
27-
// def testCondFalse(self):
28-
// x = constant_op.constant(2)
29-
// y = constant_op.constant(1)
30-
// z = control_flow_ops.cond(
31-
// math_ops.less(
32-
// x,
33-
// y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23))
34-
// self.assertEquals(self.evaluate(z), 24)
35-
}
36-
37-
[Ignore("Todo")]
38-
[TestMethod]
39-
public void testCondTrueLegacy()
40-
{
41-
// def testCondTrueLegacy(self):
42-
// x = constant_op.constant(2)
43-
// y = constant_op.constant(5)
44-
// z = control_flow_ops.cond(
45-
// math_ops.less(x, y),
46-
// fn1=lambda: math_ops.multiply(x, 17),
47-
// fn2=lambda: math_ops.add(y, 23))
48-
// self.assertEquals(self.evaluate(z), 34)
49-
}
50-
51-
[Ignore("Todo")]
52-
[TestMethod]
53-
public void testCondFalseLegacy()
54-
{
55-
// def testCondFalseLegacy(self):
56-
// x = constant_op.constant(2)
57-
// y = constant_op.constant(1)
58-
// z = control_flow_ops.cond(
59-
// math_ops.less(x, y),
60-
// fn1=lambda: math_ops.multiply(x, 17),
61-
// fn2=lambda: math_ops.add(y, 23))
62-
// self.assertEquals(self.evaluate(z), 24)
31+
with(tf.Graph().as_default(), g =>
32+
{
33+
var x = tf.constant(2);
34+
var y = tf.constant(1);
35+
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
36+
() => tf.add(y, tf.constant(23)));
37+
self.assertEquals(eval_scalar(z), 24);
38+
});
6339
}
6440

6541
[Ignore("Todo")]

0 commit comments

Comments
 (0)