Skip to content

Commit aa0352c

Browse files
yushangdipytorchmergebot
authored andcommitted
[custom ops] add default value support for device types (#129792)
Fixes #129371 I think the first case in Issue #129371 is already supported in the current code? Since it takes care of string default values. This PR adds support for device type default values. Pull Request resolved: #129792 Approved by: https://github.com/zou3519
1 parent d7680a5 commit aa0352c

File tree

2 files changed

+64
-3
lines changed

2 files changed

+64
-3
lines changed

test/test_custom_ops.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,29 @@ def g(
678678
"""(Tensor(a0!) x, Tensor(a1!)[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""",
679679
)
680680

681+
def h(
682+
x: Tensor,
683+
a: Optional[int] = None,
684+
b: float = 3.14,
685+
c: bool = True,
686+
d: int = 3,
687+
e: str = "foo",
688+
f: torch.dtype = torch.float,
689+
g: torch.dtype = torch.float32,
690+
h: torch.dtype = torch.int,
691+
i: torch.device = torch.device("cpu:0"),
692+
j: torch.device = "cpu",
693+
) -> None:
694+
pass
695+
696+
self.assertExpectedInline(
697+
infer_schema(h),
698+
(
699+
"""(Tensor x, SymInt? a=None, float b=3.14, bool c=True, SymInt d=3, str e="foo", """
700+
"""ScalarType f=float32, ScalarType g=float32, ScalarType h=int32, Device i="cpu:0", Device j="cpu") -> ()"""
701+
),
702+
)
703+
681704
def test_infer_schema_unsupported(self):
682705
with self.assertRaisesRegex(ValueError, "varargs"):
683706

@@ -2439,15 +2462,53 @@ def f(
24392462
f: torch.dtype = torch.float,
24402463
g: torch.dtype = torch.float32,
24412464
h: torch.dtype = torch.int,
2465+
i: torch.device = torch.device("cpu:0"),
2466+
j: torch.device = "cpu",
24422467
) -> Tensor:
2443-
defaults.extend([a, b, c, d, e, f, g, h])
2468+
defaults.extend([a, b, c, d, e, f, g, h, i, j])
24442469
return x.clone()
24452470

24462471
x = torch.randn(3)
24472472
f(x)
24482473
self.assertEqual(
24492474
defaults,
2450-
[None, 3.14, True, 3, "foo", torch.float, torch.float32, torch.int],
2475+
[
2476+
None,
2477+
3.14,
2478+
True,
2479+
3,
2480+
"foo",
2481+
torch.float,
2482+
torch.float32,
2483+
torch.int,
2484+
torch.device("cpu:0"),
2485+
"cpu",
2486+
],
2487+
)
2488+
default_values = [
2489+
arg.default_value
2490+
for arg in torch.ops._torch_testing.f.default._schema.arguments
2491+
]
2492+
# enum values taken from c10/core/ScalarType.h
2493+
type_enum = {
2494+
"float": 6,
2495+
"int": 3,
2496+
}
2497+
self.assertEqual(
2498+
default_values,
2499+
[
2500+
None,
2501+
None,
2502+
3.14,
2503+
True,
2504+
3,
2505+
"foo",
2506+
type_enum["float"],
2507+
type_enum["float"],
2508+
type_enum["int"],
2509+
torch.device("cpu:0"),
2510+
torch.device("cpu"),
2511+
],
24512512
)
24522513

24532514
def test_mutated_error(self):

torch/_library/infer_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def convert_type_string(annotation_type: str):
100100
default_repr = None
101101
if param.default is None or isinstance(param.default, (int, float, bool)):
102102
default_repr = str(param.default)
103-
elif isinstance(param.default, str):
103+
elif isinstance(param.default, (str, torch.device)):
104104
default_repr = f'"{param.default}"'
105105
elif isinstance(param.default, torch.dtype):
106106
dtype_repr = str(param.default)

0 commit comments

Comments
 (0)