Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 173 additions & 52 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,11 +1717,16 @@ def test_weight_norm(self):
m = torch.nn.utils.weight_norm(m, dim=None)
self.assertEqual(m(input), expected_output)

with self.assertRaisesRegex(RuntimeError, 'register two weight_norm hooks'):
m = torch.nn.utils.weight_norm(m)
m = torch.nn.utils.weight_norm(m)

def test_weight_norm_pickle(self):
m = torch.nn.utils.weight_norm(nn.Linear(5, 7))
m = pickle.loads(pickle.dumps(m))
self.assertIsInstance(m, nn.Linear)

@skipIfRocm
def test_spectral_norm(self):
input = torch.randn(3, 5)
m = nn.Linear(5, 7)
Expand All @@ -1734,8 +1739,9 @@ def test_spectral_norm(self):
# weight_u should be just a reused buffer
self.assertTrue(hasattr(m, 'weight_u'))
self.assertTrue('weight_u' in m._buffers)
self.assertTrue('weight' in m._buffers)
self.assertTrue('weight_v' in m._buffers)
# weight should be a plain attribute, not counted as a buffer or a param
self.assertFalse('weight' in m._buffers)
self.assertFalse('weight' in m._parameters)
# it should also be sharing storage as `weight_orig`
self.assertEqual(m.weight_orig.storage(), m.weight.storage())
Expand All @@ -1749,58 +1755,173 @@ def test_spectral_norm(self):
self.assertTrue(hasattr(m, 'weight'))
self.assertTrue('weight' in m._parameters)

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_spectral_norm_dp(self):
for requires_grad in (True, False):
m = nn.Linear(5, 7).to(torch.device('cuda'))
m.weight.requires_grad_(requires_grad)
with self.assertRaisesRegex(RuntimeError, 'register two spectral_norm hooks'):
m = torch.nn.utils.spectral_norm(m)
dpm = torch.nn.DataParallel(m, [0, 1])
self.assertTrue(hasattr(m, 'weight_u'))
u0 = m.weight_u.clone()

# assert that u is updated
input = torch.randn(2, 5, device=torch.device('cuda'))
dpm(input)
self.assertNotEqual(u0, m.weight_u)

# test that eval works
dpm.eval()
eval_out0 = dpm(input)
self.assertEqual(eval_out0, dpm(input))

def test_spectral_norm_eval_remove(self):
inp = torch.randn(3, 5)
m = nn.Linear(5, 7)
m = torch.nn.utils.spectral_norm(m)
x0 = m(inp)
m.eval()
# test that eval mode and removing / adding+removing doesn't change weight and output
x1 = m(inp)
x2 = m(inp)
self.assertEqual(x0, x1)
self.assertEqual(x0, x2)
# test that we can backward several times without running into problems
x1 = m(inp)
x1.sum().backward()
x1 = m(inp)
x1.sum().backward()
# test removing
m = torch.nn.utils.remove_spectral_norm(m)
x3 = m(inp)
self.assertEqual(x0, x3)
m = torch.nn.utils.spectral_norm(m)
m = torch.nn.utils.remove_spectral_norm(m)
x4 = m(inp)
self.assertEqual(x0, x4)
# check that removing after train doesn't change output
m.train()
m = torch.nn.utils.spectral_norm(m)
for i in range(5):
x0 = m(inp)
m = torch.nn.utils.remove_spectral_norm(m)
x1 = m(inp)
self.assertEqual(x0, x1)
m = torch.nn.utils.spectral_norm(m)

# test correctness in training/eval modes and cpu/multi-gpu settings
for apply_dp in (True, False):
if apply_dp:
if not TEST_MULTIGPU:
continue
device = torch.device('cuda:0')

def maybe_wrap(m):
return torch.nn.DataParallel(m, [0, 1])
else:
device = torch.device('cpu')

def maybe_wrap(m):
return m

for requires_grad in (True, False):
m = nn.Linear(3, 4).to(device)
m.weight.requires_grad_(requires_grad)
m = torch.nn.utils.spectral_norm(m)
wrapped_m = maybe_wrap(m)
self.assertTrue(hasattr(m, 'weight_u'))
u0 = m.weight_u.clone()
v0 = m.weight_v.clone()

# TEST TRAINING BEHAVIOR

# assert that u and v are updated
input = torch.randn(2, 3, device=device)
out = wrapped_m(input)
self.assertNotEqual(u0, m.weight_u)
self.assertNotEqual(v0, m.weight_v)

