Skip to content

Commit 45b91bd

Browse files
wanchaolfacebook-github-bot
authored andcommitted
refactor all for in range/tensor tests to be together with other for loop tests (#21950)
Summary: Pull Request resolved: #21950 ghimport-source-id: b249131 Test Plan: Imported from OSS Differential Revision: D15948546 Pulled By: wanchaol fbshipit-source-id: 34dde28902ae5b8affbf6e4deaaffdb1d8ddd6ec
1 parent e0f5ab2 commit 45b91bd

File tree

1 file changed

+172
-172
lines changed

1 file changed

+172
-172
lines changed

test/test_jit.py

Lines changed: 172 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -3264,96 +3264,6 @@ def stuff3(x):
32643264
return torch.ones(x), x
32653265
self.checkScript(stuff3, ([3, 2],))
32663266

3267-
def test_for_in_tensors(self):
3268-
def test_sizes(x):
3269-
sumz = 0
3270-
for s in x:
3271-
sumz += 1
3272-
return sumz
3273-
self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),))
3274-
self.checkScript(test_sizes, (torch.rand(777),))
3275-
self.checkScript(test_sizes, (torch.rand(0),))
3276-
3277-
def test_for_in_tensors_rank0(self):
3278-
with self.assertRaisesRegex(RuntimeError, "of a 0-d tensor"):
3279-
@torch.jit.script
3280-
def test_sizes(x):
3281-
sumz = 0
3282-
for s in x:
3283-
sumz += 1
3284-
return sumz
3285-
3286-
test_sizes(torch.tensor(1))
3287-
3288-
def test_for_in_tensors_fail_scalar(self):
3289-
with self.assertRaisesRegex(RuntimeError, "'float' object is not iterable"):
3290-
@torch.jit.script
3291-
def test_sizes(x):
3292-
# type: (float) -> int
3293-
sumz = 0
3294-
for s in x: # noqa
3295-
sumz += 1
3296-
return sumz
3297-
3298-
test_sizes(0.0)
3299-
3300-
def test_for_in_tensors_nested(self):
3301-
def test_sizes(x):
3302-
sumz = 0
3303-
for n in x:
3304-
for t in n:
3305-
sumz += 1
3306-
return sumz
3307-
3308-
self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),))
3309-
3310-
# to avoid defining sum_list in multiple tests
3311-
def get_sum_list_fn(self):
3312-
def sum_list(a):
3313-
# type: (List[int]) -> int
3314-
sum = 0
3315-
for i in a:
3316-
sum += i
3317-
3318-
return sum
3319-
3320-
return sum_list
3321-
3322-
def test_sum_list_diff_elms(self):
3323-
self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],))
3324-
3325-
def test_sum_list_empty(self):
3326-
self.checkScript(self.get_sum_list_fn(), ([],))
3327-
3328-
def test_sum_list_one(self):
3329-
self.checkScript(self.get_sum_list_fn(), ([1],))
3330-
3331-
def test_sum_list_literal(self):
3332-
3333-
def sum_list():
3334-
# type: () -> int
3335-
sum = 0
3336-
for i in [1, 2, 3, 4, 5]:
3337-
sum += i
3338-
3339-
return sum
3340-
3341-
self.checkScript(sum_list, ())
3342-
3343-
def test_sum_list_wrong_type(self):
3344-
3345-
with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
3346-
@torch.jit.script
3347-
def sum_list(a):
3348-
# type: (int) -> int
3349-
sum = 0
3350-
for i in a: # noqa: T484
3351-
sum += i
3352-
3353-
return sum
3354-
3355-
sum_list(1)
3356-
33573267
def test_bool_list_io(self):
33583268
@torch.jit.script
33593269
def stuff4(x):
@@ -5256,87 +5166,6 @@ def func(a, b):
52565166
inputs = self._make_scalar_vars([4321, 1234], torch.int64)
52575167
self.checkScript(func, inputs, optimize=True)
52585168

