Skip to content

Commit da87150

Browse files
author
Taylor Robie
committed
[Profiler] Memory profiler part 1: Gradient identification
Pull Request resolved: #86802 There are multiple ways to indentify that a Tensor is a gradient. (A subset of which also give additional context.) So to start off I've made a utility to handle that determination. ghstack-source-id: 170188444 Differential Revision: [D39920730](https://our.internmc.facebook.com/intern/diff/D39920730/)
1 parent 58d037c commit da87150

File tree

5 files changed

+355
-5
lines changed

5 files changed

+355
-5
lines changed

mypy-strict.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ files =
4040
.github,
4141
benchmarks/instruction_counts,
4242
tools,
43+
torch/profiler/_memory_profiler.py,
4344
torch/utils/_pytree.py,
4445
torch/utils/benchmark/utils/common.py,
4546
torch/utils/benchmark/utils/timer.py,
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Owner(s): ["oncall: profiler"]
2+
import functools
3+
from typing import Iterator, Optional
4+
5+
import torch
6+
from torch._C._profiler import _EventType
7+
from torch.profiler import _memory_profiler, _utils
8+
from torch.testing._internal.common_utils import run_tests, TestCase
9+
10+
11+
profile = functools.partial(
12+
torch.profiler.profile, record_shapes=True, profile_memory=True, with_stack=True
13+
)
14+
15+
16+
class ScaleLayer(torch.nn.Module):
17+
def __init__(self) -> None:
18+
super().__init__()
19+
self.scale = torch.nn.Parameter(torch.rand(()), requires_grad=True)
20+
21+
def forward(self, x: torch.Tensor) -> torch.Tensor:
22+
return x * self.scale
23+
24+
25+
class TestIdentifyGradients(TestCase):
26+
def gradient_detected(
27+
self,
28+
prof: torch.profiler.profile,
29+
ctx: _EventType,
30+
grad_tensor: torch.Tensor,
31+
parameter: Optional[torch.Tensor] = None,
32+
) -> None:
33+
34+
# This is not an exhaustive check, but for the purpose of unit testing
35+
# it is sufficient.
36+
def key_matches_tensor(key, tensor) -> bool:
37+
# Vacuous case.
38+
if tensor is None:
39+
return True
40+
41+
if key is None:
42+
return False
43+
44+
return tensor.storage().data_ptr() == key.storage_ptr
45+
46+
tree = prof.profiler.kineto_results.experimental_event_tree()
47+
for node in _utils.traverse_dfs(tree):
48+
for p_key, p_grad_key in _memory_profiler.extract_gradients(node):
49+
if node.tag == ctx and key_matches_tensor(p_grad_key, grad_tensor):
50+
if parameter is None:
51+
return True # Don't need to check parameter; we're done.
52+
53+
elif p_key is not None:
54+
# For a complex workflow a gradient could correspond to
55+
# different parameters at different points in a trace.
56+
# However this will not happen in the relatively simple
57+
# cases tested here, so if `extract_gradients` identifies
58+
# the parameter corresponding to a particular gradient it
59+
# must be the one we expect.
60+
self.assertTrue(key_matches_tensor(p_key, parameter))
61+
return True
62+
63+
return False
64+
65+
def assertGradientDetected(self, name: str, *args, **kwargs) -> None:
66+
self.assertTrue(
67+
self.gradient_detected(*args, **kwargs),
68+
f"Failed to identify gradient `{name}` from profile.",
69+
)
70+
71+
def assertOnlyGradients(
72+
self, prof: torch.profiler.profile, tensors: Iterator[torch.Tensor]
73+
) -> None:
74+
allowed_set = {t.storage().data_ptr() for t in tensors}
75+
76+
tree = prof.profiler.kineto_results.experimental_event_tree()
77+
for node in _utils.traverse_dfs(tree):
78+
for _, p_grad_key in _memory_profiler.extract_gradients(node):
79+
self.assertTrue(
80+
p_grad_key.storage_ptr in allowed_set,
81+
f"Tensor wrongly marked as gradient: {node.name}: {p_grad_key}",
82+
)
83+
84+
def test_extract_gradients_low_level(self) -> None:
85+
x = torch.ones((1,))
86+
w0 = torch.ones((1,), requires_grad=True)
87+
w1 = torch.ones((1,), requires_grad=True)
88+
89+
def check(cold_start: bool):
90+
self.assertEqual(w0.grad is None, cold_start)
91+
self.assertEqual(w1.grad is None, cold_start)
92+
with profile() as prof:
93+
z = x.expand(4) * w0
94+
(z * w1).sum().backward()
95+
96+
# Gradient detection through op inspection does not provide a
97+
# reference to the parameter corresponding to the gradient.
98+
self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
99+
self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
100+
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
101+
102+
check(cold_start=True)
103+
check(cold_start=False)
104+
105+
def test_extract_gradients_from_module(self) -> None:
106+
model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
107+
named_parameters = {name: p for name, p in model.named_parameters()}
108+
self.assertEqual(len(named_parameters), 3)
109+
110+
def assert_only_gradients(prof: torch.profiler.profile):
111+
gradients = tuple(i.grad for i in named_parameters.values())
112+
self.assertFalse(any(i is None for i in gradients))
113+
self.assertOnlyGradients(prof, gradients)
114+
115+
def check(cold_start: bool):
116+
x = torch.ones((2, 2))
117+
with profile() as prof:
118+
model(x).sum().backward()
119+
120+
for name, p in named_parameters.items():
121+
# The first time we run a module none of the `.grad` fields
122+
# have been initialized. This is fine; in that case we can
123+
# detect everything we need in the profiled section.
124+
self.assertNotEqual(
125+
self.gradient_detected(prof, _EventType.PyCall, p.grad, p),
126+
cold_start,
127+
name,
128+
)
129+
130+
# Op based detection should still identify the gradients.
131+
self.assertGradientDetected(name, prof, _EventType.TorchOp, p.grad)
132+
assert_only_gradients(prof)
133+
134+
# We can detect gradients even when `.backward()` is not called.
135+
with profile() as prof:
136+
model(torch.ones((2, 2)))
137+
138+
for name, p in named_parameters.items():
139+
self.assertGradientDetected(name, prof, _EventType.PyCall, p.grad, p)
140+
self.assertFalse(
141+
self.gradient_detected(prof, _EventType.TorchOp, p.grad), name
142+
)
143+
assert_only_gradients(prof)
144+
145+
check(cold_start=True)
146+
check(cold_start=False)
147+
148+
def _test_extract_gradients_from_optimizer(self, set_to_none: bool) -> None:
149+
150+
x = torch.ones((1,))
151+
w0 = torch.ones((1,), requires_grad=True)
152+
w1 = torch.ones((1,), requires_grad=True)
153+
optimizer = torch.optim.SGD((w0, w1), lr=0.1, momentum=0.9)
154+
155+
def check(cold_start: bool):
156+
self.assertEqual(w0.grad is None, cold_start)
157+
self.assertEqual(w1.grad is None, cold_start)
158+
with profile() as prof:
159+
optimizer.zero_grad(set_to_none=set_to_none)
160+
z = x.expand(4) * w0
161+
(z * w1).sum().backward()
162+
optimizer.step()
163+
164+
# Optimizer instrumentation runs late in the step, so we can detect
165+
# gradients for both cold and warm start.
166+
self.assertGradientDetected("w0", prof, _EventType.PyCall, w0.grad, w0)
167+
self.assertGradientDetected("w1", prof, _EventType.PyCall, w1.grad, w1)
168+
169+
self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
170+
self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
171+
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
172+
173+
with profile() as prof:
174+
for _ in range(2):
175+
optimizer.zero_grad(set_to_none=set_to_none)
176+
z = x.expand(4) * w0
177+
(z * w1).sum().backward()
178+
optimizer.step()
179+
180+
# Inspected state is cached, so if we replace gradients (as is the
181+
# case for `set_to_none=True`) our python instrumentation will not
182+
# see them.
183+
# TODO(robieta): Should `.step()` be excluded from caching?
184+
self.assertNotEqual(
185+
self.gradient_detected(prof, _EventType.PyCall, w0.grad, w0),
186+
set_to_none,
187+
)
188+
189+
self.assertNotEqual(
190+
self.gradient_detected(prof, _EventType.PyCall, w1.grad, w1),
191+
set_to_none,
192+
)
193+
194+
if set_to_none:
195+
with self.assertRaisesRegex(AssertionError, "Tensor wrongly marked"):
196+
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
197+
198+
check(cold_start=True)
199+
check(cold_start=False)
200+
201+
def test_extract_gradients_from_optimizer(self) -> None:
202+
self._test_extract_gradients_from_optimizer(set_to_none=False)
203+
204+
def test_extract_gradients_from_optimizer_set_to_none(self) -> None:
205+
self._test_extract_gradients_from_optimizer(set_to_none=True)
206+
207+
def test_extract_gradients_from_module_and_optimizer(self) -> None:
208+
# Module and optimizer are thoroughly tested individually and should be
209+
# additive. Thus we can manage with a lightweight check that they don't
210+
# interact adversely.
211+
model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
212+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
213+
with profile() as prof:
214+
model(torch.ones((2, 2))).sum().backward()
215+
optimizer.step()
216+
217+
self.assertGradientDetected(
218+
"weight", prof, _EventType.PyCall, model[0].weight.grad, model[0].weight
219+
)
220+
221+
222+
if __name__ == "__main__":
223+
run_tests()

