Skip to content

Commit cef13eb

Browse files
Taylor Robiepytorchmergebot
authored andcommitted
[Profiler] Memory profiler part 1: Gradient identification (#86802)
There are multiple ways to indentify that a Tensor is a gradient. (A subset of which also give additional context.) So to start off I've made a utility to handle that determination. Differential Revision: [D39920730](https://our.internmc.facebook.com/intern/diff/D39920730/) Pull Request resolved: #86802 Approved by: https://github.com/chaekit
1 parent c0e6b43 commit cef13eb

File tree

5 files changed

+400
-9
lines changed

5 files changed

+400
-9
lines changed

mypy-strict.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ files =
4040
.github,
4141
benchmarks/instruction_counts,
4242
tools,
43+
torch/profiler/_memory_profiler.py,
4344
torch/utils/_pytree.py,
4445
torch/utils/benchmark/utils/common.py,
4546
torch/utils/benchmark/utils/timer.py,
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Owner(s): ["oncall: profiler"]
2+
import functools
3+
from typing import Iterator, Optional
4+
5+
import torch
6+
from torch._C._profiler import _EventType
7+
from torch.profiler import _memory_profiler, _utils
8+
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
9+
10+
11+
profile = functools.partial(
12+
torch.profiler.profile, record_shapes=True, profile_memory=True, with_stack=True
13+
)
14+
15+
16+
class ScaleLayer(torch.nn.Module):
17+
def __init__(self) -> None:
18+
super().__init__()
19+
self.scale = torch.nn.Parameter(torch.rand(()), requires_grad=True)
20+
21+
def forward(self, x: torch.Tensor) -> torch.Tensor:
22+
return x * self.scale
23+
24+
25+
@skipIfTorchDynamo("TorchDynamo changes Python calls that memory profiling relies on.")
26+
class TestIdentifyGradients(TestCase):
27+
def gradient_detected(
28+
self,
29+
prof: torch.profiler.profile,
30+
ctx: _EventType,
31+
grad_tensor: torch.Tensor,
32+
parameter: Optional[torch.Tensor] = None,
33+
) -> None:
34+
35+
# This is not an exhaustive check, but for the purpose of unit testing
36+
# it is sufficient.
37+
def key_matches_tensor(key, tensor) -> bool:
38+
# Vacuous case.
39+
if tensor is None:
40+
return True
41+
42+
if key is None:
43+
return False
44+
45+
return tensor.storage().data_ptr() == key.storage.ptr
46+
47+
tree = prof.profiler.kineto_results.experimental_event_tree()
48+
for node in _utils.traverse_dfs(tree):
49+
for p_key, p_grad_key in _memory_profiler.extract_gradients(node):
50+
if node.tag == ctx and key_matches_tensor(p_grad_key, grad_tensor):
51+
if parameter is None:
52+
return True # Don't need to check parameter; we're done.
53+
54+
elif p_key is not None:
55+
# For a complex workflow a gradient could correspond to
56+
# different parameters at different points in a trace.
57+
# However this will not happen in the relatively simple
58+
# cases tested here, so if `extract_gradients` identifies
59+
# the parameter corresponding to a particular gradient it
60+
# must be the one we expect.
61+
self.assertTrue(key_matches_tensor(p_key, parameter))
62+
return True
63+
64+
return False
65+
66+
def assertGradientDetected(self, name: str, *args, **kwargs) -> None:
67+
self.assertTrue(
68+
self.gradient_detected(*args, **kwargs),
69+
f"Failed to identify gradient `{name}` from profile.",
70+
)
71+
72+
def assertOnlyGradients(
73+
self, prof: torch.profiler.profile, tensors: Iterator[torch.Tensor]
74+
) -> None:
75+
allowed_set = {t.storage().data_ptr() for t in tensors}
76+
77+
tree = prof.profiler.kineto_results.experimental_event_tree()
78+
for node in _utils.traverse_dfs(tree):
79+
for _, p_grad_key in _memory_profiler.extract_gradients(node):
80+
self.assertTrue(
81+
p_grad_key.storage.ptr in allowed_set,
82+
f"Tensor wrongly marked as gradient: {node.name}: {p_grad_key}",
83+
)
84+
85+
def test_extract_gradients_low_level(self) -> None:
86+
x = torch.ones((1,))
87+
w0 = torch.ones((1,), requires_grad=True)
88+
w1 = torch.ones((1,), requires_grad=True)
89+
90+
def check(cold_start: bool):
91+
self.assertEqual(w0.grad is None, cold_start)
92+
self.assertEqual(w1.grad is None, cold_start)
93+
with profile() as prof:
94+
z = x.expand(4) * w0
95+
(z * w1).sum().backward()
96+
97+
# Gradient detection through op inspection does not provide a
98+
# reference to the parameter corresponding to the gradient.
99+
self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
100+
self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
101+
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
102+
103+
check(cold_start=True)
104+
check(cold_start=False)
105+
106+
def test_extract_gradients_from_module(self) -> None:
107+
model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
108+
named_parameters = {name: p for name, p in model.named_parameters()}
109+
self.assertEqual(len(named_parameters), 3)
110+
111+
def assert_only_gradients(prof: torch.profiler.profile):
112+
gradients = tuple(i.grad for i in named_parameters.values())
113+
self.assertFalse(any(i is None for i in gradients))
114+
self.assertOnlyGradients(prof, gradients)
115+
116+
def check(cold_start: bool):
117+
x = torch.ones((2, 2))
118+
with profile() as prof:
119+
model(x).sum().backward()
120+
121+
for name, p in named_parameters.items():
122+
# The first time we run a module none of the `.grad` fields
123+
# have been initialized. This is fine; in that case we can
124+
# detect everything we need in the profiled section.
125+
self.assertNotEqual(
126+
self.gradient_detected(prof, _EventType.PyCall, p.grad, p),
127+
cold_start,
128+
name,
129+
)
130+
131+
# Op based detection should still identify the gradients.
132+
self.assertGradientDetected(name, prof, _EventType.TorchOp, p.grad)
133+
assert_only_gradients(prof)
134+
135+
# We can detect gradients even when `.backward()` is not called.
136+
with profile() as prof:
137+
model(torch.ones((2, 2)))
138+
139+
for name, p in named_parameters.items():
140+
self.assertGradientDetected(name, prof, _EventType.PyCall, p.grad, p)
141+
self.assertFalse(
142+
self.gradient_detected(prof, _EventType.TorchOp, p.grad), name
143+
)
144+
assert_only_gradients(prof)
145+
146+
check(cold_start=True)
147+
check(cold_start=False)
148+
149+
def _test_extract_gradients_from_optimizer(self, set_to_none: bool) -> None:
150+
151+
x = torch.ones((1,))
152+
w0 = torch.ones((1,), requires_grad=True)
153+
w1 = torch.ones((1,), requires_grad=True)
154+
optimizer = torch.optim.SGD((w0, w1), lr=0.1, momentum=0.9)
155+
156+
def check(cold_start: bool):
157+
self.assertEqual(w0.grad is None, cold_start)
158+
self.assertEqual(w1.grad is None, cold_start)
159+
with profile() as prof:
160+
optimizer.zero_grad(set_to_none=set_to_none)
161+
z = x.expand(4) * w0
162+
(z * w1).sum().backward()
163+
optimizer.step()
164+
165+
# Optimizer instrumentation runs late in the step, so we can detect
166+
# gradients for both cold and warm start.
167+
self.assertGradientDetected("w0", prof, _EventType.PyCall, w0.grad, w0)
168+
self.assertGradientDetected("w1", prof, _EventType.PyCall, w1.grad, w1)
169+
170+
self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
171+
self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
172+
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
173+
174+
with profile() as prof:
175+
for _ in range(2):
176+
optimizer.zero_grad(set_to_none=set_to_none)
177+
z = x.expand(4) * w0
178+
(z * w1).sum().backward()
179+
optimizer.step()
180+
181+
# Inspected state is cached, so if we replace gradients (as is the
182+
# case for `set_to_none=True`) our python instrumentation will not
183+
# see them.
184+
# TODO(robieta): Should `.step()` be excluded from caching?
185+
self.assertNotEqual(
186+
self.gradient_detected(prof, _EventType.PyCall, w0.grad, w0),
187+
set_to_none,
188+
)
189+
190+
self.assertNotEqual(
191+
self.gradient_detected(prof, _EventType.PyCall, w1.grad, w1),
192+
set_to_none,
193+
)
194+
195+
if set_to_none:
196+
with self.assertRaisesRegex(AssertionError, "Tensor wrongly marked"):
197+
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
198+
199+
check(cold_start=True)
200+
check(cold_start=False)
201+
202+
def test_extract_gradients_from_optimizer(self) -> None:
203+
self._test_extract_gradients_from_optimizer(set_to_none=False)
204+
205+
def test_extract_gradients_from_optimizer_set_to_none(self) -> None:
206+
self._test_extract_gradients_from_optimizer(set_to_none=True)
207+
208+
def test_extract_gradients_from_module_and_optimizer(self) -> None:
209+
# Module and optimizer are thoroughly tested individually and should be
210+
# additive. Thus we can manage with a lightweight check that they don't
211+
# interact adversely.
212+
model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
213+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
214+
with profile() as prof:
215+
model(torch.ones((2, 2))).sum().backward()
216+
optimizer.step()
217+
218+
self.assertGradientDetected(
219+
"weight", prof, _EventType.PyCall, model[0].weight.grad, model[0].weight
220+
)
221+
222+
223+
if __name__ == "__main__":
224+
run_tests()

torch/_C/_profiler.pyi

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ from typing import List, Optional, Tuple, Union
33

44
from torch._C import device, dtype, layout
55

6+
from typing_extensions import Literal
7+
68
# defined in torch/csrc/profiler/python/init.cpp
79

810
class RecordScope(Enum):
@@ -38,11 +40,12 @@ class ProfilerActivity(Enum):
3840
CUDA = ...
3941

4042
class _EventType(Enum):
41-
Allocation = ...
43+
TorchOp = ...
4244
Backend = ...
45+
Allocation = ...
46+
OutOfMemory = ...
4347
PyCall = ...
4448
PyCCall = ...
45-
TorchOp = ...
4649
Kineto = ...
4750

4851
class _ExperimentalConfig:
@@ -71,6 +74,8 @@ class _ProfilerEvent:
7174
start_tid: int
7275
start_time_ns: int
7376
children: List[_ProfilerEvent]
77+
78+
# TODO(robieta): remove in favor of `self.typed`
7479
extra_fields: Union[
7580
_ExtraFields_TorchOp,
7681
_ExtraFields_Backend,
@@ -81,6 +86,18 @@ class _ProfilerEvent:
8186
_ExtraFields_Kineto,
8287
]
8388

89+
@property
90+
def typed(
91+
self,
92+
) -> Union[
93+
Tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp],
94+
Tuple[Literal[_EventType.Backend], _ExtraFields_Backend],
95+
Tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation],
96+
Tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory],
97+
Tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall],
98+
Tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall],
99+
Tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto],
100+
]: ...
84101
@property
85102
def name(self) -> str: ...
86103
@property
@@ -101,6 +118,8 @@ class _TensorMetadata:
101118
storage_data_ptr: Optional[int]
102119
id: Optional[int]
103120

