Skip to content

Commit 7abdc30

Browse files
authored
Don't allow requires_grad to be set on integer Tensor constructors in… (#7185)
* Don't allow requires_grad to be set on integer Tensor constructors in tensor_new. * Fix autograd test. * Fix test_distributions. * Fix test_jit. * Fix NN tests.
1 parent 431c80a commit 7abdc30

File tree

8 files changed

+34
-28
lines changed

8 files changed

+34
-28
lines changed

test/test_autograd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1774,7 +1774,7 @@ def backward(self, grad_output):
17741774
self.assertEqual(x.grad.data, torch.ones(x.size()))
17751775

17761776
def test_set_grad_enabled(self):
1777-
x = torch.tensor([1], requires_grad=True)
1777+
x = torch.tensor([1.], requires_grad=True)
17781778
with torch.set_grad_enabled(False):
17791779
y = x * 2
17801780
self.assertFalse(y.requires_grad)

test/test_distributions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1877,7 +1877,9 @@ def test_cdf_log_prob(self):
18771877
for Dist, params in EXAMPLES:
18781878
for i, param in enumerate(params):
18791879
dist = Dist(**param)
1880-
samples = torch.tensor(dist.sample().data, requires_grad=True)
1880+
samples = torch.tensor(dist.sample().data)
1881+
if samples.dtype.is_floating_point:
1882+
samples.requires_grad_()
18811883
try:
18821884
cdfs = dist.cdf(samples)
18831885
pdfs = dist.log_prob(samples).exp()

test/test_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def forward(ctx, x):
458458
def backward(ctx, go):
459459
return go
460460

461-
x = torch.tensor([0], requires_grad=True)
461+
x = torch.tensor([0.], requires_grad=True)
462462

463463
def fn(x):
464464
y = RegularFn.apply(x)

test/test_nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2046,7 +2046,7 @@ def test_pad(self):
20462046
lambda: F.pad(torch.randn(1, 1, 2), (2, 1), mode='reflect'))
20472047

20482048
def test_pad_scalar_error(self):
2049-
inputs = torch.tensor(0, requires_grad=True)
2049+
inputs = torch.tensor(0., requires_grad=True)
20502050
self.assertRaises(AssertionError, lambda: F.pad(inputs, (1, 1)))
20512051
self.assertRaises(AssertionError, lambda: F.pad(inputs, (1,)))
20522052

