Skip to content

Commit 5073769

Browse files
caisqtensorflower-gardener
authored andcommitted
tfdbg: two small bug fixes
1. Make the BaseDebugWrapperSession class capable of working as a context manager, as the non-debug Session. Also let the wrapper class support the close() method. 2. Handle the case in which a fetch is an object without the "name" attribute, e.g., a SparseTensor. Change: 141308246
1 parent c2bd403 commit 5073769

4 files changed

Lines changed: 45 additions & 2 deletions

File tree

tensorflow/python/debug/cli/cli_shared.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@
2424
from tensorflow.python.ops import variables
2525

2626

27+
def _get_fetch_name(fetch):
28+
"""Obtain the name or string representation of a fetch.
29+
30+
Args:
31+
fetch: The fetch in question.
32+
33+
Returns:
34+
If the attribute 'name' is available, return the name. Otherwise, return
35+
str(fetch).
36+
"""
37+
38+
return fetch.name if hasattr(fetch, "name") else str(fetch)
39+
40+
2741
def _get_fetch_names(fetches):
2842
"""Get a flattened list of the names in run() call fetches.
2943
@@ -46,7 +60,7 @@ def _get_fetch_names(fetches):
4660
else:
4761
# This ought to be a Tensor, an Operation or a Variable, for which the name
4862
# attribute should be available. (Bottom-out condition of the recursion.)
49-
lines.append(fetches.name)
63+
lines.append(_get_fetch_name(fetches))
5064

5165
return lines
5266

@@ -190,7 +204,7 @@ def get_run_short_description(run_call_count, fetches, feed_dict):
190204
description = "run #%d: " % run_call_count
191205

192206
if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)):
193-
description += "1 fetch (%s); " % fetches.name
207+
description += "1 fetch (%s); " % _get_fetch_name(fetches)
194208
else:
195209
# Could be (nested) list, tuple, dict or namedtuple.
196210
num_fetches = len(_get_fetch_names(fetches))

tensorflow/python/debug/cli/cli_shared_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tensorflow.python.framework import constant_op
2424
from tensorflow.python.framework import errors
2525
from tensorflow.python.framework import ops
26+
from tensorflow.python.framework import sparse_tensor
2627
from tensorflow.python.framework import test_util
2728
from tensorflow.python.ops import variables
2829
from tensorflow.python.platform import googletest
@@ -35,6 +36,9 @@ def setUp(self):
3536
self.const_b = constant_op.constant(22.0, name="b")
3637
self.const_c = constant_op.constant(33.0, name="c")
3738

39+
self.sparse_d = sparse_tensor.SparseTensor(
40+
indices=[[0, 0], [1, 1]], values=[1.0, 2.0], dense_shape=[3, 3])
41+
3842
def tearDown(self):
3943
ops.reset_default_graph()
4044

@@ -66,6 +70,10 @@ def testSingleFetchNoFeeds(self):
6670
description = cli_shared.get_run_short_description(12, self.const_a, None)
6771
self.assertEqual("run #12: 1 fetch (a:0); 0 feeds", description)
6872

73+
def testSparseTensorAsFetchShouldHandleNoNameAttribute(self):
74+
run_start_intro = cli_shared.get_run_start_intro(1, self.sparse_d, None, {})
75+
self.assertEqual(str(self.sparse_d), run_start_intro.lines[4].strip())
76+
6977
def testTwoFetchesListNoFeeds(self):
7078
fetches = [self.const_a, self.const_b]
7179
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})

tensorflow/python/debug/wrappers/framework.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,5 +501,14 @@ def on_run_end(self, request):
501501
"""
502502
pass
503503

504+
def __enter__(self):
505+
return self._sess.__enter__()
506+
507+
def __exit__(self, exec_type, exec_value, exec_tb):
508+
self._sess.__exit__(exec_type, exec_value, exec_tb)
509+
510+
def close(self):
511+
self._sess.close()
512+
504513
# TODO(cais): Add _node_name_regex_whitelist and
505514
# _node_op_type_regex_whitelist.

tensorflow/python/debug/wrappers/framework_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,18 @@ def testErrorDuringRun(self):
299299
self.assertTrue(
300300
isinstance(self._observer["tf_error"], errors.InvalidArgumentError))
301301

302+
def testUsingWrappedSessionShouldWorkAsContextManager(self):
303+
wrapper = TestDebugWrapperSession(self._sess, self._dump_root,
304+
self._observer)
305+
306+
with wrapper as sess:
307+
sess.run(self._s)
308+
309+
def testWrapperShouldSupportSessionClose(self):
310+
wrapper = TestDebugWrapperSession(self._sess, self._dump_root,
311+
self._observer)
312+
wrapper.close()
313+
302314

303315
if __name__ == "__main__":
304316
googletest.main()

0 commit comments

Comments
 (0)