121+
@property
122+
def allocation_id(self) -> Optional[int]: ...
104123
@property
105124
def layout(self) -> layout: ...
106125
@property
@@ -129,11 +148,12 @@ class _ExtraFields_Backend: ...
129148
class _ExtraFields_Allocation:
130149
ptr: int
131150
id: Optional[int]
132-
allocation_id: Optional[int]
133151
alloc_size: int
134152
total_allocated: int
135153
total_reserved: int
136154

155+
@property
156+
def allocation_id(self) -> Optional[int]: ...
137157
@property
138158
def device(self) -> device: ...
139159

@@ -147,22 +167,47 @@ class _PyFrameState:
147167
def file_name(self) -> str: ...
148168

149169
class _NNModuleInfo:
150-
@property
151-
def params(self) -> List[Tuple[str, int]]: ...
152170
@property
153171
def self_ptr(self) -> int: ...
154172
@property
155173
def cls_ptr(self) -> int: ...
156174
@property
157175
def cls_name(self) -> str: ...
176+
@property
177+
def parameters(
178+
self,
179+
) -> List[Tuple[str, _TensorMetadata, Optional[_TensorMetadata]]]: ...
180+
181+
class _OptimizerInfo:
182+
@property
183+
def parameters(
184+
self,
185+
) -> List[
186+
Tuple[
187+
# Parameter
188+
_TensorMetadata,
189+
#
190+
# Gradient (if present during optimizer.step())
191+
Optional[_TensorMetadata],
192+
#
193+
# Optimizer state for Parameter as (name, tensor) pairs
194+
List[Tuple[str, _TensorMetadata]],
195+
]
196+
]: ...
158197

