Skip to content

Commit 82dd693

Browse files
Dmytro Dzhulgakovfacebook-github-bot
authored andcommitted
Split nn.Module._save_to_state_dict to make it overridable (#21933)
Summary: # Motivation We allow to override JIT module serialization with `__getstate__/__setstate__` in order to cover cases where parameters are not serializable. Use cases include: MKLDNN integration: https://github.com/pytorch/pytorch/blob/a388c783505987363717bd4da0b166e8d1d7ccb9/torch/utils/mkldnn.py#L18-L26 and also fbgemm prepacked format integration for quantized tensors. However many Eager scripts use `torch.save(module.state_dict())` form of serialization. There are several ways to make it work: * make packed_weight itself pickleable (e.g. by binding `__getstate__/__setstate__` on C++ UDT level) * change: we’d need to allow module buffers to be of arbitrary, non-Tensor types * pro: no change to state_dict behavior * cons: might not be directly inspectable by user calling .state_dict(), especially if packed weights represent several tensors fused together * make packed_weight being proper Tensor layout * pro: no change to state_dict or buffers behavior * cons: adding new tensor layouts is pretty costly today * cons: doesn’t work if multiple tensors are packed in one interleaved representation * *[this approach]* allow Modules to override state_dict and return regular tensors * pro: most flexible and hackable * pro: maintains semantic meaning of statedict as all data necessary to represent module’s state * cons: complicates state_dict logic * cons: potential code duplication between `__getstate__/__setstate__` Based on discussions with zdevito and gchanan we decided to pick latter approach. Rationale: this behavior is fully opt-in and will impact only modules that need it. For those modules the requirement listed above won't be true. But we do preserve requirement that all elements of state_dict are tensors. (https://fburl.com/qgybrug4 for internal discussion) In the future we might also implement one of the approaches above but those are more involved. Pull Request resolved: #21933 Differential Revision: D15937678 Pulled By: dzhulgakov fbshipit-source-id: 3cb5d1a8304d04def7aabc0969d0a2e7be182367
1 parent b2197ef commit 82dd693

File tree

3 files changed

+61
-14
lines changed

3 files changed

+61
-14
lines changed

test/test_nn.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4669,6 +4669,38 @@ def test_load_state_dict_ref_cycle(self):
46694669

46704670
self.assertEqual(refcycles, 0)
46714671

4672+
def test_load_state_dict_custom(self):
4673+
4674+
class CustomState(nn.Module):
4675+
def __init__(self):
4676+
super(CustomState, self).__init__()
4677+
self.param = torch.nn.Parameter(torch.ones(1))
4678+
self.sub = torch.nn.Linear(5, 5)
4679+
4680+
def _save_to_state_dict(self, destination, prefix, keep_vars):
4681+
destination[prefix + "serialized"] = self.param.data + 1
4682+
4683+
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
4684+
strict, missing_keys, unexpected_keys,
4685+
error_msgs):
4686+
# skip some of the error handling
4687+
self.param.data.copy_(state_dict[prefix + "serialized"] - 1)
4688+
4689+
# use sequential to verify nesting
4690+
m = nn.Sequential(CustomState())
4691+
m[0].param[0] = 10
4692+
m[0].sub.weight[0, 0] = 555
4693+
state_dict = m.state_dict()
4694+
self.assertEqual(state_dict["0.serialized"].item(), 11)
4695+
self.assertIn("0.sub.weight", state_dict)
4696+
self.assertNotIn("0.param", state_dict)
4697+
del m
4698+
mm = nn.Sequential(CustomState())
4699+
self.assertEqual(mm[0].param[0].item(), 1)
4700+
mm.load_state_dict(state_dict)
4701+
self.assertEqual(mm[0].param[0].item(), 10)
4702+
self.assertEqual(mm[0].sub.weight[0, 0].item(), 555)
4703+
46724704
def test_parameter_assignment(self):
46734705
l = nn.Linear(5, 5)
46744706

@@ -5985,11 +6017,11 @@ def test_transformer_args_check(self):
59856017
wrong_d_model = 63
59866018
wrong_nhead = 5
59876019

