@@ -3264,96 +3264,6 @@ def stuff3(x):
32643264 return torch.ones(x), x
32653265 self.checkScript(stuff3, ([3, 2],))
32663266
3267- def test_for_in_tensors(self):
3268- def test_sizes(x):
3269- sumz = 0
3270- for s in x:
3271- sumz += 1
3272- return sumz
3273- self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),))
3274- self.checkScript(test_sizes, (torch.rand(777),))
3275- self.checkScript(test_sizes, (torch.rand(0),))
3276-
3277- def test_for_in_tensors_rank0(self):
3278- with self.assertRaisesRegex(RuntimeError, "of a 0-d tensor"):
3279- @torch.jit.script
3280- def test_sizes(x):
3281- sumz = 0
3282- for s in x:
3283- sumz += 1
3284- return sumz
3285-
3286- test_sizes(torch.tensor(1))
3287-
3288- def test_for_in_tensors_fail_scalar(self):
3289- with self.assertRaisesRegex(RuntimeError, "'float' object is not iterable"):
3290- @torch.jit.script
3291- def test_sizes(x):
3292- # type: (float) -> int
3293- sumz = 0
3294- for s in x: # noqa
3295- sumz += 1
3296- return sumz
3297-
3298- test_sizes(0.0)
3299-
3300- def test_for_in_tensors_nested(self):
3301- def test_sizes(x):
3302- sumz = 0
3303- for n in x:
3304- for t in n:
3305- sumz += 1
3306- return sumz
3307-
3308- self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),))
3309-
3310- # to avoid defining sum_list in multiple tests
3311- def get_sum_list_fn(self):
3312- def sum_list(a):
3313- # type: (List[int]) -> int
3314- sum = 0
3315- for i in a:
3316- sum += i
3317-
3318- return sum
3319-
3320- return sum_list
3321-
3322- def test_sum_list_diff_elms(self):
3323- self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],))
3324-
3325- def test_sum_list_empty(self):
3326- self.checkScript(self.get_sum_list_fn(), ([],))
3327-
3328- def test_sum_list_one(self):
3329- self.checkScript(self.get_sum_list_fn(), ([1],))
3330-
3331- def test_sum_list_literal(self):
3332-
3333- def sum_list():
3334- # type: () -> int
3335- sum = 0
3336- for i in [1, 2, 3, 4, 5]:
3337- sum += i
3338-
3339- return sum
3340-
3341- self.checkScript(sum_list, ())
3342-
3343- def test_sum_list_wrong_type(self):
3344-
3345- with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
3346- @torch.jit.script
3347- def sum_list(a):
3348- # type: (int) -> int
3349- sum = 0
3350- for i in a: # noqa: T484
3351- sum += i
3352-
3353- return sum
3354-
3355- sum_list(1)
3356-
33573267 def test_bool_list_io(self):
33583268 @torch.jit.script
33593269 def stuff4(x):
@@ -5256,87 +5166,6 @@ def func(a, b):
52565166 inputs = self._make_scalar_vars([4321, 1234], torch.int64)
52575167 self.checkScript(func, inputs, optimize=True)
52585168
5259- def test_script_for_in_range(self):
5260- def fn():
5261- c = 0
5262- for i in range(100):
5263- c += i
5264- return c
5265- self.checkScript(fn, (), outputs=4950, optimize=True)
5266-
5267- def test_script_for_in_range_dynamic(self):
5268- def fn():
5269- c = 0
5270- for i in range(100):
5271- acc = 0
5272- for j in range(i):
5273- acc += j
5274- c += acc
5275- return c
5276- self.checkScript(fn, (), optimize=False)
5277-
5278- def test_script_for_in_range_ast(self):
5279- @torch.jit.script
5280- def test_script_for_in_range_ast():
5281- c = 0
5282- for i in range(100):
5283- acc = 0
5284- for j in range(i):
5285- acc += j
5286- c += acc
5287- return c
5288-
5289- self.assertEqual(test_script_for_in_range_ast(), 161700)
5290-
5291- def test_script_for_in_range_if_ast(self):
5292- @torch.jit.script
5293- def test_script_for_in_range_if_ast(x):
5294- output = x
5295- for i in range(20):
5296- if i == 0:
5297- output = x.unsqueeze(0)
5298- else:
5299- output = torch.cat((output, x.unsqueeze(0)), dim=0)
5300- return output
5301- inputs = self._make_scalar_vars([0], torch.int64)
5302-
5303- self.assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20)
5304-
5305- def test_script_for_in_range_start_end(self):
5306- def fn():
5307- x = 0
5308- for i in range(7, 100):
5309- x += i
5310- return x
5311- self.checkScript(fn, (), outputs=4929, optimize=True)
5312-
5313- def test_script_for_in_range_start_end_step(self):
5314- def fn(start, end, step):
5315- # type: (int, int, int) -> int
5316- x = 0
5317- for i in range(start, end, step):
5318- x += i
5319- return x
5320-
5321- def check(inp):
5322- self.checkScript(fn, inp, outputs=fn(*inp), optimize=True)
5323- check((7, 100, 7))
5324- check((7, 100, -7))
5325- check((2, -11, -3))
5326- check((2, -11, 3))
5327- check((2, 10, 3))
5328- check((-2, -10, -10))
5329-
5330- def test_script_for_zero_step(self):
5331- @torch.jit.script
5332- def fn():
5333- x = 0
5334- for i in range(2, -11, 0):
5335- x += i
5336- return x
5337- with self.assertRaisesRegex(RuntimeError, "must not be zero"):
5338- fn()
5339-
53405169 def test_script_optional_none(self):
53415170 def none_stmt(x):
53425171 output = None
@@ -9222,7 +9051,88 @@ def return_stmt(x):
92229051 return x
92239052 self.checkScript(return_stmt, (torch.rand(1),))
92249053
9225- def test_for_range_no_arg(self):
9054+ def test_for_in_range(self):
9055+ def fn():
9056+ c = 0
9057+ for i in range(100):
9058+ c += i
9059+ return c
9060+ self.checkScript(fn, (), outputs=4950, optimize=True)
9061+
9062+ def test_for_in_range_dynamic(self):
9063+ def fn():
9064+ c = 0
9065+ for i in range(100):
9066+ acc = 0
9067+ for j in range(i):
9068+ acc += j
9069+ c += acc
9070+ return c
9071+ self.checkScript(fn, (), optimize=False)
9072+
9073+ def test_for_in_range_ast(self):
9074+ @torch.jit.script
9075+ def test_script_for_in_range_ast():
9076+ c = 0
9077+ for i in range(100):
9078+ acc = 0
9079+ for j in range(i):
9080+ acc += j
9081+ c += acc
9082+ return c
9083+
9084+ self.assertEqual(test_script_for_in_range_ast(), 161700)
9085+
9086+ def test_for_in_range_if_ast(self):
9087+ @torch.jit.script
9088+ def test_script_for_in_range_if_ast(x):
9089+ output = x
9090+ for i in range(20):
9091+ if i == 0:
9092+ output = x.unsqueeze(0)
9093+ else:
9094+ output = torch.cat((output, x.unsqueeze(0)), dim=0)
9095+ return output
9096+ inputs = self._make_scalar_vars([0], torch.int64)
9097+
9098+ self.assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20)
9099+
9100+ def test_for_in_range_start_end(self):
9101+ def fn():
9102+ x = 0
9103+ for i in range(7, 100):
9104+ x += i
9105+ return x
9106+ self.checkScript(fn, (), outputs=4929, optimize=True)
9107+
9108+ def test_for_in_range_start_end_step(self):
9109+ def fn(start, end, step):
9110+ # type: (int, int, int) -> int
9111+ x = 0
9112+ for i in range(start, end, step):
9113+ x += i
9114+ return x
9115+
9116+ def check(inp):
9117+ self.checkScript(fn, inp, outputs=fn(*inp), optimize=True)
9118+ check((7, 100, 7))
9119+ check((7, 100, -7))
9120+ check((2, -11, -3))
9121+ check((2, -11, 3))
9122+ check((2, 10, 3))
9123+ check((-2, -10, -10))
9124+
9125+ def test_for_in_range_zero_step(self):
9126+ @torch.jit.script
9127+ def fn():
9128+ x = 0
9129+ for i in range(2, -11, 0):
9130+ x += i
9131+ return x
9132+ with self.assertRaisesRegex(RuntimeError, "must not be zero"):
9133+ fn()
9134+
9135+ def test_for_in_range_no_arg(self):
92269136 with self.assertRaisesRegex(RuntimeError, r'range expected at least 1 arguments, got 0'):
92279137 @torch.jit.script
92289138 def range_no_arg(x):
@@ -9353,6 +9263,96 @@ def fn_enumerate_zip(x, y):
93539263
93549264 self.checkScript(fn_enumerate_zip, ([1, 2, 3, 4], [2, 3, 4, 5]))
93559265
9266+ def test_for_in_tensors(self):
9267+ def test_sizes(x):
9268+ sumz = 0
9269+ for s in x:
9270+ sumz += 1
9271+ return sumz
9272+ self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),))
9273+ self.checkScript(test_sizes, (torch.rand(777),))
9274+ self.checkScript(test_sizes, (torch.rand(0),))
9275+
9276+ def test_for_in_tensors_rank0(self):
9277+ with self.assertRaisesRegex(RuntimeError, "of a 0-d tensor"):
9278+ @torch.jit.script
9279+ def test_sizes(x):
9280+ sumz = 0
9281+ for s in x:
9282+ sumz += 1
9283+ return sumz
9284+
9285+ test_sizes(torch.tensor(1))
9286+
9287+ def test_for_in_tensors_fail_scalar(self):
9288+ with self.assertRaisesRegex(RuntimeError, "'float' object is not iterable"):
9289+ @torch.jit.script
9290+ def test_sizes(x):
9291+ # type: (float) -> int
9292+ sumz = 0
9293+ for s in x: # noqa
9294+ sumz += 1
9295+ return sumz
9296+
9297+ test_sizes(0.0)
9298+
9299+ def test_for_in_tensors_nested(self):
9300+ def test_sizes(x):
9301+ sumz = 0
9302+ for n in x:
9303+ for t in n:
9304+ sumz += 1
9305+ return sumz
9306+
9307+ self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),))
9308+
9309+ # to avoid defining sum_list in multiple tests
9310+ def get_sum_list_fn(self):
9311+ def sum_list(a):
9312+ # type: (List[int]) -> int
9313+ sum = 0
9314+ for i in a:
9315+ sum += i
9316+
9317+ return sum
9318+
9319+ return sum_list
9320+
9321+ def test_sum_list_diff_elms(self):
9322+ self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],))
9323+
9324+ def test_sum_list_empty(self):
9325+ self.checkScript(self.get_sum_list_fn(), ([],))
9326+
9327+ def test_sum_list_one(self):
9328+ self.checkScript(self.get_sum_list_fn(), ([1],))
9329+
9330+ def test_sum_list_literal(self):
9331+
9332+ def sum_list():
9333+ # type: () -> int
9334+ sum = 0
9335+ for i in [1, 2, 3, 4, 5]:
9336+ sum += i
9337+
9338+ return sum
9339+
9340+ self.checkScript(sum_list, ())
9341+
9342+ def test_sum_list_wrong_type(self):
9343+
9344+ with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
9345+ @torch.jit.script
9346+ def sum_list(a):
9347+ # type: (int) -> int
9348+ sum = 0
9349+ for i in a: # noqa: T484
9350+ sum += i
9351+
9352+ return sum
9353+
9354+ sum_list(1)
9355+
93569356 def test_list_iterables(self):
93579357 with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
93589358 cu = torch.jit.CompilationUnit('''
0 commit comments