Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,35 @@ def test_module_backcompat(self):
input = torch.randn(2, 3, dtype=torch.float)
self.assertEqual(m(input).size(), (2, 5))

def test_module_super_init(self):
class MyMixin:
def __init__(self, *a, **kw):
super().__init__(*a, **kw)
self.mixin_init = True

class MyModuleWithMixinBefore(MyMixin, nn.Module):
def __init__(self):
super().__init__()

class MyModuleWithMixinAfter(nn.Module, MyMixin):
def __init__(self):
super().__init__()

self.assertTrue(hasattr(MyModuleWithMixinBefore(), 'mixin_init'))
self.assertFalse(hasattr(MyModuleWithMixinAfter(), 'mixin_init'))

nn.Module.call_super_init = True
self.assertTrue(hasattr(MyModuleWithMixinBefore(), 'mixin_init'))
self.assertTrue(hasattr(MyModuleWithMixinAfter(), 'mixin_init'))
nn.Module.call_super_init = False

MyModuleWithMixinBefore.call_super_init = True
MyModuleWithMixinAfter.call_super_init = True
self.assertTrue(hasattr(MyModuleWithMixinBefore(), 'mixin_init'))
self.assertTrue(hasattr(MyModuleWithMixinAfter(), 'mixin_init'))
MyModuleWithMixinBefore.call_super_init = False
MyModuleWithMixinAfter.call_super_init = False

def test_share_memory(self):
class Net(nn.Module):
def __init__(self):
Expand Down
15 changes: 14 additions & 1 deletion torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,13 +432,23 @@ def forward(self, x):
_state_dict_pre_hooks: Dict[int, Callable]
_load_state_dict_post_hooks: Dict[int, Callable]
_modules: Dict[str, Optional['Module']]
call_super_init: bool = False

def __init__(self) -> None:
def __init__(self, *args, **kwargs) -> None:
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module")

# Backward compatibility: no args used to be allowed when call_super_init=False
if self.call_super_init is False and bool(kwargs):
raise TypeError("{}.__init__() got an unexpected keyword argument '{}'"
"".format(type(self).__name__, next(iter(kwargs))))

if self.call_super_init is False and bool(args):
raise TypeError("{}.__init__() takes 1 positional argument but {} were"
" given".format(type(self).__name__, len(args) + 1))

"""
Calls super().__setattr__('a', a) instead of the typical self.a = a
to avoid Module.__setattr__ overhead. Module's __setattr__ has special
Expand All @@ -462,6 +472,9 @@ def __init__(self) -> None:
super().__setattr__('_load_state_dict_post_hooks', OrderedDict())
super().__setattr__('_modules', OrderedDict())

if self.call_super_init:
super(Module, self).__init__(*args, **kwargs)

forward: Callable[..., Any] = _forward_unimplemented

def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
Expand Down