Skip to content

Commit a6bfa16

Browse files
authored
torch.arange: add numpy-style type inference. (#7016)
* torch.arange: add numpy-style type inference. This is a backwards-compatibility breaking change. * Fix flake8. * Use at::optional. * Remove unneeded header files. * Use reference wrapper. * Update arange for test. * Address review comments.
1 parent bdd27ea commit a6bfa16

File tree

11 files changed

+191
-63
lines changed

11 files changed

+191
-63
lines changed

test/test_autograd.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ def backward(self, dy):
638638
self.assertFalse(y.requires_grad)
639639

640640
def test_indexing(self):
641-
x = torch.arange(1, 17).view(4, 4)
641+
x = torch.arange(1., 17).view(4, 4)
642642
y = Variable(x, requires_grad=True)
643643

644644
def compare(x, y, idx, indexed_tensor, indexed_var):
@@ -681,7 +681,7 @@ def check_index(x, y, idx):
681681
check_index(x, y, ([0]))
682682
check_index(x, y, ([0], ))
683683

684-
x = torch.arange(1, 49).view(4, 3, 4)
684+
x = torch.arange(1., 49).view(4, 3, 4)
685685
y = Variable(x, requires_grad=True)
686686

687687
check_index(x, y, (slice(None), [0], [0]))
@@ -717,7 +717,7 @@ def check_index(x, y, idx):
717717
compare(x, y, seq, indexed_tensor, indexed_var)
718718

719719
def test_indexing_duplicates(self):
720-
x = torch.arange(1, 17).view(4, 4)
720+
x = torch.arange(1., 17).view(4, 4)
721721
y = Variable(x, requires_grad=True)
722722

723723
idx = torch.LongTensor([1, 1, 3, 2, 1, 2])
@@ -728,7 +728,7 @@ def test_indexing_duplicates(self):
728728
self.assertEqual(y.grad.data, expected_grad)
729729

730730
# with advanced indexing
731-
x = torch.arange(1, 17).view(4, 4)
731+
x = torch.arange(1., 17).view(4, 4)
732732
y = Variable(x, requires_grad=True)
733733

734734
idx = [[1, 1, 3, 2, 1, 2], [0]]
@@ -740,7 +740,7 @@ def test_indexing_duplicates(self):
740740

741741
self.assertEqual(y.grad.data, expected_grad)
742742

743-
x = torch.arange(1, 17).view(4, 4)
743+
x = torch.arange(1., 17).view(4, 4)
744744
y = Variable(x, requires_grad=True)
745745
idx = [[[1, 2], [0, 0]], [[0, 1], [1, 1]]]
746746
y[idx].sum().backward()
@@ -750,7 +750,7 @@ def test_indexing_duplicates(self):
750750
[0, 0, 0, 0]])
751751
self.assertEqual(y.grad.data, expected_grad)
752752

753-
x = torch.arange(1, 65).view(4, 4, 4)
753+
x = torch.arange(1., 65).view(4, 4, 4)
754754
y = Variable(x, requires_grad=True)
755755

756756
idx = [[1, 1, 1], slice(None), slice(None)]
@@ -1952,7 +1952,7 @@ def test_dir(self):
19521952
self.assertTrue(hasattr(x, key))
19531953

19541954
def test_as_strided(self):
1955-
x = Variable(torch.arange(0, 25).view(5, 5), requires_grad=True)
1955+
x = Variable(torch.arange(0., 25).view(5, 5), requires_grad=True)
19561956

19571957
def as_strided(x):
19581958
return x.as_strided([3, 3], [6, 2], 2)
@@ -2253,7 +2253,7 @@ def make_nonzero_det(A, sign=None, min_singular_value=0.1):
22532253
def random_fullrank_matrix_distinct_singular_value(l):
22542254
A = torch.randn(l, l)
22552255
u, _, v = A.svd()
2256-
s = torch.arange(1, l + 1).mul_(1.0 / (l + 1))
2256+
s = torch.arange(1., l + 1).mul_(1.0 / (l + 1))
22572257
return u.mm(torch.diag(s)).mm(v.t())
22582258

