Skip to content

Commit 8706ce7

Browse files
author
Taylor Robie
committed
[Profiler] Memory profiler part 1: Gradient identification
Pull Request resolved: #86802 There are multiple ways to indentify that a Tensor is a gradient. (A subset of which also give additional context.) So to start off I've made a utility to handle that determination. ghstack-source-id: 170233186 Differential Revision: [D39920730](https://our.internmc.facebook.com/intern/diff/D39920730/)
1 parent 58d037c commit 8706ce7

File tree

5 files changed

+370
-8
lines changed

5 files changed

+370
-8
lines changed

mypy-strict.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ files =
4040
.github,
4141
benchmarks/instruction_counts,
4242
tools,
43+
torch/profiler/_memory_profiler.py,
4344
torch/utils/_pytree.py,
4445
torch/utils/benchmark/utils/common.py,
4546
torch/utils/benchmark/utils/timer.py,
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Owner(s): ["oncall: profiler"]
2+
import functools
3+
from typing import Iterator, Optional
4+
5+
import torch
6+
from torch._C._profiler import _EventType
7+
from torch.profiler import _memory_profiler, _utils
8+
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
9+
10+
11+
profile = functools.partial(
12+
torch.profiler.profile, record_shapes=True, profile_memory=True, with_stack=True
13+
)
14+
15+
16+
class ScaleLayer(torch.nn.Module):
17+
def __init__(self) -> None:
18+
super().__init__()
19+
self.scale = torch.nn.Parameter(torch.rand(()), requires_grad=True)
20+
21+
def forward(self, x: torch.Tensor) -> torch.Tensor:
22+
return x * self.scale
23+
24+
25+
@skipIfTorchDynamo("TorchDynamo changes Python calls that memory profiling relies on.")
26+
class TestIdentifyGradients(TestCase):
27+
def gradient_detected(
28+
self,
29+
prof: torch.profiler.profile,
30+
ctx: _EventType,
31+
grad_tensor: torch.Tensor,
32+
parameter: Optional[torch.Tensor] = None,
33+
) -> None:
34+
35+
# This is not an exhaustive check, but for the purpose of unit testing
36+
# it is sufficient.
37+
def key_matches_tensor(key, tensor) -> bool:
38+
# Vacuous case.
39+
if tensor is None:
40+
return True
41+
42+
if key is None:
43+
return False
44+
45+
return tensor.storage().data_ptr() == key.storage_ptr
46+
47+
tree = prof.profiler.kineto_results.experimental_event_tree()
48+
for node in _utils.traverse_dfs(tree):
49+
for p_key, p_grad_key in _memory_profiler.extract_gradients(node):
50+
if node.tag == ctx and key_matches_tensor(p_grad_key, grad_tensor):
51+
if parameter is None:
52+
return True # Don't need to check parameter; we're done.
53+
54+
elif p_key is not None:
55+
# For a complex workflow a gradient could correspond to
56+
# different parameters at different points in a trace.
57+
# However this will not happen in the relatively simple
58+
# cases tested here, so if `extract_gradients` identifies
59+
# the parameter corresponding to a particular gradient it
60+
# must be the one we expect.
61+
self.assertTrue(key_matches_tensor(p_key, parameter))
62+
return True
63+
64+
return False
65+
66+
def assertGradientDetected(self, name: str, *args, **kwargs) -> None:
67+
self.assertTrue(
68+
self.gradient_detected(*args, **kwargs),
69+
f"Failed to identify gradient `{name}` from profile.",
70+
)
71+
72+
def assertOnlyGradients(
73+
self, prof: torch.profiler.profile, tensors: Iterator[torch.Tensor]
74+
) -> None:
75+
allowed_set = {t.storage().data_ptr() for t in tensors}
76+
77+
tree = prof.profiler.kineto_results.experimental_event_tree()
78+
for node in _utils.traverse_dfs(tree):
79+
for _, p_grad_key in _memory_profiler.extract_gradients(node):
80+
self.assertTrue(
81+
p_grad_key.storage_ptr in allowed_set,
82+
f"Tensor wrongly marked as gradient: {node.name}: {p_grad_key}",
83+
)
84+
85+
def test_extract_gradients_low_level(self) -> None:
86+
x = torch.ones((1,))
87+
w0 = torch.ones((1,), requires_grad=True)
88+
w1 = torch.ones((1,), requires_grad=True)
89+
90+
def check(cold_start: bool):
91+
self.assertEqual(w0.grad is None, cold_start)
92+
self.assertEqual(w1.grad is None, cold_start)
93+
with profile() as prof:
94+
z = x.expand(4) * w0
95+
(z * w1).sum().backward()
96+
97+
# Gradient detection through op inspection does not provide a
98+
# reference to the parameter corresponding to the gradient.
99+
self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
100+
self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
101+
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
102+
103+
check(cold_start=True)
104+
check(cold_start=False)
105+
106+
def test_extract_gradients_from_module(self) -> None:
107+
model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
108+
named_parameters = {name: p for name, p in model.named_parameters()}
109+
self.assertEqual(len(named_parameters), 3)
110+
111+
def assert_only_gradients(prof: torch.profiler.profile):
112+
gradients = tuple(i.grad for i in named_parameters.values())
113+
self.assertFalse(any(i is None for i in gradients))
114+
self.assertOnlyGradients(prof, gradients)
115+
116+
def check(cold_start: bool):
117+
x = torch.ones((2, 2))
118+
with profile() as prof:
119+
model(x).sum().backward()
120+
121+
for name, p in named_parameters.items():
122+
# The first time we run a module none of the `.grad` fields
123+
# have been initialized. This is fine; in that case we can
124+
# detect everything we need in the profiled section.
125+
self.assertNotEqual(
126+
self.gradient_detected(prof, _EventType.PyCall, p.grad, p),
127+
cold_start,
128+
name,
129+
)
130+
131+
# Op based detection should still identify the gradients.
132+
self.assertGradientDetected(name, prof, _EventType.TorchOp, p.grad)
133+
assert_only_gradients(prof)
134+
135+
# We can detect gradients even when `.backward()` is not called.
136+
with profile() as prof:
137+
model(torch.ones((2, 2)))
138+
139+
for name, p in named_parameters.items():
140+
self.assertGradientDetected(name, prof, _EventType.PyCall, p.grad, p)
141+
self.assertFalse(
142+
self.gradient_detected(prof, _EventType.TorchOp, p.grad), name
143+
)
144+
assert_only_gradients(prof)
145+
146+
check(cold_start=True)
147+
check(cold_start=False)
148+
149+
def _test_extract_gradients_from_optimizer(self, set_to_none: bool) -> None:
150+
151+
x = torch.ones((1,))
152+
w0 = torch.ones((1,), requires_grad=True)
153+
w1 = torch.ones((1,), requires_grad=True)
154+
optimizer = torch.optim.SGD((w0, w1), lr=0.1, momentum=0.9)
155+
156+
def check(cold_start: bool):
157+
self.assertEqual(w0.grad is None, cold_start)
158+
self.assertEqual(w1.grad is None, cold_start)
159+
with profile() as prof:
160+
optimizer.zero_grad(set_to_none=set_to_none)
161+
z = x.expand(4) * w0
162+
(z * w1).sum().backward()
163+
optimizer.step()
164+
165+
# Optimizer instrumentation runs late in the step, so we can detect
166+
# gradients for both cold and warm start.
167+
self.assertGradientDetected("w0", prof, _EventType.PyCall, w0.grad, w0)
168+
self.assertGradientDetected("w1", prof, _EventType.PyCall, w1.grad, w1)
169+
170+
self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
171+
self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
172+
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
173+
174+
with profile() as prof:
175+
for _ in range(2):
176+
optimizer.zero_grad(set_to_none=set_to_none)
177+
z = x.expand(4) * w0
178+
(z * w1).sum().backward()
179+
optimizer.step()
180+
181+
# Inspected state is cached, so if we replace gradients (as is the
182+
# case for `set_to_none=True`) our python instrumentation will not
183+
# see them.
184+
# TODO(robieta): Should `.step()` be excluded from caching?
185+
self.assertNotEqual(
186+
self.gradient_detected(prof, _EventType.PyCall, w0.grad, w0),
187+
set_to_none,
188+
)
189+
190+
self.assertNotEqual(
191+
self.gradient_detected(prof, _EventType.PyCall, w1.grad, w1),
192+
set_to_none,
193+
)
194+
195+
if set_to_none:
196+
with self.assertRaisesRegex(AssertionError, "Tensor wrongly marked"):
197+
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
198+
199+
check(cold_start=True)
200+
check(cold_start=False)
201+
202+
def test_extract_gradients_from_optimizer(self) -> None:
203+
self._test_extract_gradients_from_optimizer(set_to_none=False)
204+
205+
def test_extract_gradients_from_optimizer_set_to_none(self) -> None:
206+
self._test_extract_gradients_from_optimizer(set_to_none=True)
207+
208+
def test_extract_gradients_from_module_and_optimizer(self) -> None:
209+
# Module and optimizer are thoroughly tested individually and should be
210+
# additive. Thus we can manage with a lightweight check that they don't
211+
# interact adversely.
212+
model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
213+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
214+
with profile() as prof:
215+
model(torch.ones((2, 2))).sum().backward()
216+
optimizer.step()
217+
218+
self.assertGradientDetected(
219+
"weight", prof, _EventType.PyCall, model[0].weight.grad, model[0].weight
220+
)
221+
222+
223+
if __name__ == "__main__":
224+
run_tests()

