Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,23 @@ def test_no_grad_modifies_version(self):
self.assertRaisesRegex(RuntimeError, 'modified by an inplace operation',
lambda: z.backward())

def test_no_grad_input(self):
class MyFunction(Function):
@staticmethod
def forward(self, x):
return x

@staticmethod
def backward(self, grad_output):
return grad_output

x = torch.randn(5, requires_grad=True)
with torch.no_grad():
y = MyFunction.apply(x)

self.assertTrue(x.requires_grad)
self.assertIsNone(y.grad_fn)

def test_backward_copy(self):
# This tests checks backward engine for a very subtle bug that appreared
# in one of the initial versions of autograd. Gradients tensors were
Expand Down
9 changes: 9 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,6 +1905,15 @@ def test_broadcast_not_requiring_grad(self):
input_var = variables[output_idx % len(variables)]
self.assertEqual(input_var.requires_grad, broadcasted_var.requires_grad)

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_broadcast_no_grad(self):
x = torch.randn(1, 2, dtype=torch.cuda.float32, requires_grad=True)
with torch.no_grad():
broadcasted = Broadcast.apply((0, 1), x)
self.assertTrue(x.requires_grad)
for output in broadcasted:
self.assertFalse(output.requires_grad)

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_replicate(self):
module = nn.Linear(10, 5).float().cuda()
Expand Down
13 changes: 10 additions & 3 deletions torch/csrc/autograd/python_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,22 @@ static void _wrap_outputs(THPFunction *self,
auto set_history = [&](Variable& var, uint32_t output_nr, bool is_input, bool is_modified,
bool is_differentiable) {
if (!is_differentiable) {
if (!var.requires_grad()) return;
if (!var.requires_grad()) {
return;
}
// NB: we don't support returning non-differentiable views that could require grad
// (this could happen if someone were to return an input to the function).
if (var.is_view()) {
throw std::runtime_error("Returning Variables sharing storage with other Variables "
"that require grad is not supported in Python functions. "
"Please submit a feature request if you hit this error.");
}
var.detach_();
// Return detached aliases of inputs, instead of changing their requires_grad
// property.
if (is_input) {
var = var.detach();
} else {
var.detach_();
}
} else if (is_modified) {
if (var.is_leaf() && var.requires_grad()) {
throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation.");
Expand Down