22592259

test/test_indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def test_int_assignment(self):
218218
self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
219219

220220
def test_byte_tensor_assignment(self):
221-
x = torch.arange(0, 16).view(4, 4)
221+
x = torch.arange(0., 16).view(4, 4)
222222
b = torch.ByteTensor([True, False, True, False])
223223
value = torch.tensor([3., 4., 5., 6.])
224224
x[b] = value
@@ -475,7 +475,7 @@ def test_index_is_larger(self):
475475

476476
def test_broadcast_subspace(self):
477477
a = torch.zeros((100, 100))
478-
v = torch.arange(0, 100)[:, None]
478+
v = torch.arange(0., 100)[:, None]
479479
b = torch.arange(99, -1, -1).long()
480480
a[b] = v
481481
expected = b.double().unsqueeze(1).expand(100, 100)

test/test_jit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,8 +1890,8 @@ def func(x, y):
18901890
w = -q
18911891
return w * w
18921892

1893-
x = torch.arange(4, requires_grad=True)
1894-
y = torch.arange(0, 8, 2, requires_grad=True)
1893+
x = torch.arange(4., requires_grad=True)
1894+
y = torch.arange(0., 8, 2, requires_grad=True)
18951895
self.checkScript(func, [x, y], optimize=True, capture_output=True)
18961896

18971897
def test_multiple_assignment(self):
@@ -2041,7 +2041,7 @@ def fn(x, slope):
20412041
c = F.prelu(x, slope)
20422042
return a, b, c
20432043

2044-
x = torch.arange(-3, 4)
2044+
x = torch.arange(-3., 4)
20452045
slope = torch.tensor([0.5])
20462046
self.checkScript(fn, [x, slope], optimize=True)
20472047

test/test_multiprocessing.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def autograd_sharing(queue, ready, master_modified):
9292
ready.set()
9393
master_modified.wait()
9494

95-
expected_var = torch.arange(1, 26).view(5, 5)
95+
expected_var = torch.arange(1., 26).view(5, 5)
9696
expected_var[0, 0] = 1000
9797
is_ok = var.data.equal(expected_var)
9898
var.data[:] = torch.ones(5, 5)
@@ -314,7 +314,7 @@ def test_cuda_small_tensors(self):
314314
tensors = []
315315
for i in range(5):
316316
device = i % 2
317-
tensors += [torch.arange(i * 5, (i + 1) * 5).cuda(device)]
317+
tensors += [torch.arange(i * 5., (i + 1) * 5).cuda(device)]
318318

319319
inq = ctx.Queue()
320320
outq = ctx.Queue()
@@ -329,7 +329,7 @@ def test_cuda_small_tensors(self):
329329

330330
for i, tensor in enumerate(tensors):
331331
v, device, tensor_size, storage_size = results[i]
332-
self.assertEqual(v, torch.arange(i * 5, (i + 1) * 5).sum())
332+
self.assertEqual(v, torch.arange(i * 5., (i + 1) * 5).sum())
333333
self.assertEqual(device, i % 2)
334334
self.assertEqual(tensor_size, 5)
335335
self.assertEqual(storage_size, 5)
@@ -412,12 +412,12 @@ def _test_autograd_sharing(self, var):
412412

