Skip to content

Commit 34bc6f7

Browse files
committed
[Resubmit] state_dict_pre_hook
ghstack-source-id: 07f7959 Pull Request resolved: #90435
1 parent c92cf6b commit 34bc6f7

File tree

3 files changed

+75
-1
lines changed

3 files changed

+75
-1
lines changed

test/test_nn.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from itertools import product
1313
from functools import partial
1414
from collections import OrderedDict
15+
from tempfile import NamedTemporaryFile
1516

1617
import torch
1718

@@ -37,7 +38,7 @@
3738
download_file, get_function_arglist, load_tests, skipIfMps,\
3839
TEST_WITH_UBSAN, IS_PPC, \
3940
parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \
40-
skipIfTorchDynamo
41+
skipIfTorchDynamo, IS_WINDOWS
4142
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
4243
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
4344
module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
@@ -2450,6 +2451,60 @@ def hook_fn(module, state_dict, prefix, local_metadata, strict, missing_keys, un
24502451
model[0][0]._register_load_state_dict_pre_hook(hook_fn, with_module=True)
24512452
model.load_state_dict(model.state_dict(), strict=True)
24522453

2454+
@unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows")
2455+
def test_register_state_dict_pre_hook_backward_compat(self):
2456+
called = False
2457+
2458+
def my_state_dict_pre_hook(*args, **kwargs):
2459+
nonlocal called
2460+
called = True
2461+
2462+
m = nn.Linear(1, 1)
2463+
self.assertTrue(hasattr(m, '_state_dict_pre_hooks'))
2464+
delattr(m, '_state_dict_pre_hooks')
2465+
# Save and load, ensure we can still call state_dict
2466+
# without running into issues.
2467+
with NamedTemporaryFile() as f:
2468+
# Note that torch.save / torch.load is not recommended
2469+
# to save / load modules.
2470+
torch.save(m, f.name)
2471+
m = torch.load(f.name)
2472+
2473+
# Ensure we can run state_dict without issues
2474+
_ = m.state_dict()
2475+
self.assertFalse(called)
2476+
m.register_state_dict_pre_hook(my_state_dict_pre_hook)
2477+
_ = m.state_dict()
2478+
self.assertTrue(called)
2479+
2480+
def test_register_state_dict_pre_hook(self):
2481+
_state_dict_prefix = "foo."
2482+
state_dict_pre_hook_count = 0
2483+
2484+
class MyModule(torch.nn.Module):
2485+
def __init__(self):
2486+
super().__init__()
2487+
self.a = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3))
2488+
2489+
def forward(self, x):
2490+
return self.a(x)
2491+
2492+
def my_state_dict_pre_hook(module, prefix, keep_vars):
2493+
nonlocal keep_var_setting
2494+
self.assertEqual(keep_vars, keep_var_setting)
2495+
nonlocal state_dict_pre_hook_count
2496+
state_dict_pre_hook_count += 1
2497+
self.assertTrue(prefix.startswith(_state_dict_prefix))
2498+
2499+
mod = MyModule()
2500+
mod.register_state_dict_pre_hook(my_state_dict_pre_hook)
2501+
# Test to ensure submodules run the hook as well.
2502+
mod.a.register_state_dict_pre_hook(my_state_dict_pre_hook)
2503+
for keep_var_setting in [True, False]:
2504+
_ = mod.state_dict(prefix=_state_dict_prefix, keep_vars=keep_var_setting)
2505+
self.assertEqual(2, state_dict_pre_hook_count)
2506+
state_dict_pre_hook_count = 0
2507+
24532508
@skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
24542509
def test_load_state_dict_ref_cycle(self):
24552510
# load_state_dict shouldn't cause a reference cycle involving Tensors

torch/distributed/nn/api/remote_module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,10 @@
6868
"_forward_pre_hooks",
6969
"_forward_pre_hooks_with_kwargs",
7070
"_state_dict_hooks",
71+
"_state_dict_pre_hooks",
7172
"_load_state_dict_pre_hooks",
7273
"_load_state_dict_post_hooks",
74+
"_state_dict_pre_hooks",
7375
"_modules",
7476
# The two attributes below are generated methods, not available at pickling time.
7577
"forward_async",

torch/nn/modules/module.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ def forward(self, x):
429429
_forward_pre_hooks_with_kwargs: Dict[int, bool]
430430
_state_dict_hooks: Dict[int, Callable]
431431
_load_state_dict_pre_hooks: Dict[int, Callable]
432+
_state_dict_pre_hooks: Dict[int, Callable]
432433
_load_state_dict_post_hooks: Dict[int, Callable]
433434
_modules: Dict[str, Optional['Module']]
434435

@@ -456,6 +457,7 @@ def __init__(self) -> None:
456457
super().__setattr__('_forward_pre_hooks', OrderedDict())
457458
super().__setattr__('_forward_pre_hooks_with_kwargs', OrderedDict())
458459
super().__setattr__('_state_dict_hooks', OrderedDict())
460+
super().__setattr__('_state_dict_pre_hooks', OrderedDict())
459461
super().__setattr__('_load_state_dict_pre_hooks', OrderedDict())
460462
super().__setattr__('_load_state_dict_post_hooks', OrderedDict())
461463
super().__setattr__('_modules', OrderedDict())
@@ -1560,6 +1562,8 @@ def __setstate__(self, state):
15601562
self._forward_hooks_with_kwargs = OrderedDict()
15611563
if '_state_dict_hooks' not in self.__dict__:
15621564
self._state_dict_hooks = OrderedDict()
1565+
if '_state_dict_pre_hooks' not in self.__dict__:
1566+
self._state_dict_pre_hooks = OrderedDict()
15631567
if '_load_state_dict_pre_hooks' not in self.__dict__:
15641568
self._load_state_dict_pre_hooks = OrderedDict()
15651569
if '_load_state_dict_post_hooks' not in self.__dict__:
@@ -1668,6 +1672,16 @@ def _register_state_dict_hook(self, hook):
16681672
self._state_dict_hooks[handle.id] = hook
16691673
return handle
16701674

1675+
def register_state_dict_pre_hook(self, hook):
1676+
r"""These hooks will be called with arguments: ``self``, ``prefix``,
1677+
and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered
1678+
hooks can be used to perform pre-processing before the ``state_dict``
1679+
call is made.
1680+
"""
1681+
handle = hooks.RemovableHandle(self._state_dict_pre_hooks)
1682+
self._state_dict_pre_hooks[handle.id] = hook
1683+
return handle
1684+
16711685
def _save_to_state_dict(self, destination, prefix, keep_vars):
16721686
r"""Saves module state to `destination` dictionary, containing a state
16731687
of the module, but not its descendants. This is called on every
@@ -1681,6 +1695,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
16811695
prefix (str): the prefix for parameters and buffers used in this
16821696
module
16831697
"""
1698+
for hook in self._state_dict_pre_hooks.values():
1699+
hook(self, prefix, keep_vars)
1700+
16841701
for name, param in self._parameters.items():
16851702
if param is not None:
16861703
destination[prefix + name] = param if keep_vars else param.detach()

0 commit comments

Comments
 (0)