torch/_C/_profiler.pyi

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ from typing import List, Optional, Tuple, Union
33

44
from torch._C import device, dtype, layout
55

6+
from typing_extensions import Literal
7+
68
# defined in torch/csrc/profiler/python/init.cpp
79

810
class RecordScope(Enum):
@@ -38,11 +40,12 @@ class ProfilerActivity(Enum):
3840
CUDA = ...
3941

4042
class _EventType(Enum):
41-
Allocation = ...
43+
TorchOp = ...
4244
Backend = ...
45+
Allocation = ...
46+
OutOfMemory = ...
4347
PyCall = ...
4448
PyCCall = ...
45-
TorchOp = ...
4649
Kineto = ...
4750

4851
class _ExperimentalConfig:
@@ -71,6 +74,8 @@ class _ProfilerEvent:
7174
start_tid: int
7275
start_time_ns: int
7376
children: List[_ProfilerEvent]
77+
78+
# TODO(robieta): remove in favor of `self.typed`
7479
extra_fields: Union[
7580
_ExtraFields_TorchOp,
7681
_ExtraFields_Backend,
@@ -81,6 +86,18 @@ class _ProfilerEvent:
8186
_ExtraFields_Kineto,
8287
]
8388

89+
@property
90+
def typed(
91+
self,
92+
) -> Union[
93+
Tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp],
94+
Tuple[Literal[_EventType.Backend], _ExtraFields_Backend],
95+
Tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation],
96+
Tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory],
97+
Tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall],
98+
Tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall],
99+
Tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto],
100+
]: ...
84101
@property
85102
def name(self) -> str: ...
86103
@property
@@ -145,21 +162,46 @@ class _PyFrameState:
145162
def file_name(self) -> str: ...
146163

