Skip to content

Commit 48c433f

Browse files
samjabrahamsdrpngx
authored andcommitted
Have write_graph return the output path of file (tensorflow#6851)
* Have `write_graph` return the output path of file Since the function is performing `os.path.join()` for the user, it makes sense to return the location of the final file, or else the user will have to call `os.path.join()` again on their own. * Adjust WriteGraphTest to check returned path * Assert path exists; "/".join -> os.path.join Theoretically more Windows compatible
1 parent cf9818e commit 48c433f

2 files changed

Lines changed: 16 additions & 4 deletions

File tree

tensorflow/python/framework/graph_io.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def write_graph(graph_or_graph_def, logdir, name, as_text=True):
5050
filesystems, such as Google Cloud Storage (GCS).
5151
name: Filename for the graph.
5252
as_text: If `True`, writes the graph as an ASCII proto.
53+
54+
Returns:
55+
The path of the output proto file.
5356
"""
5457
if isinstance(graph_or_graph_def, ops.Graph):
5558
graph_def = graph_or_graph_def.as_graph_def()
@@ -64,3 +67,4 @@ def write_graph(graph_or_graph_def, logdir, name, as_text=True):
6467
file_io.atomic_write_string_to_file(path, str(graph_def))
6568
else:
6669
file_io.atomic_write_string_to_file(path, graph_def.SerializeToString())
70+
return path

tensorflow/python/training/saver_test.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,14 +1837,22 @@ class WriteGraphTest(test.TestCase):
18371837
def testWriteGraph(self):
18381838
test_dir = _TestDir("write_graph_dir")
18391839
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
1840-
graph_io.write_graph(ops_lib.get_default_graph(),
1841-
"/".join([test_dir, "l1"]), "graph.pbtxt")
1840+
path = graph_io.write_graph(ops_lib.get_default_graph(),
1841+
os.path.join(test_dir, "l1"), "graph.pbtxt")
1842+
truth = os.path.join(test_dir, "l1", "graph.pbtxt")
1843+
self.assertEqual(path, truth)
1844+
self.assertTrue(os.path.exists(path))
1845+
18421846

18431847
def testRecursiveCreate(self):
18441848
test_dir = _TestDir("deep_dir")
18451849
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
1846-
graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
1847-
"/".join([test_dir, "l1/l2/l3"]), "graph.pbtxt")
1850+
path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
1851+
os.path.join(test_dir, "l1", "l2", "l3"),
1852+
"graph.pbtxt")
1853+
truth = os.path.join(test_dir, 'l1', 'l2', 'l3', "graph.pbtxt")
1854+
self.assertEqual(path, truth)
1855+
self.assertTrue(os.path.exists(path))
18481856

18491857

18501858
class SaverUtilsTest(test.TestCase):

0 commit comments

Comments
 (0)