413413
def test_variable_sharing(self):
414414
for requires_grad in [True, False]:
415-
var = Variable(torch.arange(1, 26).view(5, 5),
415+
var = Variable(torch.arange(1., 26).view(5, 5),
416416
requires_grad=requires_grad)
417417
self._test_autograd_sharing(var)
418418

419419
def test_parameter_sharing(self):
420-
param = Parameter(torch.arange(1, 26).view(5, 5))
420+
param = Parameter(torch.arange(1., 26).view(5, 5))
421421
self._test_autograd_sharing(param)
422422

423423
def test_empty_shared(self):

test/test_nn.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,7 +1237,7 @@ def compare_scaling(grads):
12371237
self.assertEqual(scale.std(), 0)
12381238
return scale[0]
12391239

1240-
grads = torch.arange(1, 101).view(10, 10), torch.ones(10).div(1000)
1240+
grads = torch.arange(1., 101).view(10, 10), torch.ones(10).div(1000)
12411241
for norm_type in [0.5, 1.5, 2, 4, 'inf']:
12421242
for p, g in zip(l.parameters(), grads):
12431243
p._grad = Variable(g.clone().view_as(p.data))
@@ -1267,7 +1267,7 @@ def test_clip_grad_value(self):
12671267
l = nn.Linear(10, 10)
12681268
clip_value = 2.5
12691269

1270-
grad_w, grad_b = torch.arange(-50, 50).view(10, 10).div_(5), torch.ones(10).mul_(2)
1270+
grad_w, grad_b = torch.arange(-50., 50).view(10, 10).div_(5), torch.ones(10).mul_(2)
12711271
for grad_list in [[grad_w, grad_b], [grad_w, None]]:
12721272
for p, g in zip(l.parameters(), grad_list):
12731273
p._grad = g.clone().view_as(p.data) if g is not None else g
@@ -1290,7 +1290,7 @@ def test_vector_to_parameters(self):
12901290
fc1 = nn.Linear(10, 20)
12911291
model = nn.Sequential(conv1, fc1)
12921292

1293-
vec = Variable(torch.arange(0, 980))
1293+
vec = Variable(torch.arange(0., 980))
12941294
vector_to_parameters(vec, model.parameters())
12951295

12961296
sample = next(model.parameters())[0, 0, 0]
@@ -3191,10 +3191,10 @@ def pad(tensor, length):
31913191
max_length = lengths[0]
31923192
batch_sizes = [sum(map(bool, filter(lambda x: x >= i, lengths))) for i in range(1, max_length + 1)]
31933193
offset = 0
3194-
padded = torch.cat([pad(i * 100 + torch.arange(1, 5 * l + 1).view(l, 1, 5), max_length)
3194+
padded = torch.cat([pad(i * 100 + torch.arange(1., 5 * l + 1).view(l, 1, 5), max_length)
31953195
for i, l in enumerate(lengths, 1)], 1)
31963196
padded = torch.tensor(padded, requires_grad=True)
3197-
expected_data = [[torch.arange(1, 6) + (i + 1) * 100 + 5 * n for i in range(batch_size)]
3197+
expected_data = [[torch.arange(1., 6) + (i + 1) * 100 + 5 * n for i in range(batch_size)]
31983198
for n, batch_size in enumerate(batch_sizes)]
31993199
expected_data = list(itertools.chain.from_iterable(expected_data))
32003200
expected_data = torch.stack(expected_data, dim=0)
@@ -4320,7 +4320,7 @@ def test_shape(N, C, IH, IW, H, W, padding_mode):
43204320
# test known input on CPU
43214321
for padding_mode in ['zeros', 'border']:
43224322

4323-
input = Variable(torch.arange(1, 11).view(1, 1, 2, 5))
4323+
input = Variable(torch.arange(1., 11).view(1, 1, 2, 5))
43244324
grid = Variable(torch.Tensor(
43254325
[[-0.9, -1.4, 0, 0.2, 1],
43264326
[-1, -0.333, 0, 0.5, 1],
@@ -4430,7 +4430,7 @@ def test_shape(N, C, ID, IH, IW, D, H, W, padding_mode):
44304430

44314431
def test_affine_grid(self):
44324432
# test known input on CPU
4433-
input = Variable(torch.arange(1, 7).view(1, 2, 3))
4433+
input = Variable(torch.arange(1., 7).view(1, 2, 3))
44344434
output = F.affine_grid(input, torch.Size([1, 1, 2, 2]))
44354435
groundtruth = torch.Tensor(
44364436
[[[0, -3], [2, 5]], [[4, 7], [6, 15]]]).view(1, 2, 2, 2)

test/test_torch.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,7 @@ def test_remainder(self):
998998
long_m1 = torch.LongTensor(10, 10).random_(-10, 10)
999999
long_res1 = long_m1.clone()
10001000
long_res2 = long_m1.clone()
1001-
long_qs = torch.arange(-5, 5).long()
1001+
long_qs = torch.arange(-5, 5)
10021002
long_qs[5] = 5 # Can't handle the divisor=0 case
10031003
for col_idx, long_q in enumerate(long_qs):
10041004
# Reference
@@ -2313,6 +2313,39 @@ def test_arange(self):
23132313
self.assertEqual(r1, r2, 0)
23142314
self.assertEqual(r2, r3[:-1], 0)
23152315

2316+
def test_arange_inference(self):
2317+
saved_dtype = torch.get_default_dtype()
2318+
torch.set_default_dtype(torch.float32)
2319+
# end only
2320+
self.assertIs(torch.float32, torch.arange(1.).dtype)
2321+
self.assertIs(torch.float32, torch.arange(torch.tensor(1.)).dtype)
2322+
self.assertIs(torch.float32, torch.arange(torch.tensor(1., dtype=torch.float64)).dtype)
2323+
2324+
self.assertIs(torch.int64, torch.arange(1).dtype)
2325+
self.assertIs(torch.int64, torch.arange(torch.tensor(1)).dtype)
2326+
self.assertIs(torch.int64, torch.arange(torch.tensor(1, dtype=torch.int16)).dtype)
2327+
2328+
# start, end, [step]
2329+
self.assertIs(torch.float32, torch.arange(1., 3).dtype)
2330+
self.assertIs(torch.float32, torch.arange(torch.tensor(1., dtype=torch.float64), 3).dtype)
2331+
self.assertIs(torch.float32, torch.arange(1, 3.).dtype)
2332+
self.assertIs(torch.float32, torch.arange(torch.tensor(1, dtype=torch.int16), torch.tensor(3.)).dtype)
2333+
self.assertIs(torch.float32, torch.arange(1, 3, 1.).dtype)
2334+
self.assertIs(torch.float32,
2335+
torch.arange(torch.tensor(1),
2336+
torch.tensor(3, dtype=torch.int16),
2337+
torch.tensor(1., dtype=torch.float64)).dtype)
2338+
2339+
self.assertIs(torch.int64, torch.arange(1, 3).dtype)
2340+
self.assertIs(torch.int64, torch.arange(torch.tensor(1), 3).dtype)
2341+
self.assertIs(torch.int64, torch.arange(torch.tensor(1), torch.tensor(3, dtype=torch.int16)).dtype)
2342+
self.assertIs(torch.int64, torch.arange(1, 3, 1).dtype)
2343+
self.assertIs(torch.int64,
2344+
torch.arange(torch.tensor(1),
2345+
torch.tensor(3),
2346+
torch.tensor(1, dtype=torch.int16)).dtype)
2347+
torch.set_default_dtype(saved_dtype)
2348+
23162349
@staticmethod
23172350
def _select_broadcastable_dims(dims_full=None):
23182351
# select full dimensionality
@@ -2883,7 +2916,7 @@ def test_median(self):
28832916
self.assertEqual(x, x0, 0)
28842917

28852918
def test_mode(self):
2886-
x = torch.arange(1, SIZE * SIZE + 1).clone().resize_(SIZE, SIZE)
2919+
x = torch.arange(1., SIZE * SIZE + 1).clone().resize_(SIZE, SIZE)
28872920
x[:2] = 1
28882921
x[:, :2] = 1
28892922
x0 = x.clone()
@@ -3119,7 +3152,7 @@ def test_randn(self):
31193152

31203153
def test_slice(self):
31213154
empty = torch.Tensor()
3122-
x = torch.arange(0, 16).view(4, 4)
3155+
x = torch.arange(0., 16).view(4, 4)
31233156
self.assertEqual(x.slice(), x)
31243157
self.assertEqual(x.slice(0, 0, 4), x)
31253158
# start and stop are clamped to the size of dim
@@ -3914,7 +3947,7 @@ def naive_stft(x, frame_length, hop, fft_size=None, normalized=False,
39143947
return_size = fft_size
39153948
result = x.new(batch, int((length - frame_length) / float(hop)) + 1, return_size, 2)
39163949
for w in range(return_size): # freq
3917-
radians = torch.arange(frame_length) * w * 2 * math.pi / fft_size
3950+
radians = torch.arange(float(frame_length)) * w * 2 * math.pi / fft_size
39183951
radians = radians.type_as(x)
39193952
re_kernel = radians.cos().mul_(window)
39203953
im_kernel = -radians.sin().mul_(window)
@@ -4576,7 +4609,7 @@ def ri(indices):
45764609
# strided is [[1 3 5 7],
45774610
# [9 11 13 15]]
45784611

4579-
reference = conv_fn(torch.arange(0, 24).view(3, 8))
4612+
reference = conv_fn(torch.arange(0., 24).view(3, 8))
45804613
strided = conv_fn(torch.Tensor())
45814614
strided.set_(reference.storage(), 1, size=torch.Size([2, 4]),
45824615
stride=[8, 2])
@@ -4614,15 +4647,15 @@ def ri(indices):
46144647
# strided is [[10, 11],
46154648
# [17, 18]]
46164649

4617-
reference = conv_fn(torch.arange(0, 24).view(3, 8))
4650+
reference = conv_fn(torch.arange(0., 24).view(3, 8))
46184651
strided = conv_fn(torch.Tensor())
46194652
strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
46204653
stride=[7, 1])
46214654
self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([11]))
46224655
strided[ri([0]), ri([1])] = -1
46234656
self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([-1]))
46244657

4625-
reference = conv_fn(torch.arange(0, 24).view(3, 8))
4658+
reference = conv_fn(torch.arange(0., 24).view(3, 8))
46264659
strided = conv_fn(torch.Tensor())
46274660
strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
46284661
stride=[7, 1])
@@ -4632,7 +4665,7 @@ def ri(indices):
46324665
self.assertEqual(strided[ri([0, 1]), ri([1, 0])], torch.Tensor([-1,
46334666
2]))
46344667

4635-
reference = conv_fn(torch.arange(0, 24).view(3, 8))
4668+
reference = conv_fn(torch.arange(0., 24).view(3, 8))
46364669
strided = conv_fn(torch.Tensor())
46374670
strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
46384671
stride=[7, 1])
@@ -4727,7 +4760,7 @@ def get_set_tensor(indexed, indexer):
47274760
# 5 6 7 8 9
47284761
# 10 11 12 13 14
47294762
# 15 16 17 18 19
4730-
reference = conv_fn(torch.arange(0, 20).view(4, 5))
4763+
reference = conv_fn(torch.arange(0., 20).view(4, 5))
47314764

47324765
indices_to_test = [
47334766
# grab the second, fourth columns
@@ -4753,7 +4786,7 @@ def get_set_tensor(indexed, indexer):
47534786
indexer,
47544787
get_set_tensor(reference, indexer))
47554788

4756-
reference = conv_fn(torch.arange(0, 160).view(4, 8, 5))
4789+
reference = conv_fn(torch.arange(0., 160).view(4, 8, 5))
47574790

47584791
indices_to_test = [
47594792
[slice(None), slice(None), [0, 3, 4]],
@@ -4804,7 +4837,7 @@ def get_set_tensor(indexed, indexer):
48044837
indexer,
48054838
get_set_tensor(reference, indexer))
48064839

4807-
reference = conv_fn(torch.arange(0, 1296).view(3, 9, 8, 6))
4840+
reference = conv_fn(torch.arange(0., 1296).view(3, 9, 8, 6))
48084841

48094842
indices_to_test = [
48104843
[slice(None), slice(None), slice(None), [0, 3, 4]],

tools/autograd/gen_python_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
'.*_forward_out', 'sparse_raw_resize_', '_unsafe_view', 'tensor',
2020
'sparse_coo_tensor', '_arange.*', '_range.*', '_linspace.*', '_logspace.*',
2121
'_indexCopy_', 'max_values', 'min_values', 'argmax', 'argmin',
22-
'_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', '_th_sum.*', '_th_prod.*',
22+
'_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', '_th_sum.*', '_th_prod.*', 'arange.*',
2323
]
2424

2525
PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')

0 commit comments

Comments
 (0)