@@ -18,7 +18,8 @@ public class PythonTest : Python
1818 {
1919 #region python compatibility layer
2020 protected PythonTest self { get => this ; }
21- protected object None {
21+ protected object None
22+ {
2223 get { return null ; }
2324 }
2425 #endregion
@@ -43,7 +44,7 @@ public void assertItemsEqual(ICollection given, ICollection expected)
4344 assertItemsEqual ( ( g [ i ] as NDArray ) . Array , ( e [ i ] as NDArray ) . Array ) ;
4445 else if ( e [ i ] is ICollection && g [ i ] is ICollection )
4546 assertEqual ( g [ i ] , e [ i ] ) ;
46- else
47+ else
4748 Assert . AreEqual ( e [ i ] , g [ i ] , $ "Items differ at index { i } , expected { e [ i ] } but got { g [ i ] } ") ;
4849 }
4950 }
@@ -102,28 +103,171 @@ protected object _eval_helper(Tensor[] tensors)
102103 {
103104 if ( tensors == null )
104105 return null ;
105- // return nest.map_structure(self._eval_tensor, tensors);
106+ return nest . map_structure ( self . _eval_tensor , tensors ) ;
106107 return null ;
107108 }
108109
109- //def evaluate(self, tensors) :
110- // """Evaluates tensors and returns numpy values.
111-
112- // Args:
113- // tensors: A Tensor or a nested list/tuple of Tensors.
114-
115- // Returns:
116- // tensors numpy values.
117- // """
118- // if context.executing_eagerly():
119- // return self._eval_helper(tensors)
120- // else:
121- // sess = ops.get_default_session()
122- // if sess is None:
123- // with self.test_session() as sess:
124- // return sess.run(tensors)
125- // else:
126- // return sess.run(tensors)
110+ protected object _eval_tensor ( object tensor )
111+ {
112+ if ( tensor == None )
113+ return None ;
114+ //else if (callable(tensor))
115+ // return self._eval_helper(tensor())
116+ else
117+ {
118+ try
119+ {
120+ //TODO:
121+ // if sparse_tensor.is_sparse(tensor):
122+ // return sparse_tensor.SparseTensorValue(tensor.indices, tensor.values,
123+ // tensor.dense_shape)
124+ //return (tensor as Tensor).numpy();
125+ }
126+ catch ( Exception e )
127+ {
128+ throw new ValueError ( "Unsupported type: " + tensor . GetType ( ) ) ;
129+ }
130+ return null ;
131+ }
132+ }
133+
134+ /// <summary>
135+ /// Evaluates tensors and returns numpy values.
136+ /// <param name="tensors">A Tensor or a nested list/tuple of Tensors.</param>
137+ /// </summary>
138+ /// <returns> tensors numpy values.</returns>
139+ public object evaluate ( params Tensor [ ] tensors )
140+ {
141+ // if context.executing_eagerly():
142+ // return self._eval_helper(tensors)
143+ // else:
144+ {
145+ var sess = ops . get_default_session ( ) ;
146+ if ( sess == None )
147+ with ( self . session ( ) , s => sess = s ) ;
148+ return sess . run ( tensors ) ;
149+ }
150+ }
151+
152+ //Returns a TensorFlow Session for use in executing tests.
153+ public Session session ( Graph graph = null , object config = null , bool use_gpu = false , bool force_gpu = false )
154+ {
155+ //Note that this will set this session and the graph as global defaults.
156+
157+ //Use the `use_gpu` and `force_gpu` options to control where ops are run.If
158+ //`force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
159+ //`use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
160+ //possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to
161+ //the CPU.
162+
163+ //Example:
164+ //```python
165+ //class MyOperatorTest(test_util.TensorFlowTestCase):
166+ // def testMyOperator(self):
167+ // with self.session(use_gpu= True):
168+ // valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
169+ // result = MyOperator(valid_input).eval()
170+ // self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
171+ // invalid_input = [-1.0, 2.0, 7.0]
172+ // with self.assertRaisesOpError("negative input not supported"):
173+ // MyOperator(invalid_input).eval()
174+ //```
175+
176+ //Args:
177+ // graph: Optional graph to use during the returned session.
178+ // config: An optional config_pb2.ConfigProto to use to configure the
179+ // session.
180+ // use_gpu: If True, attempt to run as many ops as possible on GPU.
181+ // force_gpu: If True, pin all ops to `/device:GPU:0`.
182+
183+ //Yields:
184+ // A Session object that should be used as a context manager to surround
185+ // the graph building and execution code in a test case.
186+
187+ Session s = null ;
188+ //if (context.executing_eagerly())
189+ // yield None
190+ //else
191+ {
192+ with < Session > ( self . _create_session ( graph , config , force_gpu ) , sess =>
193+ {
194+ with ( self . _constrain_devices_and_set_default ( sess , use_gpu , force_gpu ) , ( x ) =>
195+ {
196+ s = sess ;
197+ } ) ;
198+ } ) ;
199+ }
200+ return s ;
201+ }
202+
203+ private IPython _constrain_devices_and_set_default ( Session sess , bool useGpu , bool forceGpu )
204+ {
205+ //def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
206+ //"""Set the session and its graph to global default and constrain devices."""
207+ //if context.executing_eagerly():
208+ // yield None
209+ //else:
210+ // with sess.graph.as_default(), sess.as_default():
211+ // if force_gpu:
212+ // # Use the name of an actual device if one is detected, or
213+ // # '/device:GPU:0' otherwise
214+ // gpu_name = gpu_device_name()
215+ // if not gpu_name:
216+ // gpu_name = "/device:GPU:0"
217+ // with sess.graph.device(gpu_name):
218+ // yield sess
219+ // elif use_gpu:
220+ // yield sess
221+ // else:
222+ // with sess.graph.device("/device:CPU:0"):
223+ // yield sess
224+ return sess ;
225+ }
226+
227+ // See session() for details.
228+ private Session _create_session ( Graph graph , object cfg , bool forceGpu )
229+ {
230+ var prepare_config = new Func < object , object > ( ( config ) =>
231+ {
232+ // """Returns a config for sessions.
233+ // Args:
234+ // config: An optional config_pb2.ConfigProto to use to configure the
235+ // session.
236+ // Returns:
237+ // A config_pb2.ConfigProto object.
238+
239+ //TODO: config
240+
241+ // # use_gpu=False. Currently many tests rely on the fact that any device
242+ // # will be used even when a specific device is supposed to be used.
243+ // allow_soft_placement = not force_gpu
244+ // if config is None:
245+ // config = config_pb2.ConfigProto()
246+ // config.allow_soft_placement = allow_soft_placement
247+ // config.gpu_options.per_process_gpu_memory_fraction = 0.3
248+ // elif not allow_soft_placement and config.allow_soft_placement:
249+ // config_copy = config_pb2.ConfigProto()
250+ // config_copy.CopyFrom(config)
251+ // config = config_copy
252+ // config.allow_soft_placement = False
253+ // # Don't perform optimizations for tests so we don't inadvertently run
254+ // # gpu ops on cpu
255+ // config.graph_options.optimizer_options.opt_level = -1
256+ // # Disable Grappler constant folding since some tests & benchmarks
257+ // # use constant input and become meaningless after constant folding.
258+ // # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE
259+ // # GRAPPLER TEAM.
260+ // config.graph_options.rewrite_options.constant_folding = (
261+ // rewriter_config_pb2.RewriterConfig.OFF)
262+ // config.graph_options.rewrite_options.pin_to_host_optimization = (
263+ // rewriter_config_pb2.RewriterConfig.OFF)
264+ return config ;
265+ } ) ;
266+ //TODO: use this instead of normal session
267+ //return new ErrorLoggingSession(graph = graph, config = prepare_config(config))
268+ return new Session ( graph : graph ) ; //, config = prepare_config(config))
269+ }
270+
127271 #endregion
128272
129273
0 commit comments