159198
class _ExtraFields_PyCCall:
160-
callsite: _PyFrameState
161-
caller: _PyFrameState
162-
module: Optional[_NNModuleInfo]
199+
@property
200+
def caller(self) -> _PyFrameState: ...
163201

164202
class _ExtraFields_PyCall:
165-
caller: _PyFrameState
203+
@property
204+
def callsite(self) -> _PyFrameState: ...
205+
@property
206+
def caller(self) -> _PyFrameState: ...
207+
@property
208+
def module(self) -> Optional[_NNModuleInfo]: ...
209+
@property
210+
def optimizer(self) -> Optional[_OptimizerInfo]: ...
166211

167212
class _ExtraFields_Kineto: ...
168213

torch/csrc/profiler/python/init.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,13 @@ void initPythonBindings(PyObject* module) {
251251
.def_property_readonly("name", &Result::name)
252252
.def_property_readonly("tag", &Result::tag)
253253
.def_readonly("extra_fields", &Result::extra_fields_)
254+
.def_property_readonly(
255+
"typed",
256+
[](const Result& r) {
257+
return py::make_tuple(
258+
r.tag(),
259+
py::cast(r.extra_fields_, py::return_value_policy::reference));
260+
})
254261
.def_property_readonly(
255262
"id",
256263
[](const Result& r) {

0 commit comments

Comments
 (0)