Skip to content

Commit 1db0f73

Browse files
Taylor Robiepytorchmergebot
authored andcommitted
[Profiler] Account for caching when assigning IDs (#88917)
The python tracer caches information about module and optimizer state. That means that for subsequent calls, the presence of a Tensor in these fields does not imply that the Tensor is still live; just that it was live during the first call. (I should perhaps rename the fields to something like `stale_parameters` to convey this.) Unless we discard subsequent calls ID assignment get tripped up when it see's a Tensor that was already released. Differential Revision: [D41226827](https://our.internmc.facebook.com/intern/diff/D41226827/) Pull Request resolved: #88917 Approved by: https://github.com/chaekit
1 parent ee44123 commit 1db0f73

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

torch/csrc/profiler/data_flow.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ void calculateUniqueTensorIDs(
6969
// --------------------------------------------------------------------------
7070
{
7171
RawTensors raw_tensors;
72+
73+
// The python tracer caches values, so it's only safe to use the first case.
74+
ska::flat_hash_set<PyModuleSelf> seen_modules;
75+
ska::flat_hash_set<PyOptimizerSelf> seen_optimizers;
7276
for (auto& result : sorted_results) {
7377
result->visit(c10::overloaded(
7478
[&](ExtraFields<EventType::TorchOp>& torch_op) {
@@ -78,15 +82,17 @@ void calculateUniqueTensorIDs(
7882
},
7983
[&](ExtraFields<EventType::PyCall>& py_call) {
8084
// torch.nn.Module
81-
if (py_call.module_.has_value()) {
85+
if (py_call.module_.has_value() &&
86+
seen_modules.insert(py_call.module_->self_).second) {
8287
for (auto& p : py_call.module_->parameters_) {
8388
raw_tensors(p.metadata_);
8489
raw_tensors(p.grad_metadata_);
8590
}
8691
}
8792

8893
// torch.optim.Optimizer
89-
if (py_call.optimizer_.has_value()) {
94+
if (py_call.optimizer_.has_value() &&
95+
seen_optimizers.insert(py_call.optimizer_->self_).second) {
9096
for (auto& p : py_call.optimizer_->parameters_) {
9197
raw_tensors(p.metadata_);
9298
raw_tensors(p.grad_metadata_);

0 commit comments

Comments
 (0)