5259-
def test_script_for_in_range(self):
5260-
def fn():
5261-
c = 0
5262-
for i in range(100):
5263-
c += i
5264-
return c
5265-
self.checkScript(fn, (), outputs=4950, optimize=True)
5266-
5267-
def test_script_for_in_range_dynamic(self):
5268-
def fn():
5269-
c = 0
5270-
for i in range(100):
5271-
acc = 0
5272-
for j in range(i):
5273-
acc += j
5274-
c += acc
5275-
return c
5276-
self.checkScript(fn, (), optimize=False)
5277-
5278-
def test_script_for_in_range_ast(self):
5279-
@torch.jit.script
5280-
def test_script_for_in_range_ast():
5281-
c = 0
5282-
for i in range(100):
5283-
acc = 0
5284-
for j in range(i):
5285-
acc += j
5286-
c += acc
5287-
return c
5288-
5289-
self.assertEqual(test_script_for_in_range_ast(), 161700)
5290-
5291-
def test_script_for_in_range_if_ast(self):
5292-
@torch.jit.script
5293-
def test_script_for_in_range_if_ast(x):
5294-
output = x
5295-
for i in range(20):
5296-
if i == 0:
5297-
output = x.unsqueeze(0)
5298-
else:
5299-
output = torch.cat((output, x.unsqueeze(0)), dim=0)
5300-
return output
5301-
inputs = self._make_scalar_vars([0], torch.int64)
5302-
5303-
self.assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20)
5304-
5305-
def test_script_for_in_range_start_end(self):
5306-
def fn():
5307-
x = 0
5308-
for i in range(7, 100):
5309-
x += i
5310-
return x
5311-
self.checkScript(fn, (), outputs=4929, optimize=True)
5312-
5313-
def test_script_for_in_range_start_end_step(self):
5314-
def fn(start, end, step):
5315-
# type: (int, int, int) -> int
5316-
x = 0
5317-
for i in range(start, end, step):
5318-
x += i
5319-
return x
5320-
5321-
def check(inp):
5322-
self.checkScript(fn, inp, outputs=fn(*inp), optimize=True)
5323-
check((7, 100, 7))
5324-
check((7, 100, -7))
5325-
check((2, -11, -3))
5326-
check((2, -11, 3))
5327-
check((2, 10, 3))
5328-
check((-2, -10, -10))
5329-
5330-
def test_script_for_zero_step(self):
5331-
@torch.jit.script
5332-
def fn():
5333-
x = 0
5334-
for i in range(2, -11, 0):
5335-
x += i
5336-
return x
5337-
with self.assertRaisesRegex(RuntimeError, "must not be zero"):
5338-
fn()
5339-
53405169
def test_script_optional_none(self):
53415170
def none_stmt(x):
53425171
output = None
@@ -9222,7 +9051,88 @@ def return_stmt(x):
92229051
return x
92239052
self.checkScript(return_stmt, (torch.rand(1),))
92249053

9225-
def test_for_range_no_arg(self):
9054+
def test_for_in_range(self):
9055+
def fn():
9056+
c = 0
9057+
for i in range(100):
9058+
c += i
9059+
return c
9060+
self.checkScript(fn, (), outputs=4950, optimize=True)
9061+
9062+
def test_for_in_range_dynamic(self):
9063+
def fn():
9064+
c = 0
9065+
for i in range(100):
9066+
acc = 0
9067+
for j in range(i):
9068+
acc += j
9069+
c += acc
9070+
return c
9071+
self.checkScript(fn, (), optimize=False)
9072+
9073+
def test_for_in_range_ast(self):
9074+
@torch.jit.script
9075+
def test_script_for_in_range_ast():
9076+
c = 0
9077+
for i in range(100):
9078+
acc = 0
9079+
for j in range(i):
9080+
acc += j
9081+
c += acc
9082+
return c
9083+
9084+
self.assertEqual(test_script_for_in_range_ast(), 161700)
9085+
9086+
def test_for_in_range_if_ast(self):
9087+
@torch.jit.script
9088+
def test_script_for_in_range_if_ast(x):
9089+
output = x
9090+
for i in range(20):
9091+
if i == 0:
9092+
output = x.unsqueeze(0)
9093+
else:
9094+
output = torch.cat((output, x.unsqueeze(0)), dim=0)
9095+
return output
9096+
inputs = self._make_scalar_vars([0], torch.int64)
9097+
9098+
self.assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20)
9099+
9100+
def test_for_in_range_start_end(self):
9101+
def fn():
9102+
x = 0
9103+
for i in range(7, 100):
9104+
x += i
9105+
return x
9106+
self.checkScript(fn, (), outputs=4929, optimize=True)
9107+
9108+
def test_for_in_range_start_end_step(self):
9109+
def fn(start, end, step):
9110+
# type: (int, int, int) -> int
9111+
x = 0
9112+
for i in range(start, end, step):
9113+
x += i
9114+
return x
9115+
9116+
def check(inp):
9117+
self.checkScript(fn, inp, outputs=fn(*inp), optimize=True)
9118+
check((7, 100, 7))
9119+
check((7, 100, -7))
9120+
check((2, -11, -3))
9121+
check((2, -11, 3))
9122+
check((2, 10, 3))
9123+
check((-2, -10, -10))
9124+
9125+
def test_for_in_range_zero_step(self):
9126+
@torch.jit.script
9127+
def fn():
9128+
x = 0
9129+
for i in range(2, -11, 0):
9130+
x += i
9131+
return x
9132+
with self.assertRaisesRegex(RuntimeError, "must not be zero"):
9133+
fn()
9134+
9135+
def test_for_in_range_no_arg(self):
92269136
with self.assertRaisesRegex(RuntimeError, r'range expected at least 1 arguments, got 0'):
92279137
@torch.jit.script
92289138
def range_no_arg(x):
@@ -9353,6 +9263,96 @@ def fn_enumerate_zip(x, y):
93539263

