Skip to content

Commit ab01a0d

Browse files
eellisonpytorchmergebot
authored andcommitted
Add memory estimator (#164738)
Original work by @ShatianWang, with lints applied. I am going to a few changes and add tests in subsequent prs but I want to preserve original commit first. Pull Request resolved: #164738 Approved by: https://github.com/IvanKobzarev ghstack dependencies: #164568, #164569, #164581
1 parent 801e282 commit ab01a0d

File tree

1 file changed

+386
-0
lines changed

1 file changed

+386
-0
lines changed
Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,386 @@
1+
import logging
2+
import operator
3+
from typing import Any, Callable
4+
5+
import torch.fx as fx
6+
from torch._functorch.partitioners import _size_of, get_default_op_list
7+
from torch.utils._ordered_set import OrderedSet
8+
9+
10+
log = logging.getLogger(__name__)
11+
12+
13+
def build_memory_profile(
14+
graph: fx.Graph,
15+
size_of: Callable[[fx.Node], int],
16+
is_releasable: Callable[[fx.Node], bool],
17+
) -> list[int]:
18+
"""
19+
Function to estimate the memory profile of an input FX graph.
20+
21+
Args:
22+
- graph (fx.Graph): The input FX graph for which the memory profile
23+
is to be estimated.
24+
- size_of (Callable[[fx.Node], int]): A function that returns
25+
the size of a given node.
26+
- is_releasable (Callable[[fx.Node], bool]): A function that
27+
determines if a node's memory can be released (e.g. primal nodes
28+
cannot be released).
29+
30+
Returns:
31+
- List[int]: A list representing the memory profile over the execution
32+
of the graph, where each entry corresponds to the memory usage at
33+
a particular point in the execution.
34+
"""
35+
36+
nodes = list(graph.nodes)
37+
op_types = get_default_op_list()
38+
39+
class AliasInfo:
40+
"""
41+
Class for storing and accessing alias information of a FX graph.
42+
43+
Attributes:
44+
- view_to_source: Maps view nodes to their source nodes
45+
- getitem_to_source: Maps getitem nodes to (source_node, key) tuples
46+
- source_to_getitems: Maps source nodes to dictionaries of
47+
{key: getitem_node, "unclaimed": None}
48+
- source_to_unclaimed_size: Maps source nodes to their storage size
49+
unclaimed by any getitem_nodes
50+
"""
51+
52+
def __init__(self, nodes: list[fx.Node]):
53+
"""
54+
Initialize the AliasInfo class with a list of FX graph nodes.
55+
56+
Args:
57+
- nodes (list[fx.Node]): A list of nodes from an FX graph,
58+
ordered in execution order.
59+
60+
The constructor analyzes the relationships between nodes in the FX graph
61+
to populate alias information. It identifies two types of alias nodes:
62+
getitem and view. For each view, it maps it to its source. For each
63+
getitem, it maps it to its source and key. It also populates mappings
64+
for source nodes to their getitems and calculates unclaimed storage sizes.
65+
66+
"""
67+
# For each view, we map it to its source.
68+
# Note that we treat getitems of a view (e.g. aten.split) as views.
69+
self.view_to_source: dict[fx.Node, fx.Node] = {}
70+
71+
# For each remaining getitem, we map it to its source and key.
72+
self.getitem_to_source: dict[fx.Node, tuple[fx.Node, Any]] = {}
73+
74+
# For each none-view source_node of getitems, we map it to a dictionary
75+
# in the form of {key: getitem_node, ..., "unclaimed": None}, where
76+
# "unclaimed" is a dummy key that represents all elements in the
77+
# source_node that is not claimed by any getitems.
78+
self.source_to_getitems: dict[fx.Node, dict[Any, fx.Node | None]] = {}
79+
80+
# For each none-view source_node of getitems with at least one unclaimed
81+
# elements, we map it to its unclaimed storage size.
82+
self.source_to_unclaimed_size: dict[fx.Node, int] = {}
83+
84+
for node in nodes:
85+
is_view = op_types.is_view(node)
86+
is_getitem = node.target is operator.getitem
87+
if not (is_view or is_getitem):
88+
continue
89+
assert not (is_view and is_getitem)
90+
assert node.args and isinstance(node.args[0], fx.Node)
91+
source = node.args[0]
92+
if is_view:
93+
assert not isinstance(source.meta["val"], list | tuple | dict)
94+
if source in self.view_to_source:
95+
source = self.view_to_source[source]
96+
self.view_to_source[node] = source
97+
if is_getitem:
98+
assert isinstance(source.meta["val"], list | tuple | dict)
99+
# Source of getitem can be a view (e.g. aten.split).
100+
if source in self.view_to_source:
101+
if source in self.view_to_source:
102+
source = self.view_to_source[source]
103+
# In this case, the getitem node should be treated
104+
# the same way as a regular view.
105+
self.view_to_source[node] = source
106+
continue
107+
# Source of getitem cannot be a getitem.
108+
assert source not in self.getitem_to_source
109+
110+
# There must be a second argument that specifies the key.
111+
assert len(node.args) >= 2
112+
key = node.args[1]
113+
self.getitem_to_source[node] = (source, key)
114+
115+
# Populate source_to_getitems.
116+
if source not in self.source_to_getitems:
117+
self.source_to_getitems[source] = {"unclaimed": None}
118+
assert key not in self.source_to_getitems[source]
119+
self.source_to_getitems[source][key] = node # type: ignore[index]
120+
121+
for source, getitem_map in self.source_to_getitems.items():
122+
unclaimed_source_size = size_of(source)
123+
for key, getitem_node in getitem_map.items():
124+
if key != "unclaimed" and getitem_node is not None:
125+
unclaimed_source_size -= size_of(getitem_node)
126+
assert unclaimed_source_size >= 0
127+
if unclaimed_source_size > 0:
128+
self.source_to_unclaimed_size[source] = unclaimed_source_size
129+
130+
def is_view(self, node: fx.Node) -> bool:
131+
return node in self.view_to_source
132+
133+
def is_getitem(self, node: fx.Node) -> bool:
134+
return node in self.getitem_to_source
135+
136+
def get_source(self, node: fx.Node) -> fx.Node | tuple[fx.Node, Any]:
137+
if self.is_view(node):
138+
return self.view_to_source[node]
139+
if self.is_getitem(node):
140+
return self.getitem_to_source[node]
141+
return node
142+
143+
def is_source_of_getitems(self, node: fx.Node) -> bool:
144+
return node in self.source_to_getitems
145+
146+
def get_storage_keys(self, source_node: fx.Node) -> list[Any]:
147+
assert source_node in self.source_to_getitems
148+
return list(self.source_to_getitems[source_node].keys())
149+
150+
def get_unclaimed_storage_size(self, source_node: fx.Node) -> int:
151+
return self.source_to_unclaimed_size.get(source_node, 0)
152+
153+
def get_getitem_by_key(self, source: fx.Node, key: Any) -> fx.Node | None:
154+
assert source in self.source_to_getitems
155+
assert key in self.source_to_getitems[source]
156+
return self.source_to_getitems[source][key]
157+
158+
def _get_last_usage(
159+
nodes: list[fx.Node], alias_info: AliasInfo
160+
) -> dict[fx.Node, list[tuple[fx.Node, Any]]]:
161+
"""
162+
Determine the last usage point of each storage. This information is used to
163+
identify when storages can be safely released.
164+
165+
Args:
166+
- nodes (list[fx.Node]): A list of nodes from the FX graph, ordered
167+
in execution order.
168+
- alias_info (AliasInfo): An instance of AliasInfo containing aliasing
169+
relationships between nodes in the graph.
170+
171+
Returns:
172+
- Dict[fx.Node, list[tuple[fx.Node, Optional[Any]]]]: A mapping
173+
from each node to a list of storages (represented as tuples of source node
174+
and key) that are last used by that node. This helps in identifying which
175+
storages can be released after the node's execution.
176+
177+
"""
178+
storage_to_last_user: dict[tuple[fx.Node, Any], fx.Node] = {}
179+
node_to_last_used_storages: dict[fx.Node, list[tuple[fx.Node, Any]]] = {}
180+
181+
def register_last_uses(use: fx.Node, user: fx.Node) -> None:
182+
keys: list[Any] = []
183+
if alias_info.is_view(use):
184+
# When use is a view (or getitem of a view),
185+
# user is essentially using the storage allocated at the
186+
# creation of the source of use.
187+
use = alias_info.get_source(use) # type: ignore[assignment]
188+
189+
if alias_info.is_source_of_getitems(use): # type: ignore[arg-type]
190+
# When use is a source of getitems, user is using all separate
191+
# storages of use.
192+
keys.extend(alias_info.get_storage_keys(use)) # type: ignore[arg-type]
193+
elif alias_info.is_getitem(use): # type: ignore[arg-type]
194+
# When use is a getitem, user is essentially using a separate
195+
# storage of the source of use specified by key.
196+
use, key = alias_info.get_source(use) # type: ignore[assignment,misc]
197+
keys.append(key)
198+
else:
199+
keys.append(None)
200+
201+
assert keys
202+
203+
for key in keys:
204+
if (use, key) not in storage_to_last_user: # type: ignore[comparison-overlap]
205+
storage_to_last_user[(use, key)] = user # type: ignore[index]
206+
node_to_last_used_storages.setdefault(user, []).append((use, key)) # type: ignore[arg-type]
207+
208+
for node in reversed(nodes):
209+
fx.node.map_arg(node.args, lambda n: register_last_uses(n, node))
210+
fx.node.map_arg(node.kwargs, lambda n: register_last_uses(n, node))
211+
212+
return node_to_last_used_storages
213+
214+
alias_info = AliasInfo(nodes)
215+
node_to_last_used_storages = _get_last_usage(nodes, alias_info)
216+
217+
# Initialize memory profile
218+
memory_profile = [0]
219+
220+
# Process the graph
221+
for node in nodes:
222+
if node.op == "placeholder":
223+
out_mem = size_of(node)
224+
memory_profile[0] += out_mem
225+
elif node.op == "output":
226+
pass
227+
elif (
228+
node.op == "call_function"
229+
or node.op == "call_module"
230+
or node.op == "call_method"
231+
):
232+
# Aliases don't allocate new memory
233+
if alias_info.is_view(node) or alias_info.is_getitem(node):
234+
memory_profile.append(memory_profile[-1])
235+
else:
236+
out_mem = size_of(node)
237+
memory_profile.append(memory_profile[-1] + out_mem)
238+
239+
# Process storages that are no longer needed after this operation
240+
storages_to_release = [
241+
(use, key)
242+
for use, key in node_to_last_used_storages.get(node, [])
243+
if is_releasable(use)
244+
]
245+
freed_memory = 0
246+
for node_to_release, key in storages_to_release:
247+
released_memory_size = 0
248+
if key is None:
249+
released_memory_size = size_of(node_to_release)
250+
elif key == "unclaimed":
251+
released_memory_size = alias_info.get_unclaimed_storage_size(
252+
node_to_release
253+
)
254+
else:
255+
getitem_node = alias_info.get_getitem_by_key(node_to_release, key)
256+
if getitem_node is not None:
257+
released_memory_size = size_of(getitem_node)
258+
freed_memory += released_memory_size
259+
260+
assert freed_memory >= 0
261+
memory_profile.append(memory_profile[-1] - freed_memory)
262+
return memory_profile
263+
264+
265+
def get_fwd_bwd_interactions(
266+
fwd_graph: fx.Graph,
267+
bwd_graph: fx.Graph,
268+
size_of: Callable[[fx.Node], int],
269+
) -> tuple[int, OrderedSet[str]]:
270+
"""
271+
Analyze the interactions between the forward (fwd) and backward (bwd) graphs
272+
to determine memory usage characteristics.
273+
274+
Args:
275+
- fwd_graph (fx.Graph): The forward graph representing the forward pass.
276+
- bwd_graph (fx.Graph): The backward graph representing the backward pass.
277+
- size_of (Callable[[fx.Node], int]): A function that returns the size
278+
of a given node.
279+
280+
Returns:
281+
- tuple[int, Set[fx.Node]]: A tuple containing:
282+
1. The baseline memory usage during the backward pass, accounting for
283+
nodes that persist from the forward pass (i.e., in fwd output but
284+
not in bwd input).
285+
2. A set of nodes whose storage cannot be released during the bwd pass.
286+
These include nodes that are views of primals or in bwd input
287+
but not in fwd output.
288+
"""
289+
290+
def get_nodes_in_output(graph: fx.Graph) -> OrderedSet[fx.Node]:
291+
"""
292+
Get the nodes in the output of a graph.
293+
294+
Args:
295+
- graph (fx.Graph): The input graph.
296+
297+
Returns:
298+
- list[fx.Node]: A list of nodes in the output of the graph.
299+
"""
300+
output_node = list(graph.nodes)[-1]
301+
assert output_node.op == "output"
302+
nodes_in_output: OrderedSet[fx.Node] = OrderedSet()
303+
304+
def add_node(node: fx.Node) -> None:
305+
nodes_in_output.add(node)
306+
307+
# Using map_arg since output_node.args[0] can be of different types
308+
# e.g. tuple, list, dict, fx.Node, etc.
309+
fx.node.map_arg(output_node.args[0], lambda n: add_node(n))
310+
return nodes_in_output
311+
312+
op_types = get_default_op_list()
313+
314+
bwd_baseline_memory = 0
315+
# placeholder nodes besides primals of the bwd_graph that should also
316+
# not be deleted during memory profile estimation of the bwd_graph
317+
do_not_delete: OrderedSet[str] = OrderedSet()
318+
319+
fwd_outputs = {}
320+
for node in get_nodes_in_output(fwd_graph):
321+
is_view_of_primal = False
322+
if op_types.is_view(node):
323+
source = node.args[0]
324+
if isinstance(source, fx.Node) and source.name.startswith("primals"):
325+
is_view_of_primal = True
326+
fwd_outputs[node.name] = (size_of(node), is_view_of_primal)
327+
bwd_inputs: OrderedSet[str] = OrderedSet()
328+
for node in bwd_graph.nodes:
329+
if node.op == "placeholder":
330+
bwd_inputs.add(node.name)
331+
if node.name.startswith("view"):
332+
# if node is a view, then it has to be in fwd_outputs
333+
assert node.name in fwd_outputs
334+
_, is_view_of_primal = fwd_outputs[node.name]
335+
if is_view_of_primal:
336+
# Add node to do_not_delete because it is a view of a primal
337+
do_not_delete.add(node.name)
338+
339+
# if node is not in fwd_outputs, then add it to do_not_delete
340+
if node.name not in fwd_outputs:
341+
do_not_delete.add(node.name)
342+
343+
# nodes that are in fwd_outputs but not in bwd_inputs take memory storage
344+
# throughout the bwd pass
345+
for name, (size, _) in fwd_outputs.items():
346+
if name not in bwd_inputs:
347+
bwd_baseline_memory += size
348+
349+
return bwd_baseline_memory, do_not_delete
350+
351+
352+
def get_peak_memory(
353+
fwd_graph: fx.Graph,
354+
bwd_graph: fx.Graph,
355+
) -> int:
356+
def _safe_size_of(n: fx.Node) -> int:
357+
try:
358+
return _size_of(n)
359+
except Exception:
360+
log.warning("Failed size_of(%s). Returning 0 instead.", n)
361+
return 0
362+
363+
def _is_releasable(n: fx.Node) -> bool:
364+
# Storages of primals cannot be released during fwd or bwd pass.
365+
return not n.name.startswith("primals")
366+
367+
fwd_peak_memory = max(
368+
build_memory_profile(fwd_graph, _safe_size_of, _is_releasable)
369+
)
370+
371+
bwd_baseline_memory, bwd_do_not_delete = get_fwd_bwd_interactions(
372+
fwd_graph, bwd_graph, _safe_size_of
373+
)
374+
375+
def _is_bwd_releasable(n: fx.Node) -> bool:
376+
# Storages of nodes in bwd_do_not_delete cannot be released
377+
# during the bwd pass.
378+
return _is_releasable(n) and n.name not in bwd_do_not_delete
379+
380+
bwd_peak_memory = bwd_baseline_memory + max(
381+
build_memory_profile(bwd_graph, _safe_size_of, _is_bwd_releasable)
382+
)
383+
return max(
384+
fwd_peak_memory,
385+
bwd_peak_memory,
386+
)

0 commit comments

Comments
 (0)