torch/_C/_profiler.pyi

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ from typing import List, Optional, Tuple, Union
33

44
from torch._C import device, dtype, layout
55

6+
from typing_extensions import Literal
7+
68
# defined in torch/csrc/profiler/python/init.cpp
79

810
class RecordScope(Enum):
@@ -38,11 +40,12 @@ class ProfilerActivity(Enum):
3840
CUDA = ...
3941

4042
class _EventType(Enum):
41-
Allocation = ...
43+
TorchOp = ...
4244
Backend = ...
45+
Allocation = ...
46+
OutOfMemory = ...
4347
PyCall = ...
4448
PyCCall = ...
45-
TorchOp = ...
4649
Kineto = ...
4750

4851
class _ExperimentalConfig:
@@ -71,6 +74,8 @@ class _ProfilerEvent:
7174
start_tid: int
7275
start_time_ns: int
7376
children: List[_ProfilerEvent]
77+
78+
# TODO(robieta): remove in favor of `self.typed`
7479
extra_fields: Union[
7580
_ExtraFields_TorchOp,
7681
_ExtraFields_Backend,
@@ -81,6 +86,18 @@ class _ProfilerEvent:
8186
_ExtraFields_Kineto,
8287
]
8388

89+
@property
90+
def typed(
91+
self,
92+
) -> Union[
93+
Tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp],
94+
Tuple[Literal[_EventType.Backend], _ExtraFields_Backend],
95+
Tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation],
96+
Tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory],
97+
Tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall],
98+
Tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall],
99+
Tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto],
100+
]: ...
84101
@property
85102
def name(self) -> str: ...
86103
@property
@@ -154,10 +171,32 @@ class _NNModuleInfo:
154171
@property
155172
def cls_name(self) -> str: ...
156173

