Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def _test_state_dict(self, weight, bias, input, constructor):

def fn_base(optimizer, weight, bias):
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
i = input_cuda if weight.is_cuda else input
loss = (weight.mv(i) + bias).pow(2).sum()
loss.backward()
return loss

Expand Down Expand Up @@ -161,6 +162,29 @@ def fn_base(optimizer, weight, bias):
# Make sure state dict wasn't modified
self.assertEqual(state_dict, state_dict_c)

# Check that state dict can be loaded even when we cast parameters
# to a different type and move to a different device.
if not torch.cuda.is_available():
return

input_cuda = Variable(input.data.float().cuda())
weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
optimizer_cuda = constructor(weight_cuda, bias_cuda)
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)

state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_cuda.load_state_dict(state_dict_c)
# Make sure state dict wasn't modified
self.assertEqual(state_dict, state_dict_c)

for i in range(20):
optimizer.step(fn)
optimizer_cuda.step(fn_cuda)
self.assertEqual(weight, weight_cuda)
self.assertEqual(bias, bias_cuda)

def _test_basic_cases(self, constructor, ignore_multidevice=False):
self._test_state_dict(
torch.randn(10, 5),
Expand Down
4 changes: 3 additions & 1 deletion torch/optim/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def step(self, closure):
line_search_fn = group['line_search_fn']
history_size = group['history_size']

state = self.state['global_state']
# NOTE: LBFGS has only global state, but we register it as state for
# the first param, because this helps with casting in load_state_dict
state = self.state[self._params[0]]
state.setdefault('func_evals', 0)
state.setdefault('n_iter', 0)

Expand Down
31 changes: 28 additions & 3 deletions torch/optim/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import defaultdict
from collections import defaultdict, Iterable

import torch
from copy import deepcopy
Expand Down Expand Up @@ -96,8 +96,33 @@ def load_state_dict(self, state_dict):
id_map = {old_id: p for old_id, p in
zip(chain(*(g['params'] for g in saved_groups)),
chain(*(g['params'] for g in groups)))}
state = defaultdict(
dict, {id_map.get(k, k): v for k, v in state_dict['state'].items()})

def cast(param, value):
"""Make a deep copy of value, casting all tensors to device of param."""
if torch.is_tensor(value):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
if any(tp in type(param.data).__name__ for tp in {'Half', 'Float', 'Double'}):
value = value.type_as(param.data)
value = value.cuda(param.get_device()) if param.is_cuda else value.cpu()
return value
elif isinstance(value, dict):
return {k: cast(param, v) for k, v in value.items()}
elif isinstance(value, Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value

# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict['state'].items():
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
else:
state[k] = v

# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
Expand Down