Skip to content

Commit 00c8298

Browse files
basilwongpytorchmergebot
authored andcommitted
Log Full Knapsack Problem Information (#140757)
Summary: When AOT_PARTITIONER_DEBUG is set to 1 and debug logging is turned on we can now log the full input and output for each knapsack problem. Differential Revision: D65633086 Pull Request resolved: #140757 Approved by: https://github.com/jansel
1 parent 408ad45 commit 00c8298

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

torch/_functorch/partitioners.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1703,6 +1703,44 @@ def get_saved_values_knapsack(memory_budget, node_info, joint_graph):
17031703
node_info,
17041704
all_recomputable_banned_nodes,
17051705
)
1706+
if AOT_PARTITIONER_DEBUG:
1707+
max_runtime = max(
1708+
runtimes_banned_nodes
1709+
) # For normalizing runtimes in logs
1710+
input_summary = [
1711+
f"\n\t\t\t{index}, {memory}, {runtime / max_runtime}, {node.op}, {node.target}, {node.meta}, {node.args}"
1712+
for index, (memory, runtime, node) in enumerate(
1713+
zip(
1714+
memories_banned_nodes,
1715+
runtimes_banned_nodes,
1716+
all_recomputable_banned_nodes,
1717+
)
1718+
)
1719+
]
1720+
joint_graph_nodes = [node.name for node in joint_graph.nodes]
1721+
joint_graph_edges = [
1722+
(inp.name, node.name)
1723+
for node in joint_graph.nodes
1724+
for inp in node.all_input_nodes
1725+
]
1726+
knapsack_summary = f"""
1727+
Activation Checkpointing - Knapsack Problem Summary:
1728+
Input:
1729+
Solver: {config.activation_memory_budget_solver}
1730+
Max Memory: {max(config.activation_memory_budget, 0)}
1731+
Graph Nodes: {joint_graph_nodes}
1732+
Graph Edges: {joint_graph_edges}
1733+
(Index, Memory, Runtime, Node.Op, Node.Target, Metadata): {"".join(input_summary)}
1734+
Output:
1735+
Expected Runtime: {expected_runtime}
1736+
Saved Nodes: {saved_node_idxs}
1737+
Recomputable Nodes: {recomputable_node_idxs}
1738+
"""
1739+
torch._logging.trace_structured(
1740+
name="artifact",
1741+
payload_fn=lambda: knapsack_summary,
1742+
)
1743+
log.info(knapsack_summary)
17061744
dont_ban = set()
17071745
for idx in recomputable_node_idxs:
17081746
# if idx in all_recomputable_banned_nodes:

0 commit comments

Comments
 (0)