Skip to content

Commit 2c110b0

Browse files
author
Taylor Robie
committed
[Profiler] Memory profiler part 2: Config validation
Pull Request resolved: #86853 Memory profiling requires `record_shapes`, `profile_memory`, and `with_stack`. This PR just adds a skeleton endpoint with a good error message if certain flags are missing. ghstack-source-id: 172169616 Differential Revision: [D39920801](https://our.internmc.facebook.com/intern/diff/D39920801/)
1 parent 991052d commit 2c110b0

File tree

3 files changed

+55
-4
lines changed

3 files changed

+55
-4
lines changed

test/profiler/test_memory_profiler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,28 @@
1212
torch.profiler.profile, record_shapes=True, profile_memory=True, with_stack=True
1313
)
1414

15+
@skipIfTorchDynamo("TorchDynamo removes profiler altogether.")
16+
class TestMemoryProfiler(TestCase):
17+
def test_config_check(self) -> None:
18+
with torch.profiler.profile() as prof:
19+
pass
20+
21+
pattern = r"record_shapes=True, profile_memory=True, with_stack=True"
22+
with self.assertRaisesRegex(ValueError, pattern):
23+
prof._memory_profile()
24+
25+
with torch.profiler.profile(record_shapes=True, with_stack=True) as prof:
26+
pass
27+
28+
pattern = r"^profile_memory=True required for memory profiling\.$"
29+
with self.assertRaisesRegex(ValueError, pattern):
30+
prof._memory_profile()
31+
32+
with profile() as prof:
33+
pass
34+
35+
self.assertIsInstance(prof._memory_profile(), _memory_profiler.MemoryProfile)
36+
1537

1638
class ScaleLayer(torch.nn.Module):
1739
def __init__(self) -> None:

torch/profiler/_memory_profiler.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22
from typing import Any, Iterator, Optional, Tuple
33

44
import torch
5-
from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata, RecordScope
5+
from torch._C._autograd import _ProfilerResult
6+
from torch._C._profiler import (
7+
_EventType,
8+
_ProfilerEvent,
9+
_TensorMetadata,
10+
RecordScope,
11+
)
612

713

814
@dataclasses.dataclass
@@ -112,3 +118,8 @@ def extract_gradients(
112118
p_grad_key = TensorKey.from_tensor(p_grad)
113119
if p_grad_key is not None:
114120
yield TensorKey.from_tensor(p), p_grad_key
121+
122+
123+
class MemoryProfile:
124+
def __init__(self, result: _ProfilerResult) -> None:
125+
pass

torch/profiler/profiler.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,19 @@
1616
_ExperimentalConfig,
1717
_remove_execution_graph_observer,
1818
)
19-
from torch.autograd import ProfilerActivity, kineto_available
19+
from torch.autograd import kineto_available, ProfilerActivity
20+
from torch.profiler import _memory_profiler
21+
22+
23+
__all__ = [
24+
"supported_activities",
25+
"ProfilerAction",
26+
"schedule",
27+
"tensorboard_trace_handler",
28+
"profile",
29+
"ExecutionGraphObserver",
30+
]
2031

21-
__all__ = ['supported_activities', 'ProfilerAction', 'schedule', 'tensorboard_trace_handler', 'profile',
22-
'ExecutionGraphObserver']
2332

2433
def supported_activities():
2534
"""
@@ -208,6 +217,15 @@ def _get_distributed_info(self):
208217
"world_size": dist.get_world_size()
209218
}
210219

220+
def _memory_profile(self) -> _memory_profiler.MemoryProfile:
221+
required = ("record_shapes", "profile_memory", "with_stack")
222+
missing = [f"{i}=True" for i in required if not getattr(self, i)]
223+
if missing:
224+
raise ValueError(f"{', '.join(missing)} required for memory profiling.")
225+
226+
assert self.profiler is not None and self.profiler.kineto_results is not None
227+
return _memory_profiler.MemoryProfile(self.profiler.kineto_results)
228+
211229

212230
class ProfilerAction(Enum):
213231
"""

0 commit comments

Comments
 (0)