# assert that backprop reaches weight_orig
# can't use gradcheck because the function changes as we
# activate through it in training mode
if requires_grad:
torch.autograd.grad(out.sum(), m.weight_orig)

# test backward works with multiple forwards
# it uses training mode so we need to reset `u` and `v` vectors
# to same value at beginning for finite difference test to pass
saved_u = m.weight_u.clone()
saved_v = m.weight_v.clone()

def fn(input):
m.weight_u.data.copy_(saved_u)
m.weight_v.data.copy_(saved_v)
out0 = wrapped_m(input)
out1 = wrapped_m(input)
return out0 + out1

torch.autograd.gradcheck(fn, (input.clone().requires_grad_(),))

# test removing
pre_remove_out = wrapped_m(input)
m = torch.nn.utils.remove_spectral_norm(m)
self.assertEqual(wrapped_m(input), pre_remove_out)

m = torch.nn.utils.spectral_norm(m)
for i in range(3):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

pre_remove_out = wrapped_m(input)
m = torch.nn.utils.remove_spectral_norm(m)
self.assertEqual(wrapped_m(input), pre_remove_out)

# TEST EVAL BEHAVIOR

m = torch.nn.utils.spectral_norm(m)
wrapped_m(input)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

last_train_out = wrapped_m(input)
last_train_u = m.weight_u.clone()
last_train_v = m.weight_v.clone()
wrapped_m.zero_grad()
wrapped_m.eval()

eval_out0 = wrapped_m(input)
# assert eval gives same result as last training iteration
self.assertEqual(eval_out0, last_train_out)
# assert doing more iteartion in eval don't change things
self.assertEqual(eval_out0, wrapped_m(input))
self.assertEqual(last_train_u, m.weight_u)
self.assertEqual(last_train_v, m.weight_v)

# test backward works with multiple forwards in mixed training
# and eval modes
# it uses training mode so we need to reset `u` and `v` vectors
# to same value at beginning for finite difference test to pass
saved_u = m.weight_u.clone()
saved_v = m.weight_v.clone()

def fn(input):
m.weight_u.data.copy_(saved_u)
m.weight_v.data.copy_(saved_v)
wrapped_m.train()
out0 = wrapped_m(input)
wrapped_m.eval()
out1 = wrapped_m(input)
wrapped_m.train()
out2 = wrapped_m(input)
wrapped_m.eval()
out3 = wrapped_m(input)
return out0 + out1 + out2 + out3

torch.autograd.gradcheck(fn, (input.clone().requires_grad_(),))

# assert that backprop reaches weight_orig in eval
if requires_grad:
def fn(weight):
return wrapped_m(input)

torch.autograd.gradcheck(fn, (m.weight_orig,))

def test_spectral_norm_load_state_dict(self):
inp = torch.randn(2, 3)
for activate_times in (0, 3):
# Test backward compatibility
# At version None -> 1: weight becomes not a buffer and v vector becomes a buffer
m = nn.Linear(3, 5)
snm = torch.nn.utils.spectral_norm(m)
snm.train()
for _ in range(activate_times):
snm(inp)

# craft a version None state_dict
version_none_state_dict = deepcopy(snm.state_dict())
self.assertEqual({'weight_orig', 'bias', 'weight_u', 'weight_v'}, set(version_none_state_dict.keys()))
self.assertIn('spectral_norm', version_none_state_dict._metadata[''])
del version_none_state_dict._metadata['']['spectral_norm'] # remove metadata info
del version_none_state_dict['weight_v'] # remove v vector
version_none_state_dict['weight'] = snm.weight.detach().clone() # set W as a buffer

# normal state_dict
version_latest_state_dict = deepcopy(snm.state_dict())

snm.eval()
out0_eval = snm(inp)
snm.train()
out1_train = snm(inp)
out2_train = snm(inp)
snm.eval()
out3_eval = snm(inp)

snm.load_state_dict(version_none_state_dict)
if activate_times > 0:
# since in loading version None state dict, we assume that the
# values in the state dict have gone through at lease one
# forward, we only test for equivalence when activate_times > 0.
snm.eval()
self.assertEqual(out0_eval, snm(inp))
snm.train()
self.assertEqual(out1_train, snm(inp))
self.assertEqual(out2_train, snm(inp))
snm.eval()
self.assertEqual(out3_eval, snm(inp))

# Test normal loading
snm.load_state_dict(version_latest_state_dict)
snm.eval()
self.assertEqual(out0_eval, snm(inp))
snm.train()
self.assertEqual(out1_train, snm(inp))
self.assertEqual(out2_train, snm(inp))
snm.eval()
self.assertEqual(out3_eval, snm(inp))

