Skip to content

Commit 9e04aee

Browse files
committed
ported the first cond test case that evaluates tensors
1 parent 2bf72e9 commit 9e04aee

File tree

2 files changed

+170
-26
lines changed

2 files changed

+170
-26
lines changed

test/TensorFlowNET.UnitTest/PythonTest.cs

Lines changed: 165 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ public class PythonTest : Python
1818
{
1919
#region python compatibility layer
2020
protected PythonTest self { get => this; }
21-
protected object None {
21+
protected object None
22+
{
2223
get { return null; }
2324
}
2425
#endregion
@@ -43,7 +44,7 @@ public void assertItemsEqual(ICollection given, ICollection expected)
4344
assertItemsEqual((g[i] as NDArray).Array, (e[i] as NDArray).Array);
4445
else if (e[i] is ICollection && g[i] is ICollection)
4546
assertEqual(g[i], e[i]);
46-
else
47+
else
4748
Assert.AreEqual(e[i], g[i], $"Items differ at index {i}, expected {e[i]} but got {g[i]}");
4849
}
4950
}
@@ -102,28 +103,171 @@ protected object _eval_helper(Tensor[] tensors)
102103
{
103104
if (tensors == null)
104105
return null;
105-
//return nest.map_structure(self._eval_tensor, tensors);
106+
return nest.map_structure(self._eval_tensor, tensors);
106107
return null;
107108
}
108109

