Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4669,6 +4669,38 @@ def test_load_state_dict_ref_cycle(self):

self.assertEqual(refcycles, 0)

def test_load_state_dict_custom(self):

class CustomState(nn.Module):
def __init__(self):
super(CustomState, self).__init__()
self.param = torch.nn.Parameter(torch.ones(1))
self.sub = torch.nn.Linear(5, 5)

def _save_to_state_dict(self, destination, prefix, keep_vars):
destination[prefix + "serialized"] = self.param.data + 1

def _load_from_state_dict(self, state_dict, prefix, local_metadata,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to be overriding this method directly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be symmetric to the _save_to_state_dict (mostly, except hooks). Better suggestions?

strict, missing_keys, unexpected_keys,
error_msgs):
# skip some of the error handling
self.param.data.copy_(state_dict[prefix + "serialized"] - 1)

# use sequential to verify nesting
m = nn.Sequential(CustomState())
m[0].param[0] = 10
m[0].sub.weight[0, 0] = 555
state_dict = m.state_dict()
self.assertEqual(state_dict["0.serialized"].item(), 11)
self.assertIn("0.sub.weight", state_dict)
self.assertNotIn("0.param", state_dict)
del m
mm = nn.Sequential(CustomState())
self.assertEqual(mm[0].param[0].item(), 1)
mm.load_state_dict(state_dict)
self.assertEqual(mm[0].param[0].item(), 10)
self.assertEqual(mm[0].sub.weight[0, 0].item(), 555)

def test_parameter_assignment(self):
l = nn.Linear(5, 5)

Expand Down Expand Up @@ -5985,11 +6017,11 @@ def test_transformer_args_check(self):
wrong_d_model = 63
wrong_nhead = 5

def test(encoder_input_shape, decoder_input_shape,
def test(encoder_input_shape, decoder_input_shape,
src_mask_len=None, tgt_mask_len=None, memory_mask_size=None):
encoder_input = torch.randn(encoder_input_shape)
decoder_input = torch.randn(decoder_input_shape)
model = getattr(nn, model_name)(d_model, nhead, num_encoder_layers,
model = getattr(nn, model_name)(d_model, nhead, num_encoder_layers,
num_decoder_layers, dim_feedforward, dropout)

if src_mask_len is not None:
Expand All @@ -6008,7 +6040,7 @@ def test(encoder_input_shape, decoder_input_shape,
memory_task = None

with self.assertRaises(RuntimeError):
model(encoder_input, decoder_input,
model(encoder_input, decoder_input,
src_mask=src_mask, tgt_mask=tgt_mask, memory_mask=memory_task)

correct_encoder_input_shape = (seq_len, bsz, d_model)
Expand Down Expand Up @@ -6043,7 +6075,7 @@ def update_shape(shape, dim, new_dim_size):
encoder_input_shape = correct_encoder_input_shape
decoder_input_shape = correct_decoder_input_shape
with self.assertRaises(AssertionError):
model = getattr(nn, model_name)(d_model, wrong_nhead, num_encoder_layers,
model = getattr(nn, model_name)(d_model, wrong_nhead, num_encoder_layers,
num_decoder_layers, dim_feedforward, dropout)

# Incorrect src_mask
Expand All @@ -6062,8 +6094,8 @@ def update_shape(shape, dim, new_dim_size):
encoder_input_shape = correct_encoder_input_shape
decoder_input_shape = correct_decoder_input_shape
wrong_tgt_mask_size = tgt_len + 1
test(encoder_input_shape, decoder_input_shape,
tgt_mask_len=wrong_tgt_mask_size,
test(encoder_input_shape, decoder_input_shape,
tgt_mask_len=wrong_tgt_mask_size,
memory_mask_size=(wrong_tgt_mask_size, wrong_src_mask_size))

def test_rnn_args_check(self):
Expand Down
4 changes: 2 additions & 2 deletions torch/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1808,8 +1808,8 @@ def _get_methods(cls):
_compiled_methods_whitelist = {
'forward', 'register_buffer', 'register_parameter', 'add_module',
'_apply', 'apply', 'cuda', 'cpu', 'to', 'type', 'float', 'double', 'half',
'state_dict', 'load_state_dict', '_load_from_state_dict',
'_named_members', 'parameters', 'named_parameters',
'state_dict', '_save_to_state_dict', 'load_state_dict',
'_load_from_state_dict', '_named_members', 'parameters', 'named_parameters',
'buffers', 'named_buffers', 'children', 'named_children', 'modules',
'named_modules', 'zero_grad', 'share_memory', '_get_name', 'extra_repr',
'_slow_forward', '_tracing_name', 'eval', 'train',
Expand Down
27 changes: 21 additions & 6 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,26 @@ def _register_state_dict_hook(self, hook):
self._state_dict_hooks[handle.id] = hook
return handle

def _save_to_state_dict(self, destination, prefix, keep_vars):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.

In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably better to give specific example of when using this method is appropriate.


Arguments:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.data
for name, buf in self._buffers.items():
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.data

def state_dict(self, destination=None, prefix='', keep_vars=False):
r"""Returns a dictionary containing a whole state of the module.

Expand All @@ -655,12 +675,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.data
for name, buf in self._buffers.items():
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.data
self._save_to_state_dict(destination, prefix, keep_vars)
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
Expand Down