93549264
self.checkScript(fn_enumerate_zip, ([1, 2, 3, 4], [2, 3, 4, 5]))
93559265

9266+
def test_for_in_tensors(self):
9267+
def test_sizes(x):
9268+
sumz = 0
9269+
for s in x:
9270+
sumz += 1
9271+
return sumz
9272+
self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),))
9273+
self.checkScript(test_sizes, (torch.rand(777),))
9274+
self.checkScript(test_sizes, (torch.rand(0),))
9275+
9276+
def test_for_in_tensors_rank0(self):
9277+
with self.assertRaisesRegex(RuntimeError, "of a 0-d tensor"):
9278+
@torch.jit.script
9279+
def test_sizes(x):
9280+
sumz = 0
9281+
for s in x:
9282+
sumz += 1
9283+
return sumz
9284+
9285+
test_sizes(torch.tensor(1))
9286+
9287+
def test_for_in_tensors_fail_scalar(self):
9288+
with self.assertRaisesRegex(RuntimeError, "'float' object is not iterable"):
9289+
@torch.jit.script
9290+
def test_sizes(x):
9291+
# type: (float) -> int
9292+
sumz = 0
9293+
for s in x: # noqa
9294+
sumz += 1
9295+
return sumz
9296+
9297+
test_sizes(0.0)
9298+
9299+
def test_for_in_tensors_nested(self):
9300+
def test_sizes(x):
9301+
sumz = 0
9302+
for n in x:
9303+
for t in n:
9304+
sumz += 1
9305+
return sumz
9306+
9307+
self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),))
9308+
9309+
# to avoid defining sum_list in multiple tests
9310+
def get_sum_list_fn(self):
9311+
def sum_list(a):
9312+
# type: (List[int]) -> int
9313+
sum = 0
9314+
for i in a:
9315+
sum += i
9316+
9317+
return sum
9318+
9319+
return sum_list
9320+
9321+
def test_sum_list_diff_elms(self):
9322+
self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],))
9323+
9324+
def test_sum_list_empty(self):
9325+
self.checkScript(self.get_sum_list_fn(), ([],))
9326+
9327+
def test_sum_list_one(self):
9328+
self.checkScript(self.get_sum_list_fn(), ([1],))
9329+
9330+
def test_sum_list_literal(self):
9331+
9332+
def sum_list():
9333+
# type: () -> int
9334+
sum = 0
9335+
for i in [1, 2, 3, 4, 5]:
9336+
sum += i
9337+
9338+
return sum
9339+
9340+
self.checkScript(sum_list, ())
9341+
9342+
def test_sum_list_wrong_type(self):
9343+
9344+
with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
9345+
@torch.jit.script
9346+
def sum_list(a):
9347+
# type: (int) -> int
9348+
sum = 0
9349+
for i in a: # noqa: T484
9350+
sum += i
9351+
9352+
return sum
9353+
9354+
sum_list(1)
9355+
93569356
def test_list_iterables(self):
93579357
with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
93589358
cu = torch.jit.CompilationUnit('''

0 commit comments

Comments
 (0)