174+
class _OptimizerInfo:
175+
@property
176+
def parameters(
177+
self,
178+
) -> List[
179+
Tuple[
180+
# Parameter
181+
_TensorMetadata,
182+
#
183+
# Gradient (if present during optimizer.step())
184+
Optional[_TensorMetadata],
185+
#
186+
# Optimizer state for Parameter as (name, tensor) pairs
187+
List[Tuple[str, _TensorMetadata]],
188+
]
189+
]: ...
190+
157191
class _ExtraFields_PyCCall:
158-
callsite: _PyFrameState
159-
caller: _PyFrameState
160-
module: Optional[_NNModuleInfo]
192+
@property
193+
def callsite(self) -> _PyFrameState: ...
194+
@property
195+
def caller(self) -> _PyFrameState: ...
196+
@property
197+
def module(self) -> Optional[_NNModuleInfo]: ...
198+
@property
199+
def optimizer(self) -> Optional[_OptimizerInfo]: ...
161200

162201
class _ExtraFields_PyCall:
163202
caller: _PyFrameState

torch/csrc/profiler/python/init.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ void initPythonBindings(PyObject* module) {
222222
.def_property_readonly("name", &Result::name)
223223
.def_property_readonly("tag", &Result::tag)
224224
.def_readonly("extra_fields", &Result::extra_fields_)
225+
.def_property_readonly(
226+
"typed",
227+
[](const Result& r) {
228+
return py::make_tuple(r.tag(), r.extra_fields_);
229+
})
225230
.def_property_readonly(
226231
"id",
227232
[](const Result& r) {

0 commit comments

Comments
 (0)