Skip to content

Commit afad3e4

Browse files
David Riazatifacebook-github-bot
authored andcommitted
Add support for class annotations (#21379)
Summary: This adds support for inferred attributes (everything except empty lists, dicts, and tuples) as well as using the PEP 526 style annotations on a class, so this eliminates the need for `torch.jit.Attribute` Pull Request resolved: #21379 Differential Revision: D15718537 Pulled By: driazati fbshipit-source-id: b7481ae3d7ee421613e931b7dc3427ef2a99757f
1 parent 85528fe commit afad3e4

File tree

3 files changed

+94
-2
lines changed

3 files changed

+94
-2
lines changed

test/test_jit.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13377,11 +13377,75 @@ def forward(self, x):
1337713377
x += self.sequential(x)
1337813378
return x
1337913379

13380+
self.checkModule(M(), (torch.randn(5, 5),))
13381+
13382+
def test_attributes(self):
13383+
untyped_values = (
13384+
('my_dict', {"I": "am", "a test": "test"}),
13385+
('my_float', 2.3),
13386+
('my_int', 99),
13387+
('my_bool', False),
13388+
('my_tuple', (1, 2, 3, 4)),
13389+
('my_list', [(1, 2), (3, 4)]),
13390+
# ('my_tensor', torch.randn(2, 2)),
13391+
('my_int_list', [1, 2, 3, 4]),
13392+
# ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]),
13393+
('my_bool_list', [True, True, False, True]),
13394+
('my_float_list', [1., 2., 3., 4.]),
13395+
('my_str_list', ['hello', 'bye']),
13396+
)
13397+
typed_values = (
13398+
('my_empty_list', []),
13399+
('my_empty_dict', {}),
13400+
('my_none', None),
13401+
)
1338013402

13403+
class M(torch.nn.Module):
13404+
# TODO: re-enable this once this test is in a Python 3-only syntax
13405+
# file
13406+
# my_empty_list : List[int]
13407+
# my_empty_dict : Dict[str, int]
13408+
# my_none : Optional[int]
13409+
13410+
def __init__(self):
13411+
super(M, self).__init__()
13412+
13413+
def forward(self, x):
13414+
return (
13415+
self.my_dict,
13416+
self.my_float,
13417+
self.my_int,
13418+
self.my_bool,
13419+
# self.my_tensor,
13420+
self.my_int_list,
13421+
# self.my_tensor_list,
13422+
self.my_bool_list,
13423+
self.my_float_list,
13424+
self.my_str_list,
13425+
self.my_empty_list,
13426+
self.my_empty_dict,
13427+
self.my_none,
13428+
)
13429+
13430+
# TODO: as a followup, fix this test
13431+
# We can't define class attributes like we should be doing:
13432+
# class M(torch.nn.Module):
13433+
# my_empty_list : List[int]
13434+
# my_empty_dict : Dict[str, int]
13435+
# my_none : Optional[int]
13436+
# my_out_of_line_attribute: List[int] = [1, 2, 3]
13437+
# since there's no string frontend for Python classes (so the `define`)
13438+
# trick doesn't work.
13439+
M.__annotations__ = {
13440+
'my_empty_list': List[int],
13441+
'my_empty_dict': Dict[str, int],
13442+
'my_none': Optional[int],
13443+
}
1338113444

1338213445
m = M()
13383-
self.checkModule(m, (torch.randn(5, 5),))
13384-
m.module_list.add_module('new', Inner())
13446+
for name, value in untyped_values + typed_values:
13447+
setattr(m, name, value)
13448+
1338513449
self.checkModule(m, (torch.randn(5, 5),))
1338613450

1338713451

torch/csrc/jit/init.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,15 @@ void initJITBindings(PyObject* module) {
338338
.def(
339339
"_jit_set_inline_everything_mode",
340340
[](bool enabled) { script::getInlineEverythingMode() = enabled; })
341+
.def(
342+
"_jit_try_infer_type",
343+
[](py::object obj) -> TypePtr {
344+
auto match = tryToInferType(obj);
345+
if (match.type) {
346+
return *match.type;
347+
}
348+
return nullptr;
349+
})
341350
.def(
342351
"_jit_fuser_get_fused_kernel_code",
343352
[](Graph& g, std::vector<at::Tensor> inps) {

torch/jit/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1658,6 +1658,25 @@ def __init__(self, original, stubs):
16581658
continue
16591659
ScriptModule.__setattr__(self, name, getattr(original, name))
16601660

1661+
# Copy annotations, pull types from `__annotations__` or try to infer
1662+
# the type if possible
1663+
class_annotations = getattr(original, '__annotations__', {})
1664+
for name in dir(original):
1665+
if name in ("training", "__dict__"):
1666+
# TODO: removing this skip should let us remove the code to add training as an
1667+
# attribute in python_sugared_value.cpp
1668+
continue
1669+
if hasattr(self, name):
1670+
# Don't re-copy properties
1671+
continue
1672+
item = getattr(original, name)
1673+
if name in class_annotations:
1674+
the_type = torch.jit.annotations.ann_to_type(class_annotations[name])
1675+
else:
1676+
the_type = torch._C._jit_try_infer_type(item)
1677+
if the_type is not None:
1678+
self._c._register_attribute(name, the_type, item)
1679+
16611680
# Copy overloads
16621681
self.__dict__["_overloads"] = dict(getattr(original, "__overloads__", {}))
16631682

0 commit comments

Comments
 (0)