test/test_torch.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,8 +1903,8 @@ def get_int64_dtype(dtype):
19031903
check_value(torch.empty(shape, out=out, device=device, layout=layout, requires_grad=rg),
19041904
dtype, layout, device, None, rg)
19051905
check_value(v.new_empty(shape), dtype, layout, device, None, False)
1906-
check_value(v.new_empty(shape, dtype=int64_dtype, device=device, requires_grad=rg),
1907-
int64_dtype, layout, device, None, rg)
1906+
check_value(v.new_empty(shape, dtype=int64_dtype, device=device, requires_grad=False),
1907+
int64_dtype, layout, device, None, False)
19081908
check_value(torch.empty_like(v), dtype, layout, device, None, False)
19091909
check_value(torch.empty_like(v, dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
19101910
int64_dtype, layout, device, None, False)
@@ -1917,8 +1917,8 @@ def get_int64_dtype(dtype):
19171917
out = v.new()
19181918
check_value(torch.full(shape, fv + 2, out=out, device=device, layout=layout, requires_grad=rg),
19191919
dtype, layout, device, fv + 2, rg)
1920-
check_value(v.new_full(shape, fv + 3, dtype=int64_dtype, device=device, requires_grad=rg),
1921-
int64_dtype, layout, device, fv + 3, rg)
1920+
check_value(v.new_full(shape, fv + 3, dtype=int64_dtype, device=device, requires_grad=False),
1921+
int64_dtype, layout, device, fv + 3, False)
19221922
check_value(torch.full_like(v, fv + 4), dtype, layout, device, fv + 4, False)
19231923
check_value(torch.full_like(v, fv + 5,
19241924
dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
@@ -2697,12 +2697,12 @@ def test_contiguous(self):
26972697

26982698
def test_scalars_as_floats(self):
26992699
"zero-dim variables that don't require grad should bind to scalar arguments"
2700-
x = torch.tensor(2)
2701-
y = torch.tensor(3)
2700+
x = torch.tensor(2.)
2701+
y = torch.tensor(3.)
27022702
# 3 + (3 * 3) * 2
27032703
self.assertEqual(y.addcmul(y, y, value=x), 21)
27042704

2705-
x = torch.tensor(2, requires_grad=True)
2705+
x = torch.tensor(2., requires_grad=True)
27062706
self.assertRaises(Exception, lambda: y.addcmul(y, y, value=x))
27072707

27082708
@staticmethod
@@ -6123,8 +6123,6 @@ def test_parsing_int64(self):
61236123
self.assertEqual(x, torch.cumsum(torch.ones(5, 5), torch.tensor(0)))
61246124
# doesn't accept floating point variables
61256125
self.assertRaises(TypeError, lambda: torch.cumsum(torch.ones(5, 5), torch.tensor(0.)))
6126-
# doesn't accept variables with requires_grad
6127-
self.assertRaises(TypeError, lambda: torch.cumsum(torch.ones(5, 5), torch.tensor(0, requires_grad=True)))
61286126

61296127
def test_parsing_double(self):
61306128
# accepts floating point and integer arguments
@@ -6136,8 +6134,6 @@ def test_parsing_double(self):
61366134
self.assertTrue(torch.isclose(x, x, torch.tensor(1), torch.tensor(1)).all())
61376135
self.assertTrue(torch.isclose(x, x, torch.tensor(1.5), torch.tensor(1.)).all())
61386136
# doesn't accept variables with requires_grad
6139-
self.assertRaises(TypeError,
6140-
lambda: torch.isclose(x, x, torch.tensor(1, requires_grad=True), torch.tensor(1)).all())
61416137
self.assertRaises(TypeError,
61426138
lambda: torch.isclose(x, x, torch.tensor(1.5), torch.tensor(1., requires_grad=True)).all())
61436139

tools/autograd/templates/python_torch_functions.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "torch/csrc/DynamicTypes.h"
1313
#include "torch/csrc/Exceptions.h"
1414
#include "torch/csrc/autograd/python_variable.h"
15+
#include "torch/csrc/autograd/utils/python_variables.h"
1516
#include "torch/csrc/autograd/utils/wrap_outputs.h"
1617
#include "torch/csrc/utils/python_arg_parser.h"
1718
#include "torch/csrc/utils/tensor_new.h"
@@ -26,18 +27,12 @@ using at::Tensor;
2627
using at::Scalar;
2728
using at::ScalarType;
2829
using at::Backend;
30+
using torch::autograd::utils::set_requires_grad;
31+
2932
using namespace torch::autograd::utils;
3033

3134
namespace torch { namespace autograd {
3235

33-
static Tensor set_requires_grad(Tensor self, bool requires_grad) {
34-
if (requires_grad && !at::isFloatingType(self.type().scalarType())) {
35-
throw std::runtime_error("only Tensors of floating point dtype can require gradients");
36-
}
37-
as_variable_ref(self).set_requires_grad(requires_grad);
38-
return self;
39-
}
40-
4136
static void check_out_type_matches(Tensor result,
4237
ScalarType scalarType, bool scalarType_is_none,
4338
const THPLayout& layout, bool layout_is_none,
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include "torch/csrc/autograd/python_variable.h"
5+
6+
namespace torch { namespace autograd { namespace utils {
7+
8+
inline at::Tensor set_requires_grad(at::Tensor self, bool requires_grad) {
9+
if (requires_grad && !at::isFloatingType(self.type().scalarType())) {
10+
throw std::runtime_error("only Tensors of floating point dtype can require gradients");
11+
}
12+
as_variable_ref(self).set_requires_grad(requires_grad);
13+
return self;
14+
}
15+
16+
}}} // namespace torch::autograd::utils

torch/csrc/utils/tensor_new.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "torch/csrc/Exceptions.h"
1010
#include "torch/csrc/Size.h"
1111
#include "torch/csrc/autograd/variable.h"
12+
#include "torch/csrc/autograd/utils/python_variables.h"
1213
#include "torch/csrc/utils/auto_gil.h"
1314
#include "torch/csrc/utils/auto_gpu.h"
1415
#include "torch/csrc/utils/cuda_lazy_init.h"
@@ -23,6 +24,7 @@
2324
static const int MAX_DIMS = 128;
2425

2526
using namespace at;
27+
using torch::autograd::utils::set_requires_grad;
2628

2729
namespace torch { namespace utils {
2830

@@ -389,11 +391,6 @@ static const Type& typeWithDefault(PythonArgs& r, int64_t dtype_idx, int64_t dev
389391
return torch::getType(scalartype, *torch::getLayout(type.backend()), device_type);
390392
}
391393

392-
static Tensor set_requires_grad(Tensor self, bool requires_grad) {
393-
static_cast<torch::autograd::Variable&>(self).set_requires_grad(requires_grad);
394-
return self;
395-
}
396-
397394
Tensor sparse_coo_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
398395
Backend sparse_backend = type.is_cuda() ? kSparseCUDA : kSparseCPU;
399396
const auto& default_sparse_type = type.toBackend(sparse_backend);

0 commit comments

Comments
 (0)