Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13370,11 +13370,75 @@ def forward(self, x):
x += self.sequential(x)
return x

self.checkModule(M(), (torch.randn(5, 5),))

def test_attributes(self):
untyped_values = (
('my_dict', {"I": "am", "a test": "test"}),
('my_float', 2.3),
('my_int', 99),
('my_bool', False),
('my_tuple', (1, 2, 3, 4)),
('my_list', [(1, 2), (3, 4)]),
# ('my_tensor', torch.randn(2, 2)),
('my_int_list', [1, 2, 3, 4]),
# ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]),
('my_bool_list', [True, True, False, True]),
('my_float_list', [1., 2., 3., 4.]),
('my_str_list', ['hello', 'bye']),
)
typed_values = (
('my_empty_list', []),
('my_empty_dict', {}),
('my_none', None),
)

class M(torch.nn.Module):
# TODO: re-enable this once this test is in a Python 3-only syntax
# file
# my_empty_list : List[int]
# my_empty_dict : Dict[str, int]
# my_none : Optional[int]

def __init__(self):
super(M, self).__init__()

def forward(self, x):
return (
self.my_dict,
self.my_float,
self.my_int,
self.my_bool,
# self.my_tensor,
self.my_int_list,
# self.my_tensor_list,
self.my_bool_list,
self.my_float_list,
self.my_str_list,
self.my_empty_list,
self.my_empty_dict,
self.my_none,
)

# TODO: as a followup, fix this test
# We can't define class attributes like we should be doing:
# class M(torch.nn.Module):
# my_empty_list : List[int]
# my_empty_dict : Dict[str, int]
# my_none : Optional[int]
# my_out_of_line_attribute: List[int] = [1, 2, 3]
# since there's no string frontend for Python classes (so the `define`)
# trick doesn't work.
M.__annotations__ = {
'my_empty_list': List[int],
'my_empty_dict': Dict[str, int],
'my_none': Optional[int],
}

m = M()
self.checkModule(m, (torch.randn(5, 5),))
m.module_list.add_module('new', Inner())
for name, value in untyped_values + typed_values:
setattr(m, name, value)

self.checkModule(m, (torch.randn(5, 5),))


Expand Down
9 changes: 9 additions & 0 deletions torch/csrc/jit/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,15 @@ void initJITBindings(PyObject* module) {
.def(
"_jit_set_inline_everything_mode",
[](bool enabled) { script::getInlineEverythingMode() = enabled; })
.def(
"_jit_try_infer_type",
[](py::object obj) -> TypePtr {
auto match = tryToInferType(obj);
if (match.type) {
return *match.type;
}
return nullptr;
})
.def(
"_jit_fuser_get_fused_kernel_code",
[](Graph& g, std::vector<at::Tensor> inps) {
Expand Down
19 changes: 19 additions & 0 deletions torch/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,25 @@ def __init__(self, original, stubs):
continue
ScriptModule.__setattr__(self, name, getattr(original, name))

# Copy annotations, pull types from `__annotations__` or try to infer
# the type if possible
class_annotations = getattr(original, '__annotations__', {})
for name in dir(original):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't dir(value) gets a ton of arbitrary python implementation details ? these are all nn.Modules, we could try register a hook in setattr or something else

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know we don't really have any other choice since things like __dict__ only have the current class' info (not anything from superclasses), plus not everything goes through __setattr__ (e.g. static members). This is mostly safe since the only things it takes are types it can infer, so it would pick up the __module__ and __doc__

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it would also pick up _version inplace dump_patches

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why doens't it work if you add it here ?

object.__setattr__(self, name, value)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That'd be mostly fine but if something is defined at the class level it doesn't go through __setattr__, it's just on the class type itself, we need to copy those too (including any up the inheritance chain) so it doesn't work generally

if name in ("training", "__dict__"):
# TODO: removing this skip should let us remove the code to add training as an
# attribute in python_sugared_value.cpp
continue
if hasattr(self, name):
# Don't re-copy properties
continue
item = getattr(original, name)
if name in class_annotations:
the_type = torch.jit.annotations.ann_to_type(class_annotations[name])
else:
the_type = torch._C._jit_try_infer_type(item)
if the_type is not None:
self._c._register_attribute(name, the_type, item)

# Copy overloads
self.__dict__["_overloads"] = dict(getattr(original, "__overloads__", {}))

Expand Down