Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def aten_ops_group_norm(
)


@dynamo_tensorrt_converter(torch.ops.aten.cat.default)
@dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True)
def aten_ops_cat(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -1327,7 +1327,7 @@ def aten_ops_abs(
)


@dynamo_tensorrt_converter(torch.ops.aten.sin.default)
@dynamo_tensorrt_converter(torch.ops.aten.sin.default, supports_dynamic_shapes=True)
def aten_ops_sin(
ctx: ConversionContext,
target: Target,
Expand All @@ -1344,7 +1344,7 @@ def aten_ops_sin(
)


@dynamo_tensorrt_converter(torch.ops.aten.cos.default)
@dynamo_tensorrt_converter(torch.ops.aten.cos.default, supports_dynamic_shapes=True)
def aten_ops_cos(
ctx: ConversionContext,
target: Target,
Expand Down
50 changes: 42 additions & 8 deletions tests/py/dynamo/conversion/test_cat_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ def forward(self, x, y):

input_specs = [
Input(
shape=(16, -1, 3),
dtype=torch.float32,
shape_ranges=[((16, 2, 3), (16, 3, 3), (16, 32, 3))],
min_shape=(16, 2, 3),
opt_shape=(16, 3, 3),
max_shape=(16, 32, 3),
),
Input(
shape=(16, -1, 3),
dtype=torch.float32,
shape_ranges=[((16, 2, 3), (16, 16, 3), (16, 32, 3))],
min_shape=(16, 2, 3),
opt_shape=(16, 16, 3),
max_shape=(16, 32, 3),
),
]
self.run_test_with_dynamic_shape(
Expand All @@ -71,14 +73,46 @@ def forward(self, x, y):

input_specs = [
Input(
shape=(-1, 16, 3),
dtype=torch.float32,
shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))],
min_shape=(2, 16, 3),
opt_shape=(3, 16, 3),
max_shape=(32, 16, 3),
),
Input(
shape=(-1, 16, 3),
dtype=torch.float32,
shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))],
min_shape=(2, 16, 3),
opt_shape=(3, 16, 3),
max_shape=(32, 16, 3),
),
]
self.run_test_with_dynamic_shape(
Cat(),
input_specs,
)

@parameterized.expand(
[
("pos", 1),
("neg", -2),
]
)
def test_cat_dynamic_shape_dim(self, _, dim):
class Cat(nn.Module):
def forward(self, x, y):
return torch.ops.aten.cat.default((x, y), dim)

input_specs = [
Input(
dtype=torch.float32,
min_shape=(2, 1, 1),
opt_shape=(3, 1, 2),
max_shape=(4, 1, 3),
),
Input(
dtype=torch.float32,
min_shape=(2, 2, 1),
opt_shape=(3, 3, 2),
max_shape=(4, 4, 3),
),
]
self.run_test_with_dynamic_shape(
Expand Down
49 changes: 49 additions & 0 deletions tests/py/dynamo/conversion/test_cos_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -44,6 +45,54 @@ def forward(self, input):
inputs,
)

@parameterized.expand(
[
(
"3d_dim_dtype_int32",
(3, 2, 1),
(3, 2, 3),
(3, 3, 4),
torch.int32,
torch.float32,
),
(
"2d_dim_dtype_float16",
(1, 1),
(2, 2),
(4, 4),
torch.float16,
torch.float16,
),
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
]
)
def test_dynamic_shape_cos(
self, _, min_shape, opt_shape, max_shape, type, output_type
):
class cos(nn.Module):
def forward(self, input):
return torch.ops.aten.cos.default(input)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]

self.run_test_with_dynamic_shape(
cos(), input_specs, output_dtypes=[output_type]
)


if __name__ == "__main__":
run_tests()
49 changes: 49 additions & 0 deletions tests/py/dynamo/conversion/test_sin_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -44,6 +45,54 @@ def forward(self, input):
inputs,
)

@parameterized.expand(
[
(
"3d_dim_dtype_int32",
(3, 2, 1),
(3, 2, 3),
(3, 3, 4),
torch.int32,
torch.float32,
),
(
"2d_dim_dtype_float16",
(1, 1),
(2, 2),
(4, 4),
torch.float16,
torch.float16,
),
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
]
)
def test_dynamic_shape_sin(
self, _, min_shape, opt_shape, max_shape, type, output_type
):
class sin(nn.Module):
def forward(self, input):
return torch.ops.aten.sin.default(input)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]

self.run_test_with_dynamic_shape(
sin(), input_specs, output_dtypes=[output_type]
)


if __name__ == "__main__":
run_tests()