Skip to content

Commit 743c385

Browse files
zhaojuanmaopytorchmergebot
authored andcommitted
refactor show_traces in memory_tracker (#90145)
refactor show_tracers in memory_tracker to make it plot multiple figures and also can load serialized stats and then plot figures Pull Request resolved: #90145 Approved by: https://github.com/rohan-varma
1 parent b6bb726 commit 743c385

File tree

1 file changed

+48
-29
lines changed

1 file changed

+48
-29
lines changed

torch/distributed/_tools/memory_tracker.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -146,36 +146,55 @@ def summary(self, top: int = 20) -> None:
146146
print("------------------------------------------------")
147147

148148
@no_type_check
149-
def show_traces(self) -> None:
150-
"""
151-
Show the traces of ``memory_allocated``, ``memory_active`` and ``memory_reserved`` at
152-
operator level and the marker 'fw_bw_boundary' at the boundary of forward pass
153-
and backward pass.
154-
"""
149+
def show_traces(self, path: str = "") -> None:
150+
from itertools import chain
151+
155152
import matplotlib.pyplot as plt
156153

157-
y_1 = [mb for (name, mb) in self.memories_allocated.values()]
158-
y_2 = [mb for (name, mb) in self.memories_active.values()]
159-
y_3 = [mb for (name, mb) in self.memories_reserved.values()]
160-
min_val = min(y_1 + y_2 + y_3)
161-
max_val = max(y_1 + y_2 + y_3)
162-
x = list(i for i in range(len(y_1)))
163-
fig = plt.figure(figsize=(16, 8))
164-
plt.plot(x, list(y_1), label="memory_allocated")
165-
plt.plot(x, list(y_2), label="memory_active")
166-
plt.plot(x, list(y_3), label="memory_reserved")
167-
plt.xlabel("# Operator Calls")
168-
plt.ylabel("Memory (MB)")
169-
for marker_name, marker in self._markers.items():
170-
if marker_name == "fw_bw_boundary":
171-
plt.plot(
172-
[marker, marker], [min_val, max_val], "r", lw=2, label=marker_name
173-
)
174-
else:
175-
plt.plot(
176-
[marker, marker], [min_val, max_val], "k-", lw=2, label=marker_name
177-
)
178-
plt.legend()
154+
def _plot_figure(x, y_values, labels):
155+
min_val = min(list(chain(*y_values))) * 0.999
156+
max_val = max(list(chain(*y_values))) * 1.001
157+
plt.figure()
158+
for y, label in zip(y_values, labels):
159+
plt.plot(x, y, label=label)
160+
plt.xlabel("# Operator Calls")
161+
plt.ylabel("Memory (MB)")
162+
plt.legend()
163+
for marker_name, marker in self._markers.items():
164+
if marker_name == "fw_bw_boundary":
165+
plt.plot(
166+
[marker, marker],
167+
[min_val, max_val],
168+
"r",
169+
lw=2,
170+
label=marker_name,
171+
)
172+
else:
173+
plt.plot(
174+
[marker, marker],
175+
[min_val, max_val],
176+
"k-",
177+
lw=2,
178+
label=marker_name,
179+
)
180+
181+
if path != "":
182+
self.load(path)
183+
184+
y_1 = [gb for (name, gb) in self.memories_allocated.values()]
185+
y_2 = [gb for (name, gb) in self.memories_active.values()]
186+
y_3 = [gb for (name, gb) in self.memories_reserved.values()]
187+
x = list(range(len(y_1)))
188+
# Split figures when there is big difference between
189+
# "reserved_memory" and "allocated_memory" or "active_memory".
190+
_plot_figure(
191+
x,
192+
[list(y_1), list(y_2), list(y_3)],
193+
["allocated_memory", "active_memory", "reserved_memory"],
194+
)
195+
_plot_figure(x, [list(y_1)], ["allocated_memory"])
196+
_plot_figure(x, [list(y_2)], ["active_memory"])
197+
_plot_figure(x, [list(y_3)], ["reserved_memory"])
179198

180199
def save_stats(self, path: str) -> None:
181200
"""
@@ -190,7 +209,7 @@ def save_stats(self, path: str) -> None:
190209
}
191210

192211
with open(path, "wb") as f:
193-
pickle.dump(stats, f)
212+
pickle.dump(stats, f, pickle.HIGHEST_PROTOCOL)
194213

195214
def load(self, path: str) -> None:
196215
"""

0 commit comments

Comments
 (0)