109-
//def evaluate(self, tensors) :
110-
// """Evaluates tensors and returns numpy values.
111-
112-
// Args:
113-
// tensors: A Tensor or a nested list/tuple of Tensors.
114-
115-
// Returns:
116-
// tensors numpy values.
117-
// """
118-
// if context.executing_eagerly():
119-
// return self._eval_helper(tensors)
120-
// else:
121-
// sess = ops.get_default_session()
122-
// if sess is None:
123-
// with self.test_session() as sess:
124-
// return sess.run(tensors)
125-
// else:
126-
// return sess.run(tensors)
110+
protected object _eval_tensor(object tensor)
111+
{
112+
if (tensor == None)
113+
return None;
114+
//else if (callable(tensor))
115+
// return self._eval_helper(tensor())
116+
else
117+
{
118+
try
119+
{
120+
//TODO:
121+
// if sparse_tensor.is_sparse(tensor):
122+
// return sparse_tensor.SparseTensorValue(tensor.indices, tensor.values,
123+
// tensor.dense_shape)
124+
//return (tensor as Tensor).numpy();
125+
}
126+
catch (Exception e)
127+
{
128+
throw new ValueError("Unsupported type: " + tensor.GetType());
129+
}
130+
return null;
131+
}
132+
}
133+
134+
/// <summary>
135+
/// Evaluates tensors and returns numpy values.
136+
/// <param name="tensors">A Tensor or a nested list/tuple of Tensors.</param>
137+
/// </summary>
138+
/// <returns> tensors numpy values.</returns>
139+
public object evaluate(params Tensor[] tensors)
140+
{
141+
// if context.executing_eagerly():
142+
// return self._eval_helper(tensors)
143+
// else:
144+
{
145+
var sess = ops.get_default_session();
146+
if (sess == None)
147+
with(self.session(), s => sess = s);
148+
return sess.run(tensors);
149+
}
150+
}
151+
152+
//Returns a TensorFlow Session for use in executing tests.
153+
public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
154+
{
155+
//Note that this will set this session and the graph as global defaults.
156+
157+
//Use the `use_gpu` and `force_gpu` options to control where ops are run.If
158+
//`force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
159+
//`use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
160+
//possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to
161+
//the CPU.
162+
163+
//Example:
164+
//```python
165+
//class MyOperatorTest(test_util.TensorFlowTestCase):
166+
// def testMyOperator(self):
167+
// with self.session(use_gpu= True):
168+
// valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
169+
// result = MyOperator(valid_input).eval()
170+
// self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
171+
// invalid_input = [-1.0, 2.0, 7.0]
172+
// with self.assertRaisesOpError("negative input not supported"):
173+
// MyOperator(invalid_input).eval()
174+
//```
175+
176+
//Args:
177+
// graph: Optional graph to use during the returned session.
178+
// config: An optional config_pb2.ConfigProto to use to configure the
179+
// session.
180+
// use_gpu: If True, attempt to run as many ops as possible on GPU.
181+
// force_gpu: If True, pin all ops to `/device:GPU:0`.
182+
183+
//Yields:
184+
// A Session object that should be used as a context manager to surround
185+
// the graph building and execution code in a test case.
186+
187+
Session s = null;
188+
//if (context.executing_eagerly())
189+
// yield None
190+
//else
191+
{
192+
with<Session>(self._create_session(graph, config, force_gpu), sess =>
193+
{
194+
with(self._constrain_devices_and_set_default(sess, use_gpu, force_gpu), (x) =>
195+
{
196+
s = sess;
197+
});
198+
});
199+
}
200+
return s;
201+
}
202+
203+
private IPython _constrain_devices_and_set_default(Session sess, bool useGpu, bool forceGpu)
204+
{
205+
//def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
206+
//"""Set the session and its graph to global default and constrain devices."""
207+
//if context.executing_eagerly():
208+
// yield None
209+
//else:
210+
// with sess.graph.as_default(), sess.as_default():
211+
// if force_gpu:
212+
// # Use the name of an actual device if one is detected, or
213+
// # '/device:GPU:0' otherwise
214+
// gpu_name = gpu_device_name()
215+
// if not gpu_name:
216+
// gpu_name = "/device:GPU:0"
217+
// with sess.graph.device(gpu_name):
218+
// yield sess
219+
// elif use_gpu:
220+
// yield sess
221+
// else:
222+
// with sess.graph.device("/device:CPU:0"):
223+
// yield sess
224+
return sess;
225+
}
226+
227+
// See session() for details.
228+
private Session _create_session(Graph graph, object cfg, bool forceGpu)
229+
{
230+
var prepare_config = new Func<object, object>((config) =>
231+
{
232+
// """Returns a config for sessions.
233+
// Args:
234+
// config: An optional config_pb2.ConfigProto to use to configure the
235+
// session.
236+
// Returns:
237+
// A config_pb2.ConfigProto object.
238+
239+
//TODO: config
240+
241+
// # use_gpu=False. Currently many tests rely on the fact that any device
242+
// # will be used even when a specific device is supposed to be used.
243+
// allow_soft_placement = not force_gpu
244+
// if config is None:
245+
// config = config_pb2.ConfigProto()
246+
// config.allow_soft_placement = allow_soft_placement
247+
// config.gpu_options.per_process_gpu_memory_fraction = 0.3
248+
// elif not allow_soft_placement and config.allow_soft_placement:
249+
// config_copy = config_pb2.ConfigProto()
250+
// config_copy.CopyFrom(config)
251+
// config = config_copy
252+
// config.allow_soft_placement = False
253+
// # Don't perform optimizations for tests so we don't inadvertently run
254+
// # gpu ops on cpu
255+
// config.graph_options.optimizer_options.opt_level = -1
256+
// # Disable Grappler constant folding since some tests & benchmarks
257+
// # use constant input and become meaningless after constant folding.
258+
// # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE
259+
// # GRAPPLER TEAM.
260+
// config.graph_options.rewrite_options.constant_folding = (
261+
// rewriter_config_pb2.RewriterConfig.OFF)
262+
// config.graph_options.rewrite_options.pin_to_host_optimization = (
263+
// rewriter_config_pb2.RewriterConfig.OFF)
264+
return config;
265+
});
266+
//TODO: use this instead of normal session
267+
//return new ErrorLoggingSession(graph = graph, config = prepare_config(config))
268+
return new Session(graph: graph);//, config = prepare_config(config))
269+
}
270+
127271
#endregion
128272

129273

test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
1010
public class CondTestCases : PythonTest
1111
{
1212

13-
[Ignore("Todo")]
1413
[TestMethod]
1514
public void testCondTrue()
1615
{
17-
//var x = constant_op.constant(2);
18-
//var y = constant_op.constant(5);
19-
// var z = control_flow_ops.cond(math_ops.less(x,y), ()=> math_ops.multiply(x, 17), ()=> math_ops.add(y, 23))
20-
//self.assertEquals(self.evaluate(z), 34);
16+
var x = tf.constant(2);
17+
var y = tf.constant(5);
18+
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
19+
() => tf.add(y, tf.constant(23)));
20+
self.assertEquals(self.evaluate(z), 34);
2121
}
2222

2323
[Ignore("Todo")]

0 commit comments

Comments
 (0)