def test_spectral_norm_dim(self):
inp = torch.randn(2, 3, 10, 12)
Expand Down
6 changes: 3 additions & 3 deletions torch/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def extra_repr(self):
return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
'track_running_stats={track_running_stats}'.format(**self.__dict__)

def _load_from_state_dict(self, state_dict, prefix, metadata, strict,
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = metadata.get('version', None)
version = local_metadata.get('version', None)

if (version is None or version < 2) and self.track_running_stats:
# at version 2: added num_batches_tracked buffer
Expand All @@ -82,7 +82,7 @@ def _load_from_state_dict(self, state_dict, prefix, metadata, strict,
state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)

super(_BatchNorm, self)._load_from_state_dict(
state_dict, prefix, metadata, strict,
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)


Expand Down
6 changes: 3 additions & 3 deletions torch/nn/modules/instancenorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False,
def _check_input_dim(self, input):
raise NotImplementedError

def _load_from_state_dict(self, state_dict, prefix, metadata, strict,
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = metadata.get('version', None)
version = local_metadata.get('version', None)
# at version 1: removed running_mean and running_var when
# track_running_stats=False (default)
if version is None and not self.track_running_stats:
Expand All @@ -38,7 +38,7 @@ def _load_from_state_dict(self, state_dict, prefix, metadata, strict,
state_dict.pop(key)

super(_InstanceNorm, self)._load_from_state_dict(
state_dict, prefix, metadata, strict,
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)

def forward(self, input):
Expand Down
48 changes: 42 additions & 6 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(self):
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True

Expand Down Expand Up @@ -498,8 +500,13 @@ def __call__(self, *input, **kwargs):

def __setstate__(self, state):
self.__dict__.update(state)
# Support loading old checkpoints that don't have the following attrs:
if '_forward_pre_hooks' not in self.__dict__:
self._forward_pre_hooks = OrderedDict()
if '_state_dict_hooks' not in self.__dict__:
self._state_dict_hooks = OrderedDict()
if '_load_state_dict_pre_hooks' not in self.__dict__:
self._load_state_dict_pre_hooks = OrderedDict()

def __getattr__(self, name):
if '_parameters' in self.__dict__:
Expand Down Expand Up @@ -571,6 +578,17 @@ def __delattr__(self, name):
else:
object.__delattr__(self, name)

def _register_state_dict_hook(self, hook):
r"""These hooks will be called with arguments: `self`, `state_dict`,
`prefix`, `local_metadata`, after the `state_dict` of `self` is set.
Note that only parameters and buffers of `self` or its children are
guaranteed to exist in `state_dict`. The hooks may modify `state_dict`
inplace or return a new one.
"""
handle = hooks.RemovableHandle(self._state_dict_hooks)
self._state_dict_hooks[handle.id] = hook
return handle

def state_dict(self, destination=None, prefix='', keep_vars=False):
r"""Returns a dictionary containing a whole state of the module.

Expand All @@ -590,7 +608,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = dict(version=self._version)
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.data
Expand All @@ -600,16 +618,31 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination

def _load_from_state_dict(self, state_dict, prefix, metadata, strict, missing_keys, unexpected_keys, error_msgs):
def _register_load_state_dict_pre_hook(self, hook):
r"""These hooks will be called with arguments: `state_dict`, `prefix`,
`local_metadata`, `strict`, `missing_keys`, `unexpected_keys`,
`error_msgs`, before loading `state_dict` into `self`. These arguments
are exactly the same as those of `_load_from_state_dict`.
"""
handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks)
self._load_state_dict_pre_hooks[handle.id] = hook
return handle

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr`metadata`.
For state dicts without meta data, :attr`metadata` is empty.
module in input :attr:`state_dict` is provided as :attr`local_metadata`.
For state dicts without metadata, :attr`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `metadata.get("version", None)`.
the version number at `local_metadata.get("version", None)`.

.. note::
:attr:`state_dict` is not the same object as the input
Expand All @@ -621,7 +654,7 @@ def _load_from_state_dict(self, state_dict, prefix, metadata, strict, missing_ke
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
metadata (dict): a dict containing the metadata for this moodule.
local_metadata (dict): a dict containing the metadata for this moodule.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
Expand All @@ -634,6 +667,9 @@ def _load_from_state_dict(self, state_dict, prefix, metadata, strict, missing_ke
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
local_state = {k: v.data for k, v in local_name_params if v is not None}

Expand Down
Loading