forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutil.py
More file actions
174 lines (152 loc) · 7.26 KB
/
Copy pathutil.py
File metadata and controls
174 lines (152 loc) · 7.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for extracting and writing checkpoint info`."""
from tensorflow.core.protobuf import trackable_object_graph_pb2
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.trackable import trackable_utils
from tensorflow.python.util import object_identity
def serialize_slot_variables(trackable_objects, node_ids, object_names):
"""Gather and name slot variables."""
non_slot_objects = list(trackable_objects)
slot_variables = object_identity.ObjectIdentityDictionary()
for trackable in non_slot_objects:
# TODO(b/110718070): Fix Keras imports.
# Note: dir() is used rather than hasattr() here to avoid triggering
# custom __getattr__ code, see b/152031870 for context.
if "get_slot_names" in dir(trackable):
slot_names = trackable.get_slot_names()
for slot_name in slot_names:
for original_variable_node_id, original_variable in enumerate(
non_slot_objects):
try:
slot_variable = trackable.get_slot(original_variable, slot_name)
except (AttributeError, KeyError):
slot_variable = None
if slot_variable is None:
continue
slot_variable._maybe_initialize_trackable() # pylint: disable=protected-access
if slot_variable._trackable_children(): # pylint: disable=protected-access
# TODO(allenl): Gather dependencies of slot variables.
raise NotImplementedError(
"Currently only variables with no dependencies can be saved as "
"slot variables. File a feature request if this limitation "
"bothers you.")
if slot_variable in node_ids:
raise NotImplementedError(
"A slot variable was re-used as a dependency of a Trackable "
f"object: {slot_variable}. This is not currently allowed. "
"File a feature request if this limitation bothers you.")
checkpoint_name = trackable_utils.slot_variable_key(
variable_path=object_names[original_variable],
optimizer_path=object_names[trackable],
slot_name=slot_name)
object_names[slot_variable] = checkpoint_name
slot_variable_node_id = len(trackable_objects)
node_ids[slot_variable] = slot_variable_node_id
trackable_objects.append(slot_variable)
slot_variable_proto = (
trackable_object_graph_pb2.TrackableObjectGraph.TrackableObject
.SlotVariableReference(
slot_name=slot_name,
original_variable_node_id=original_variable_node_id,
slot_variable_node_id=slot_variable_node_id))
slot_variables.setdefault(trackable, []).append(slot_variable_proto)
return slot_variables
def get_mapped_trackable(trackable, object_map):
"""Returns the mapped trackable if possible, otherwise returns trackable."""
if object_map is None:
return trackable
else:
return object_map.get(trackable, trackable)
def get_full_name(var):
"""Gets the full name of variable for name-based checkpoint compatiblity."""
# pylint: disable=protected-access
if (not (isinstance(var, variables.Variable) or
# Some objects do not subclass Variable but still act as one.
resource_variable_ops.is_resource_variable(var))):
return ""
if getattr(var, "_save_slice_info", None) is not None:
# Use getattr because `var._save_slice_info` may be set as `None`.
return var._save_slice_info.full_name
else:
return var._shared_name
# pylint: enable=protected-access
def add_checkpoint_values_check(object_graph_proto):
"""Determines which objects have checkpoint values and save this to the proto.
Args:
object_graph_proto: A `TrackableObjectGraph` proto.
"""
# Trackable -> set of all trackables that depend on it (the "parents").
# If a trackable has checkpoint values, then all of the parents can be
# marked as having checkpoint values.
parents = {}
checkpointed_trackables = object_identity.ObjectIdentitySet()
# First pass: build dictionary of parent objects and initial set of
# checkpointed trackables.
checkpointed_trackables = set()
for node_id, object_proto in enumerate(object_graph_proto.nodes):
if (object_proto.attributes or object_proto.slot_variables or
object_proto.HasField("registered_saver")):
checkpointed_trackables.add(node_id)
for child_proto in object_proto.children:
child = child_proto.node_id
if child not in parents:
parents[child] = set()
parents[child].add(node_id)
# Second pass: add all connected parents to set of checkpointed trackables.
to_visit = set()
to_visit.update(checkpointed_trackables)
while to_visit:
trackable = to_visit.pop()
if trackable not in parents:
# Some trackables may not have parents (e.g. slot variables).
continue
current_parents = parents.pop(trackable)
checkpointed_trackables.update(current_parents)
for parent in current_parents:
if parent in parents:
to_visit.add(parent)
for node_id, object_proto in enumerate(object_graph_proto.nodes):
object_proto.has_checkpoint_values.value = bool(
node_id in checkpointed_trackables)
def objects_ids_and_slot_variables_and_paths(graph_view):
"""Traverse the object graph and list all accessible objects.
Looks for `Trackable` objects which are dependencies of
`root_trackable`. Includes slot variables only if the variable they are
slotting for and the optimizer are dependencies of `root_trackable`
(i.e. if they would be saved with a checkpoint).
Args:
graph_view: A GraphView object.
Returns:
A tuple of (trackable objects, paths from root for each object,
object -> node id, slot variables, object_names)
"""
trackable_objects, node_paths = graph_view.breadth_first_traversal()
object_names = object_identity.ObjectIdentityDictionary()
for obj, path in node_paths.items():
object_names[obj] = trackable_utils.object_path_to_string(path)
node_ids = object_identity.ObjectIdentityDictionary()
for node_id, node in enumerate(trackable_objects):
node_ids[node] = node_id
slot_variables = serialize_slot_variables(
trackable_objects=trackable_objects,
node_ids=node_ids,
object_names=object_names)
return (trackable_objects, node_paths, node_ids, slot_variables, object_names)
def list_objects(graph_view):
"""Traverse the object graph and list all accessible objects."""
trackable_objects = objects_ids_and_slot_variables_and_paths(graph_view)[0]
return trackable_objects