33import torch .nn as nn
44from parameterized import parameterized
55from torch .testing ._internal .common_utils import run_tests
6- from torch_tensorrt .dynamo .conversion import UnsupportedOperatorException
7-
86from torch_tensorrt import Input
7+ from torch_tensorrt .dynamo .conversion import UnsupportedOperatorException
98
109from .harness import DispatchTestCase
1110
@@ -29,16 +28,32 @@ def forward(self, x):
2928 inputs = [torch .randn (1 , 2 , 3 )]
3029 self .run_test (Unsqueeze (dim ), inputs )
3130
32- # Testing with more than one dynamic dims results in following error:
33- # AssertionError: Currently we don't support unsqueeze with more than one dynamic dims.
34-
3531 @parameterized .expand (
3632 [
37- ("negative_dim_dynamic" , - 4 ),
38- ("positive_dim_dynamic" , 1 ),
33+ ("1_dynamic_shape_2d_-3" , - 3 , (2 , 5 ), (3 , 5 ), (4 , 5 )),
34+ ("1_dynamic_shape_2d_-2" , - 2 , (2 , 3 ), (2 , 4 ), (2 , 5 )),
35+ ("1_dynamic_shape_2d_-1" , - 1 , (2 , 3 ), (2 , 4 ), (2 , 5 )),
36+ ("1_dynamic_shape_2d_0" , 0 , (2 , 3 ), (2 , 4 ), (2 , 5 )),
37+ ("1_dynamic_shape_2d_1" , 1 , (2 , 3 ), (2 , 4 ), (2 , 5 )),
38+ ("1_dynamic_shape_2d_2" , 2 , (2 , 3 ), (2 , 4 ), (2 , 5 )),
39+ ("2_dynamic_shape_3d_-1" , - 1 , (2 , 2 , 3 ), (4 , 3 , 3 ), (5 , 5 , 3 )),
40+ ("2_dynamic_shape_3d_0" , 2 , (2 , 2 , 3 ), (4 , 3 , 3 ), (5 , 5 , 3 )),
41+ ("2_dynamic_shape_3d_1" , 1 , (2 , 2 , 3 ), (4 , 3 , 3 ), (5 , 6 , 3 )),
42+ ("2_dynamic_shape_3d_2" , 2 , (2 , 2 , 3 ), (4 , 3 , 3 ), (6 , 5 , 3 )),
43+ ("4_dynamic_shape_4d_-4" , - 4 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (3 , 3 , 5 , 5 )),
44+ ("4_dynamic_shape_4d_-3" , - 3 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (3 , 3 , 5 , 5 )),
45+ ("4_dynamic_shape_4d_-2" , - 2 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (4 , 3 , 5 , 6 )),
46+ ("4_dynamic_shape_4d_-1" , - 1 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (4 , 3 , 5 , 6 )),
47+ ("4_dynamic_shape_4d_0" , 0 , (1 , 2 , 3 , 4 ), (2 , 2 , 5 , 7 ), (2 , 3 , 6 , 8 )),
48+ ("4_dynamic_shape_4d_1" , 1 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (3 , 3 , 5 , 5 )),
49+ ("4_dynamic_shape_4d_2" , 2 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (3 , 3 , 5 , 5 )),
50+ ("4_dynamic_shape_4d_3" , 3 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (3 , 3 , 5 , 5 )),
51+ ("4_dynamic_shape_4d_4" , 4 , (1 , 2 , 3 , 4 ), (2 , 2 , 3 , 5 ), (3 , 3 , 5 , 5 )),
3952 ]
4053 )
41- def test_unsqueeze_with_dynamic_shape (self , _ , dim ):
54+ def test_unsqueeze_with_dynamic_shape (
55+ self , _ , dim , min_shape , opt_shape , max_shape
56+ ):
4257 class Unsqueeze (nn .Module ):
4358 def __init__ (self , dim ):
4459 super ().__init__ ()
@@ -49,9 +64,10 @@ def forward(self, x):
4964
5065 input_specs = [
5166 Input (
52- shape = (- 1 , 2 , 3 ),
5367 dtype = torch .float32 ,
54- shape_ranges = [((1 , 2 , 3 ), (2 , 2 , 3 ), (3 , 2 , 3 ))],
68+ min_shape = min_shape ,
69+ opt_shape = opt_shape ,
70+ max_shape = max_shape ,
5571 ),
5672 ]
5773 self .run_test_with_dynamic_shape (Unsqueeze (dim ), input_specs )
0 commit comments