Skip to content

Commit 04fcbd6

Browse files
committed
fix TorchTest.test_empty_full to not use requires_grad on int tensors.
1 parent 61a5d19 commit 04fcbd6

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

test/test_torch.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,7 +1781,7 @@ def get_int64_dtype(dtype):
17811781
check_value(torch.empty(shape), default_dtype, torch.strided, -1, None, False)
17821782
check_value(torch.full(shape, -5), default_dtype, torch.strided, -1, None, False)
17831783
for dtype in dtypes:
1784-
for rg in [True, False]:
1784+
for rg in {dtype.is_floating_point, False}:
17851785
int64_dtype = get_int64_dtype(dtype)
17861786
v = torch.empty(shape, dtype=dtype, device=device, layout=layout, requires_grad=rg)
17871787
check_value(v, dtype, layout, device, None, rg)
@@ -1792,8 +1792,8 @@ def get_int64_dtype(dtype):
17921792
check_value(v.new_empty(shape, dtype=int64_dtype, device=device, requires_grad=rg),
17931793
int64_dtype, layout, device, None, rg)
17941794
check_value(torch.empty_like(v), dtype, layout, device, None, False)
1795-
check_value(torch.empty_like(v, dtype=int64_dtype, layout=layout, device=device, requires_grad=rg),
1796-
int64_dtype, layout, device, None, rg)
1795+
check_value(torch.empty_like(v, dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
1796+
int64_dtype, layout, device, None, False)
17971797

17981798
if dtype is not torch.float16 and layout != torch.sparse_coo:
17991799
fv = 3
@@ -1807,8 +1807,8 @@ def get_int64_dtype(dtype):
18071807
int64_dtype, layout, device, fv + 3, rg)
18081808
check_value(torch.full_like(v, fv + 4), dtype, layout, device, fv + 4, False)
18091809
check_value(torch.full_like(v, fv + 5,
1810-
dtype=int64_dtype, layout=layout, device=device, requires_grad=rg),
1811-
int64_dtype, layout, device, fv + 5, rg)
1810+
dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
1811+
int64_dtype, layout, device, fv + 5, False)
18121812

18131813
def test_empty_full(self):
18141814
self._test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, torch.device('cpu'))

0 commit comments

Comments
 (0)