Skip to content

Conversation

@Jiaming-Liu
Copy link
Contributor

@Jiaming-Liu Jiaming-Liu commented Mar 27, 2018

This might cause list(params) to output in random order. In this case, in load_state_dict(), keys & values of id_map would not be matched correctly.

This might cause `list(params)` to output in random order. In this case, in `load_state_dict()`, `id_map` would not be matched correctly.
@Jiaming-Liu
Copy link
Contributor Author

Jiaming-Liu commented Mar 27, 2018

Just to make it easier to understand:

optim = torch.optim.Adam(set(p for p in model.parameters()))

should be avoided.

A reasonable use case:

biases = set(param for name, param in model.named_parameters() if 'bias' in name)
weights = [p for p in model.parameters() if p not in biases]  # make biases a set to accelerate `in`
groups = [
    dict(params=weights, lr=0.1, weight_decay=5e-4),
    dict(params=biases, lr=0.2, weight_decay=0)
]
optim = torch.optim.Adam(groups)

This might raise error after optim.load_state_dict(), and it is very hard to debug.

Traceback (most recent call last):
  File "xxxxxxxxxxxxxx.py", line 29, in <module>
    optim.step()
  File "/xxxxxxxxxxxxxx/site-packages/torch/optim/adam.py", line 69, in step
    exp_avg.mul_(beta1).add_(1 - beta1, grad)
RuntimeError: inconsistent tensor size, expected xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx

@ezyang
Copy link
Contributor

ezyang commented Mar 27, 2018

@pytorchbot test this please

Looks reasonable to me

@Jiaming-Liu
Copy link
Contributor Author

Jiaming-Liu commented Mar 27, 2018

maybe add some suggestions in the error msg?

@apaszke
Copy link
Contributor

apaszke commented Mar 27, 2018

I'd suggest this: "Optimizer parameters need to be organized in ordered collections, but the ordering of tensors in sets will change between runs. Please use a list instead."

Also, can you please add a warning that parameters need to give a deterministically ordered iterator in the optim docs? Thanks!

@apaszke
Copy link
Contributor

apaszke commented Mar 27, 2018

@pytorchbot test this please

@Jiaming-Liu
Copy link
Contributor Author

warning added. but i havent got the time to compile & see

@Jiaming-Liu
Copy link
Contributor Author

@pytorchbot test this please

@codinfox
Copy link

Nice work.

One little suggestion to the pytorch team: I think it would be better if we can assign each parameter a unique identifier (based on its hierarchy in the graph). By current design, the optimizer.load_state_dict() function assumes that the order of the stored state_dict is the same as the order currently defined in the network. This design is very fragile and error-prone. I would prefer this function to be implemented in a way like module.load_state_dict(), which does not rely on the order. Is there any reason why pytorch does not assign identifiers to parameters?

@Jiaming-Liu
Copy link
Contributor Author

@codinfox i agree with that but it's hard to find sth other than name. even hierarchy can make things nontransparent and fragile.

@codinfox
Copy link

@Jiaming-Liu Yeah, name is a good identifier.

@apaszke
Copy link
Contributor

apaszke commented Mar 28, 2018

@codinfox We don't do that just because there's no good way to get identifiers in a deterministic way. I guess we could extend the optimizer API to accept named lists of parameters, but we also need to keep the current API.

@apaszke
Copy link
Contributor

apaszke commented Mar 28, 2018

@pytorchbot test this please

@soumith soumith merged commit 31c0e23 into pytorch:master Mar 28, 2018
@soumith
Copy link
Contributor

soumith commented Mar 28, 2018

thanks @Jiaming-Liu !

@Whu-wxy
Copy link

Whu-wxy commented Oct 4, 2019

I met this problem when load state from check point. How can I solve this problem?

File "/xxxxxxxxxxxxxx/site-packages/torch/optim/adam.py", line 69, in step
    exp_avg.mul_(beta1).add_(1 - beta1, grad)
RuntimeError: The size of tensor a (256) must match the size of tensor b (1024) at 
non-singleton dimension 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants