|
| 1 | +import logging |
| 2 | +import operator |
| 3 | +from typing import Any, Callable |
| 4 | + |
| 5 | +import torch.fx as fx |
| 6 | +from torch._functorch.partitioners import _size_of, get_default_op_list |
| 7 | +from torch.utils._ordered_set import OrderedSet |
| 8 | + |
| 9 | + |
| 10 | +log = logging.getLogger(__name__) |
| 11 | + |
| 12 | + |
| 13 | +def build_memory_profile( |
| 14 | + graph: fx.Graph, |
| 15 | + size_of: Callable[[fx.Node], int], |
| 16 | + is_releasable: Callable[[fx.Node], bool], |
| 17 | +) -> list[int]: |
| 18 | + """ |
| 19 | + Function to estimate the memory profile of an input FX graph. |
| 20 | +
|
| 21 | + Args: |
| 22 | + - graph (fx.Graph): The input FX graph for which the memory profile |
| 23 | + is to be estimated. |
| 24 | + - size_of (Callable[[fx.Node], int]): A function that returns |
| 25 | + the size of a given node. |
| 26 | + - is_releasable (Callable[[fx.Node], bool]): A function that |
| 27 | + determines if a node's memory can be released (e.g. primal nodes |
| 28 | + cannot be released). |
| 29 | +
|
| 30 | + Returns: |
| 31 | + - List[int]: A list representing the memory profile over the execution |
| 32 | + of the graph, where each entry corresponds to the memory usage at |
| 33 | + a particular point in the execution. |
| 34 | + """ |
| 35 | + |
| 36 | + nodes = list(graph.nodes) |
| 37 | + op_types = get_default_op_list() |
| 38 | + |
| 39 | + class AliasInfo: |
| 40 | + """ |
| 41 | + Class for storing and accessing alias information of a FX graph. |
| 42 | +
|
| 43 | + Attributes: |
| 44 | + - view_to_source: Maps view nodes to their source nodes |
| 45 | + - getitem_to_source: Maps getitem nodes to (source_node, key) tuples |
| 46 | + - source_to_getitems: Maps source nodes to dictionaries of |
| 47 | + {key: getitem_node, "unclaimed": None} |
| 48 | + - source_to_unclaimed_size: Maps source nodes to their storage size |
| 49 | + unclaimed by any getitem_nodes |
| 50 | + """ |
| 51 | + |
| 52 | + def __init__(self, nodes: list[fx.Node]): |
| 53 | + """ |
| 54 | + Initialize the AliasInfo class with a list of FX graph nodes. |
| 55 | +
|
| 56 | + Args: |
| 57 | + - nodes (list[fx.Node]): A list of nodes from an FX graph, |
| 58 | + ordered in execution order. |
| 59 | +
|
| 60 | + The constructor analyzes the relationships between nodes in the FX graph |
| 61 | + to populate alias information. It identifies two types of alias nodes: |
| 62 | + getitem and view. For each view, it maps it to its source. For each |
| 63 | + getitem, it maps it to its source and key. It also populates mappings |
| 64 | + for source nodes to their getitems and calculates unclaimed storage sizes. |
| 65 | +
|
| 66 | + """ |
| 67 | + # For each view, we map it to its source. |
| 68 | + # Note that we treat getitems of a view (e.g. aten.split) as views. |
| 69 | + self.view_to_source: dict[fx.Node, fx.Node] = {} |
| 70 | + |
| 71 | + # For each remaining getitem, we map it to its source and key. |
| 72 | + self.getitem_to_source: dict[fx.Node, tuple[fx.Node, Any]] = {} |
| 73 | + |
| 74 | + # For each none-view source_node of getitems, we map it to a dictionary |
| 75 | + # in the form of {key: getitem_node, ..., "unclaimed": None}, where |
| 76 | + # "unclaimed" is a dummy key that represents all elements in the |
| 77 | + # source_node that is not claimed by any getitems. |
| 78 | + self.source_to_getitems: dict[fx.Node, dict[Any, fx.Node | None]] = {} |
| 79 | + |
| 80 | + # For each none-view source_node of getitems with at least one unclaimed |
| 81 | + # elements, we map it to its unclaimed storage size. |
| 82 | + self.source_to_unclaimed_size: dict[fx.Node, int] = {} |
| 83 | + |
| 84 | + for node in nodes: |
| 85 | + is_view = op_types.is_view(node) |
| 86 | + is_getitem = node.target is operator.getitem |
| 87 | + if not (is_view or is_getitem): |
| 88 | + continue |
| 89 | + assert not (is_view and is_getitem) |
| 90 | + assert node.args and isinstance(node.args[0], fx.Node) |
| 91 | + source = node.args[0] |
| 92 | + if is_view: |
| 93 | + assert not isinstance(source.meta["val"], list | tuple | dict) |
| 94 | + if source in self.view_to_source: |
| 95 | + source = self.view_to_source[source] |
| 96 | + self.view_to_source[node] = source |
| 97 | + if is_getitem: |
| 98 | + assert isinstance(source.meta["val"], list | tuple | dict) |
| 99 | + # Source of getitem can be a view (e.g. aten.split). |
| 100 | + if source in self.view_to_source: |
| 101 | + if source in self.view_to_source: |
| 102 | + source = self.view_to_source[source] |
| 103 | + # In this case, the getitem node should be treated |
| 104 | + # the same way as a regular view. |
| 105 | + self.view_to_source[node] = source |
| 106 | + continue |
| 107 | + # Source of getitem cannot be a getitem. |
| 108 | + assert source not in self.getitem_to_source |
| 109 | + |
| 110 | + # There must be a second argument that specifies the key. |
| 111 | + assert len(node.args) >= 2 |
| 112 | + key = node.args[1] |
| 113 | + self.getitem_to_source[node] = (source, key) |
| 114 | + |
| 115 | + # Populate source_to_getitems. |
| 116 | + if source not in self.source_to_getitems: |
| 117 | + self.source_to_getitems[source] = {"unclaimed": None} |
| 118 | + assert key not in self.source_to_getitems[source] |
| 119 | + self.source_to_getitems[source][key] = node # type: ignore[index] |
| 120 | + |
| 121 | + for source, getitem_map in self.source_to_getitems.items(): |
| 122 | + unclaimed_source_size = size_of(source) |
| 123 | + for key, getitem_node in getitem_map.items(): |
| 124 | + if key != "unclaimed" and getitem_node is not None: |
| 125 | + unclaimed_source_size -= size_of(getitem_node) |
| 126 | + assert unclaimed_source_size >= 0 |
| 127 | + if unclaimed_source_size > 0: |
| 128 | + self.source_to_unclaimed_size[source] = unclaimed_source_size |
| 129 | + |
| 130 | + def is_view(self, node: fx.Node) -> bool: |
| 131 | + return node in self.view_to_source |
| 132 | + |
| 133 | + def is_getitem(self, node: fx.Node) -> bool: |
| 134 | + return node in self.getitem_to_source |
| 135 | + |
| 136 | + def get_source(self, node: fx.Node) -> fx.Node | tuple[fx.Node, Any]: |
| 137 | + if self.is_view(node): |
| 138 | + return self.view_to_source[node] |
| 139 | + if self.is_getitem(node): |
| 140 | + return self.getitem_to_source[node] |
| 141 | + return node |
| 142 | + |
| 143 | + def is_source_of_getitems(self, node: fx.Node) -> bool: |
| 144 | + return node in self.source_to_getitems |
| 145 | + |
| 146 | + def get_storage_keys(self, source_node: fx.Node) -> list[Any]: |
| 147 | + assert source_node in self.source_to_getitems |
| 148 | + return list(self.source_to_getitems[source_node].keys()) |
| 149 | + |
| 150 | + def get_unclaimed_storage_size(self, source_node: fx.Node) -> int: |
| 151 | + return self.source_to_unclaimed_size.get(source_node, 0) |
| 152 | + |
| 153 | + def get_getitem_by_key(self, source: fx.Node, key: Any) -> fx.Node | None: |
| 154 | + assert source in self.source_to_getitems |
| 155 | + assert key in self.source_to_getitems[source] |
| 156 | + return self.source_to_getitems[source][key] |
| 157 | + |
| 158 | + def _get_last_usage( |
| 159 | + nodes: list[fx.Node], alias_info: AliasInfo |
| 160 | + ) -> dict[fx.Node, list[tuple[fx.Node, Any]]]: |
| 161 | + """ |
| 162 | + Determine the last usage point of each storage. This information is used to |
| 163 | + identify when storages can be safely released. |
| 164 | +
|
| 165 | + Args: |
| 166 | + - nodes (list[fx.Node]): A list of nodes from the FX graph, ordered |
| 167 | + in execution order. |
| 168 | + - alias_info (AliasInfo): An instance of AliasInfo containing aliasing |
| 169 | + relationships between nodes in the graph. |
| 170 | +
|
| 171 | + Returns: |
| 172 | + - Dict[fx.Node, list[tuple[fx.Node, Optional[Any]]]]: A mapping |
| 173 | + from each node to a list of storages (represented as tuples of source node |
| 174 | + and key) that are last used by that node. This helps in identifying which |
| 175 | + storages can be released after the node's execution. |
| 176 | +
|
| 177 | + """ |
| 178 | + storage_to_last_user: dict[tuple[fx.Node, Any], fx.Node] = {} |
| 179 | + node_to_last_used_storages: dict[fx.Node, list[tuple[fx.Node, Any]]] = {} |
| 180 | + |
| 181 | + def register_last_uses(use: fx.Node, user: fx.Node) -> None: |
| 182 | + keys: list[Any] = [] |
| 183 | + if alias_info.is_view(use): |
| 184 | + # When use is a view (or getitem of a view), |
| 185 | + # user is essentially using the storage allocated at the |
| 186 | + # creation of the source of use. |
| 187 | + use = alias_info.get_source(use) # type: ignore[assignment] |
| 188 | + |
| 189 | + if alias_info.is_source_of_getitems(use): # type: ignore[arg-type] |
| 190 | + # When use is a source of getitems, user is using all separate |
| 191 | + # storages of use. |
| 192 | + keys.extend(alias_info.get_storage_keys(use)) # type: ignore[arg-type] |
| 193 | + elif alias_info.is_getitem(use): # type: ignore[arg-type] |
| 194 | + # When use is a getitem, user is essentially using a separate |
| 195 | + # storage of the source of use specified by key. |
| 196 | + use, key = alias_info.get_source(use) # type: ignore[assignment,misc] |
| 197 | + keys.append(key) |
| 198 | + else: |
| 199 | + keys.append(None) |
| 200 | + |
| 201 | + assert keys |
| 202 | + |
| 203 | + for key in keys: |
| 204 | + if (use, key) not in storage_to_last_user: # type: ignore[comparison-overlap] |
| 205 | + storage_to_last_user[(use, key)] = user # type: ignore[index] |
| 206 | + node_to_last_used_storages.setdefault(user, []).append((use, key)) # type: ignore[arg-type] |
| 207 | + |
| 208 | + for node in reversed(nodes): |
| 209 | + fx.node.map_arg(node.args, lambda n: register_last_uses(n, node)) |
| 210 | + fx.node.map_arg(node.kwargs, lambda n: register_last_uses(n, node)) |
| 211 | + |
| 212 | + return node_to_last_used_storages |
| 213 | + |
| 214 | + alias_info = AliasInfo(nodes) |
| 215 | + node_to_last_used_storages = _get_last_usage(nodes, alias_info) |
| 216 | + |
| 217 | + # Initialize memory profile |
| 218 | + memory_profile = [0] |
| 219 | + |
| 220 | + # Process the graph |
| 221 | + for node in nodes: |
| 222 | + if node.op == "placeholder": |
| 223 | + out_mem = size_of(node) |
| 224 | + memory_profile[0] += out_mem |
| 225 | + elif node.op == "output": |
| 226 | + pass |
| 227 | + elif ( |
| 228 | + node.op == "call_function" |
| 229 | + or node.op == "call_module" |
| 230 | + or node.op == "call_method" |
| 231 | + ): |
| 232 | + # Aliases don't allocate new memory |
| 233 | + if alias_info.is_view(node) or alias_info.is_getitem(node): |
| 234 | + memory_profile.append(memory_profile[-1]) |
| 235 | + else: |
| 236 | + out_mem = size_of(node) |
| 237 | + memory_profile.append(memory_profile[-1] + out_mem) |
| 238 | + |
| 239 | + # Process storages that are no longer needed after this operation |
| 240 | + storages_to_release = [ |
| 241 | + (use, key) |
| 242 | + for use, key in node_to_last_used_storages.get(node, []) |
| 243 | + if is_releasable(use) |
| 244 | + ] |
| 245 | + freed_memory = 0 |
| 246 | + for node_to_release, key in storages_to_release: |
| 247 | + released_memory_size = 0 |
| 248 | + if key is None: |
| 249 | + released_memory_size = size_of(node_to_release) |
| 250 | + elif key == "unclaimed": |
| 251 | + released_memory_size = alias_info.get_unclaimed_storage_size( |
| 252 | + node_to_release |
| 253 | + ) |
| 254 | + else: |
| 255 | + getitem_node = alias_info.get_getitem_by_key(node_to_release, key) |
| 256 | + if getitem_node is not None: |
| 257 | + released_memory_size = size_of(getitem_node) |
| 258 | + freed_memory += released_memory_size |
| 259 | + |
| 260 | + assert freed_memory >= 0 |
| 261 | + memory_profile.append(memory_profile[-1] - freed_memory) |
| 262 | + return memory_profile |
| 263 | + |
| 264 | + |
| 265 | +def get_fwd_bwd_interactions( |
| 266 | + fwd_graph: fx.Graph, |
| 267 | + bwd_graph: fx.Graph, |
| 268 | + size_of: Callable[[fx.Node], int], |
| 269 | +) -> tuple[int, OrderedSet[str]]: |
| 270 | + """ |
| 271 | + Analyze the interactions between the forward (fwd) and backward (bwd) graphs |
| 272 | + to determine memory usage characteristics. |
| 273 | +
|
| 274 | + Args: |
| 275 | + - fwd_graph (fx.Graph): The forward graph representing the forward pass. |
| 276 | + - bwd_graph (fx.Graph): The backward graph representing the backward pass. |
| 277 | + - size_of (Callable[[fx.Node], int]): A function that returns the size |
| 278 | + of a given node. |
| 279 | +
|
| 280 | + Returns: |
| 281 | + - tuple[int, Set[fx.Node]]: A tuple containing: |
| 282 | + 1. The baseline memory usage during the backward pass, accounting for |
| 283 | + nodes that persist from the forward pass (i.e., in fwd output but |
| 284 | + not in bwd input). |
| 285 | + 2. A set of nodes whose storage cannot be released during the bwd pass. |
| 286 | + These include nodes that are views of primals or in bwd input |
| 287 | + but not in fwd output. |
| 288 | + """ |
| 289 | + |
| 290 | + def get_nodes_in_output(graph: fx.Graph) -> OrderedSet[fx.Node]: |
| 291 | + """ |
| 292 | + Get the nodes in the output of a graph. |
| 293 | +
|
| 294 | + Args: |
| 295 | + - graph (fx.Graph): The input graph. |
| 296 | +
|
| 297 | + Returns: |
| 298 | + - list[fx.Node]: A list of nodes in the output of the graph. |
| 299 | + """ |
| 300 | + output_node = list(graph.nodes)[-1] |
| 301 | + assert output_node.op == "output" |
| 302 | + nodes_in_output: OrderedSet[fx.Node] = OrderedSet() |
| 303 | + |
| 304 | + def add_node(node: fx.Node) -> None: |
| 305 | + nodes_in_output.add(node) |
| 306 | + |
| 307 | + # Using map_arg since output_node.args[0] can be of different types |
| 308 | + # e.g. tuple, list, dict, fx.Node, etc. |
| 309 | + fx.node.map_arg(output_node.args[0], lambda n: add_node(n)) |
| 310 | + return nodes_in_output |
| 311 | + |
| 312 | + op_types = get_default_op_list() |
| 313 | + |
| 314 | + bwd_baseline_memory = 0 |
| 315 | + # placeholder nodes besides primals of the bwd_graph that should also |
| 316 | + # not be deleted during memory profile estimation of the bwd_graph |
| 317 | + do_not_delete: OrderedSet[str] = OrderedSet() |
| 318 | + |
| 319 | + fwd_outputs = {} |
| 320 | + for node in get_nodes_in_output(fwd_graph): |
| 321 | + is_view_of_primal = False |
| 322 | + if op_types.is_view(node): |
| 323 | + source = node.args[0] |
| 324 | + if isinstance(source, fx.Node) and source.name.startswith("primals"): |
| 325 | + is_view_of_primal = True |
| 326 | + fwd_outputs[node.name] = (size_of(node), is_view_of_primal) |
| 327 | + bwd_inputs: OrderedSet[str] = OrderedSet() |
| 328 | + for node in bwd_graph.nodes: |
| 329 | + if node.op == "placeholder": |
| 330 | + bwd_inputs.add(node.name) |
| 331 | + if node.name.startswith("view"): |
| 332 | + # if node is a view, then it has to be in fwd_outputs |
| 333 | + assert node.name in fwd_outputs |
| 334 | + _, is_view_of_primal = fwd_outputs[node.name] |
| 335 | + if is_view_of_primal: |
| 336 | + # Add node to do_not_delete because it is a view of a primal |
| 337 | + do_not_delete.add(node.name) |
| 338 | + |
| 339 | + # if node is not in fwd_outputs, then add it to do_not_delete |
| 340 | + if node.name not in fwd_outputs: |
| 341 | + do_not_delete.add(node.name) |
| 342 | + |
| 343 | + # nodes that are in fwd_outputs but not in bwd_inputs take memory storage |
| 344 | + # throughout the bwd pass |
| 345 | + for name, (size, _) in fwd_outputs.items(): |
| 346 | + if name not in bwd_inputs: |
| 347 | + bwd_baseline_memory += size |
| 348 | + |
| 349 | + return bwd_baseline_memory, do_not_delete |
| 350 | + |
| 351 | + |
| 352 | +def get_peak_memory( |
| 353 | + fwd_graph: fx.Graph, |
| 354 | + bwd_graph: fx.Graph, |
| 355 | +) -> int: |
| 356 | + def _safe_size_of(n: fx.Node) -> int: |
| 357 | + try: |
| 358 | + return _size_of(n) |
| 359 | + except Exception: |
| 360 | + log.warning("Failed size_of(%s). Returning 0 instead.", n) |
| 361 | + return 0 |
| 362 | + |
| 363 | + def _is_releasable(n: fx.Node) -> bool: |
| 364 | + # Storages of primals cannot be released during fwd or bwd pass. |
| 365 | + return not n.name.startswith("primals") |
| 366 | + |
| 367 | + fwd_peak_memory = max( |
| 368 | + build_memory_profile(fwd_graph, _safe_size_of, _is_releasable) |
| 369 | + ) |
| 370 | + |
| 371 | + bwd_baseline_memory, bwd_do_not_delete = get_fwd_bwd_interactions( |
| 372 | + fwd_graph, bwd_graph, _safe_size_of |
| 373 | + ) |
| 374 | + |
| 375 | + def _is_bwd_releasable(n: fx.Node) -> bool: |
| 376 | + # Storages of nodes in bwd_do_not_delete cannot be released |
| 377 | + # during the bwd pass. |
| 378 | + return _is_releasable(n) and n.name not in bwd_do_not_delete |
| 379 | + |
| 380 | + bwd_peak_memory = bwd_baseline_memory + max( |
| 381 | + build_memory_profile(bwd_graph, _safe_size_of, _is_bwd_releasable) |
| 382 | + ) |
| 383 | + return max( |
| 384 | + fwd_peak_memory, |
| 385 | + bwd_peak_memory, |
| 386 | + ) |
0 commit comments