Skip to content

Commit d6815e1

Browse files
mrshenlifacebook-github-bot
authored andcommitted
Only record grad_fn in C++ Scatter and Gather when required so (#20286)
Summary: C++ `Scatter` and `Gather` always set autograd history for input data tensors regardless whether they require grad. This hits assertion failure in `set_history(Tensor, shared_ptr<Function> grad_fn)` where `grad_fn` cannot be nullptr. After this PR, C++ `Scatter` and `Gather` only record `grad_fn` when required. Pull Request resolved: #20286 Differential Revision: D15266610 Pulled By: mrshenli fbshipit-source-id: 641df0ea36e7c922b5820c8dc3f83e2a050412b5
1 parent 2179d5b commit d6815e1

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

torch/csrc/autograd/functions/comm.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ variable_list Scatter::apply(variable_list&& inputs) {
5959
}
6060
}
6161

62-
set_history(variables, grad_fn);
62+
if (grad_fn) {
63+
set_history(variables, grad_fn);
64+
}
6365

6466
return variables;
6567
}
@@ -120,7 +122,9 @@ variable_list Gather::apply(variable_list&& inputs) {
120122
const auto destination_index =
121123
destination_device_.is_cpu() ? -1 : destination_device_.index();
122124
auto variable = torch::cuda::gather(tensors, dim_, destination_index);
123-
set_history(variable, grad_fn);
125+
if (grad_fn) {
126+
set_history(variable, grad_fn);
127+
}
124128
return {variable};
125129
}
126130

0 commit comments

Comments
 (0)