Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 43 additions & 34 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,44 +646,53 @@ def forward_hook(m, input, output):
self.assertEqual(input.grad, expected_grad)

def test_hook_buffer_registration(self):
def buffer_registration_hook(module, name, buffer):
buffer.registered = True
handle = torch.nn.modules.module.register_module_buffer_registration_hook(
buffer_registration_hook
)
try:
l, n, s = self._create_basic_net()
for b in s.buffers():
self.assertTrue(getattr(b, "registered", False))
finally:
handle.remove()
for return_buffer in (True, False):
def buffer_registration_hook(module, name, buffer):
buffer.registered = True
if return_buffer:
return buffer
handle = torch.nn.modules.module.register_module_buffer_registration_hook(
buffer_registration_hook
)
try:
l, n, s = self._create_basic_net()
for b in s.buffers():
self.assertTrue(getattr(b, "registered", False))
finally:
handle.remove()

def test_hook_submodule_registration(self):
def module_registration_hook(module, name, submodule):
module.registered = True
submodule.registered = True
handle = torch.nn.modules.module.register_module_module_registration_hook(
module_registration_hook
)
try:
l, n, s = self._create_basic_net()
for m in s.modules():
self.assertTrue(getattr(m, "registered", False))
finally:
handle.remove()
for return_submodule in (True, False):
def module_registration_hook(module, name, submodule):
module.registered = True
submodule.registered = True
if return_submodule:
return submodule
handle = torch.nn.modules.module.register_module_module_registration_hook(
module_registration_hook
)
try:
l, n, s = self._create_basic_net()
for m in s.modules():
self.assertTrue(getattr(m, "registered", False))
finally:
handle.remove()

def test_hook_parameter_registration(self):
def parameter_registration_hook(module, name, parameter):
parameter.registered = True
handle = torch.nn.modules.module.register_module_parameter_registration_hook(
parameter_registration_hook
)
try:
l, n, s = self._create_basic_net()
for p in s.parameters():
self.assertTrue(getattr(p, "registered", False))
finally:
handle.remove()
for return_parameter in (True, False):
def parameter_registration_hook(module, name, parameter):
parameter.registered = True
if return_parameter:
return parameter
handle = torch.nn.modules.module.register_module_parameter_registration_hook(
parameter_registration_hook
)
try:
l, n, s = self._create_basic_net()
for p in s.parameters():
self.assertTrue(getattr(p, "registered", False))
finally:
handle.remove()

def test_to(self):
m = nn.Linear(3, 5)
Expand Down
24 changes: 18 additions & 6 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,9 @@ def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool
.format(torch.typename(tensor), name))
else:
for hook in _global_buffer_registration_hooks.values():
tensor = hook(self, name, tensor) or tensor
output = hook(self, name, tensor)
if output is not None:
tensor = output
self._buffers[name] = tensor
if persistent:
self._non_persistent_buffers_set.discard(name)
Expand Down Expand Up @@ -548,7 +550,9 @@ def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
"the forward() method.".format(name))
else:
for hook in _global_parameter_registration_hooks.values():
param = hook(self, name, param) or param
output = hook(self, name, param)
if output is not None:
param = output
self._parameters[name] = param

def add_module(self, name: str, module: Optional['Module']) -> None:
Expand All @@ -574,7 +578,9 @@ def add_module(self, name: str, module: Optional['Module']) -> None:
elif name == '':
raise KeyError("module name can't be empty string \"\"")
for hook in _global_module_registration_hooks.values():
module = hook(self, name, module) or module
output = hook(self, name, module)
if output is not None:
module = output
self._modules[name] = module

def register_module(self, name: str, module: Optional['Module']) -> None:
Expand Down Expand Up @@ -1468,15 +1474,19 @@ def remove_from(*dicts_or_sets):
"cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
for hook in _global_module_registration_hooks.values():
value = hook(self, name, value) or value
output = hook(self, name, value)
if output is not None:
value = output
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)"
.format(torch.typename(value), name))
for hook in _global_module_registration_hooks.values():
value = hook(self, name, value) or value
output = hook(self, name, value)
if output is not None:
value = output
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
Expand All @@ -1486,7 +1496,9 @@ def remove_from(*dicts_or_sets):
"(torch.Tensor or None expected)"
.format(torch.typename(value), name))
for hook in _global_buffer_registration_hooks.values():
value = hook(self, name, value) or value
output = hook(self, name, value)
if output is not None:
value = output
buffers[name] = value
else:
super().__setattr__(name, value)
Expand Down