5988-
def test(encoder_input_shape, decoder_input_shape,
6020+
def test(encoder_input_shape, decoder_input_shape,
59896021
src_mask_len=None, tgt_mask_len=None, memory_mask_size=None):
59906022
encoder_input = torch.randn(encoder_input_shape)
59916023
decoder_input = torch.randn(decoder_input_shape)
5992-
model = getattr(nn, model_name)(d_model, nhead, num_encoder_layers,
6024+
model = getattr(nn, model_name)(d_model, nhead, num_encoder_layers,
59936025
num_decoder_layers, dim_feedforward, dropout)
59946026

59956027
if src_mask_len is not None:
@@ -6008,7 +6040,7 @@ def test(encoder_input_shape, decoder_input_shape,
60086040
memory_task = None
60096041

60106042
with self.assertRaises(RuntimeError):
6011-
model(encoder_input, decoder_input,
6043+
model(encoder_input, decoder_input,
60126044
src_mask=src_mask, tgt_mask=tgt_mask, memory_mask=memory_task)
60136045

60146046
correct_encoder_input_shape = (seq_len, bsz, d_model)
@@ -6043,7 +6075,7 @@ def update_shape(shape, dim, new_dim_size):
60436075
encoder_input_shape = correct_encoder_input_shape
60446076
decoder_input_shape = correct_decoder_input_shape
60456077
with self.assertRaises(AssertionError):
6046-
model = getattr(nn, model_name)(d_model, wrong_nhead, num_encoder_layers,
6078+
model = getattr(nn, model_name)(d_model, wrong_nhead, num_encoder_layers,
60476079
num_decoder_layers, dim_feedforward, dropout)
60486080

60496081
# Incorrect src_mask
@@ -6062,8 +6094,8 @@ def update_shape(shape, dim, new_dim_size):
60626094
encoder_input_shape = correct_encoder_input_shape
60636095
decoder_input_shape = correct_decoder_input_shape
60646096
wrong_tgt_mask_size = tgt_len + 1
6065-
test(encoder_input_shape, decoder_input_shape,
6066-
tgt_mask_len=wrong_tgt_mask_size,
6097+
test(encoder_input_shape, decoder_input_shape,
6098+
tgt_mask_len=wrong_tgt_mask_size,
60676099
memory_mask_size=(wrong_tgt_mask_size, wrong_src_mask_size))
60686100

60696101
def test_rnn_args_check(self):

torch/jit/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1808,8 +1808,8 @@ def _get_methods(cls):
18081808
_compiled_methods_whitelist = {
18091809
'forward', 'register_buffer', 'register_parameter', 'add_module',
18101810
'_apply', 'apply', 'cuda', 'cpu', 'to', 'type', 'float', 'double', 'half',
1811-
'state_dict', 'load_state_dict', '_load_from_state_dict',
1812-
'_named_members', 'parameters', 'named_parameters',
1811+
'state_dict', '_save_to_state_dict', 'load_state_dict',
1812+
'_load_from_state_dict', '_named_members', 'parameters', 'named_parameters',
18131813
'buffers', 'named_buffers', 'children', 'named_children', 'modules',
18141814
'named_modules', 'zero_grad', 'share_memory', '_get_name', 'extra_repr',
18151815
'_slow_forward', '_tracing_name', 'eval', 'train',

torch/nn/modules/module.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,26 @@ def _register_state_dict_hook(self, hook):
635635
self._state_dict_hooks[handle.id] = hook
636636
return handle
637637

638+
def _save_to_state_dict(self, destination, prefix, keep_vars):
639+
r"""Saves module state to `destination` dictionary, containing a state
640+
of the module, but not its descendants. This is called on every
641+
submodule in :meth:`~torch.nn.Module.state_dict`.
642+
643+
In rare cases, subclasses can achieve class-specific behavior by
644+
overriding this method with custom logic.
645+
646+
Arguments:
647+
destination (dict): a dict where state will be stored
648+
prefix (str): the prefix for parameters and buffers used in this
649+
module
650+
"""
651+
for name, param in self._parameters.items():
652+
if param is not None:
653+
destination[prefix + name] = param if keep_vars else param.data
654+
for name, buf in self._buffers.items():
655+
if buf is not None:
656+
destination[prefix + name] = buf if keep_vars else buf.data
657+
638658
def state_dict(self, destination=None, prefix='', keep_vars=False):
639659
r"""Returns a dictionary containing a whole state of the module.
640660
@@ -655,12 +675,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
655675
destination = OrderedDict()
656676
destination._metadata = OrderedDict()
657677
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
658-
for name, param in self._parameters.items():
659-
if param is not None:
660-
destination[prefix + name] = param if keep_vars else param.data
661-
for name, buf in self._buffers.items():
662-
if buf is not None:
663-
destination[prefix + name] = buf if keep_vars else buf.data
678+
self._save_to_state_dict(destination, prefix, keep_vars)
664679
for name, module in self._modules.items():
665680
if module is not None:
666681
module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)

0 commit comments

Comments
 (0)