147164
class _NNModuleInfo:
148-
@property
149-
def params(self) -> List[Tuple[str, int]]: ...
150165
@property
151166
def self_ptr(self) -> int: ...
152167
@property
153168
def cls_ptr(self) -> int: ...
154169
@property
155170
def cls_name(self) -> str: ...
171+
@property
172+
def parameters(
173+
self,
174+
) -> List[Tuple[str, _TensorMetadata, Optional[_TensorMetadata]]]: ...
175+
176+
class _OptimizerInfo:
177+
@property
178+
def parameters(
179+
self,
180+
) -> List[
181+
Tuple[
182+
# Parameter
183+
_TensorMetadata,
184+
#
185+
# Gradient (if present during optimizer.step())
186+
Optional[_TensorMetadata],
187+
#
188+
# Optimizer state for Parameter as (name, tensor) pairs
189+
List[Tuple[str, _TensorMetadata]],
190+
]
191+
]: ...
156192

157193
class _ExtraFields_PyCCall:
158-
callsite: _PyFrameState
159-
caller: _PyFrameState
160-
module: Optional[_NNModuleInfo]
194+
@property
195+
def caller(self) -> _PyFrameState: ...
161196

162197
class _ExtraFields_PyCall:
163-
caller: _PyFrameState
198+
@property
199+
def callsite(self) -> _PyFrameState: ...
200+
@property
201+
def caller(self) -> _PyFrameState: ...
202+
@property
203+
def module(self) -> Optional[_NNModuleInfo]: ...
204+
@property
205+
def optimizer(self) -> Optional[_OptimizerInfo]: ...
164206

165207
class _ExtraFields_Kineto: ...

torch/csrc/profiler/python/init.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ void initPythonBindings(PyObject* module) {
222222
.def_property_readonly("name", &Result::name)
223223
.def_property_readonly("tag", &Result::tag)
224224
.def_readonly("extra_fields", &Result::extra_fields_)
225+
.def_property_readonly(
226+
"typed",
227+
[](const Result& r) {
228+
return py::make_tuple(r.tag(), r.extra_fields_);
229+
})
225230
.def_property_readonly(
226231
"id",
227232
[](const Result& r) {

0 commit comments

Comments
 (0)