Skip to content

Conversation

@ssnl
Copy link
Collaborator

@ssnl ssnl commented Oct 30, 2018

Problems with SN and DP after #12671 :

  1. in eval mode, weight_orig is not getting correct gradient SpectralNorm in eval doesn't connect grad to weight_orig #12737 .

    Fix: keep v vector around as a buffer and always calculate W = W_orig / (u @ W_orig @ v) even in eval.

  2. in training mode, the weight buffer of the parallelized module is never updated, if someone touches weight_orig and/or weight and makes them not sharing storage. So in eval the weight used is wrong.

    Fix: Make weight not a buffer anymore and always calculate it as above.

  3. Fix SpectralNorm with DataParallel #12671 changed SN to update u in-place to make DP work correctly, but then it breaks backward through two forwards (e.g., the common GAN loss D(real) - D(fake)) because the vectors needed to backprop the 1st forward is changed in the 2nd forward.

    Fix: This PR clones u and v before using them.

To maintain BC, I added a hook interface for producing and loading state_dict. This is ugly and we should really have better interface for spectral_norm. But for the purpose to fix this issue, I make this patch. Even if we have a better interface, BC mechanism for legacy loading legacy state_dict still needs to be done.

cc @t-vi @crcrpar

@ssnl
Copy link
Collaborator Author

ssnl commented Oct 31, 2018

@YaoshengFu This PR will fix the spectral norm bug you see.

@ssnl ssnl changed the title Fix spectral norm with data parallel Fix more spectral norm bugs Oct 31, 2018

This comment was marked as off-topic.

@crcrpar
Copy link
Collaborator

crcrpar commented Nov 1, 2018

I really always appreciate you guys.

I have one question as to my understanding.

u and v can/should be updated in a in-palce manner while weight_orig not.
This is why weight_orig is registered as an attribute.

Is this correct?

@ssnl
Copy link
Collaborator Author

ssnl commented Nov 1, 2018

@crcrpar Almost! weight_orig is still a parameter, so updates to it should be done on the parallelized module (the one you pass over to DataParallel) via things like optimizers explicitly started by user. Even if it is not updated in-place, it is fine, because the next time DataParallel copies it to the devices, the new weight_orig will be get from the parallelized module, and broadcast over.

u and v, however, are a bit different because they are updated when the module is activated, and not within users' control. So it is our job to automatically update it. And to ensure that such update works with DataParallel, it needs to be done in-place.

Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, but I'm not very familiar with the spectral norm code feel a bit lost.

If @t-vi has time and is familiar with this, his review may be helpful.

@ssnl
Copy link
Collaborator Author

ssnl commented Nov 1, 2018

@t-vi and @crcrpar It would be great if one or both of you could take a look at the SN changes and let me know if they look good :)

test/test_nn.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

test/test_nn.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@crcrpar
Copy link
Collaborator

crcrpar commented Nov 5, 2018

Sorry for my late response.

Changes are good but I have some questions about test codes.
So could you tell me if you have the time?

@ssnl
Copy link
Collaborator Author

ssnl commented Nov 5, 2018

Thank you for your comments @crcrpar :)

@zmurez
Copy link

zmurez commented Nov 5, 2018

I am still having a problem with spectral norm complaining about in place changes when followed by batch normalization and distributed across 2 or more GPUs using data parallel. This architecture arises in Big GAN (https://arxiv.org/abs/1809.11096). Interestingly this is not an issue with group normalization, instance normalization, or no normalization. It is also not an issue on a single GPU.

@t-vi
Copy link
Collaborator

t-vi commented Nov 5, 2018

So to get this a bit sorted, how many cases do we have? with/without DP * eval/training * weight.requires_grad=True/False * ???.
I'd feel much more comfortable if I had a list of everything that is expected to work.

@ssnl
Copy link
Collaborator Author

ssnl commented Nov 5, 2018

@zmurez even with this patch?

@ssnl
Copy link
Collaborator Author

ssnl commented Nov 5, 2018

@t-vi Yeah, I think those 8 cases are all we need to consider.

@zmurez
Copy link

zmurez commented Nov 5, 2018

@zmurez even with this patch?

I think so. It complained regardless of normalization and number of GPU's prior to adding this patch. However, I just copied the relevant lines into my own spectral_norm.py file instead of pulling the entire branch... So it is possible I missed something... but it seems unlikely since all the other cases are fixed.

Note, I am still using the latest stable release. To get this patch to work I also grabbed a copy of the normalize function, and implemented my on chain_matmul (a single if statement in this case with 3 matrices).

Does this bug exist for you? If not I guess I will have to consider updating to the unstable version.

Thanks

@ssnl
Copy link
Collaborator Author

ssnl commented Nov 5, 2018

@zmurez It is possible that other changes are needed. Do you have a small repro script I can try on my build?

@zmurez
Copy link

zmurez commented Nov 5, 2018

@zmurez It is possible that other changes are needed. Do you have a small repro script I can try on my build?

import torch
import torch.nn as nn

#from spectral_norm import spectral_norm       # my local spectral_norm patch
from torch.nn.utils import spectral_norm       # torch implementation

dim=5
batchsize=10
net = nn.DataParallel(nn.Sequential(
            #nn.Conv2d(dim, dim,1),
            nn.BatchNorm2d(dim),
            #nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(dim, dim,1)),
        )).cuda()
noise = torch.randn(batchsize, dim, 1, 1).cuda()
net.zero_grad()
out = net(noise).sum()
out.backward()
print('no bug')

@ssnl
Copy link
Collaborator Author

ssnl commented Nov 5, 2018

@zmurez Thanks! This is actually a bug elsewhere. I have submitted a fix at #13594. With these two patches together, I can run your script successfully.

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding all the new tests! They seem to be comprehensive and I think the patch is good with them.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ssnl
Copy link
Collaborator Author

ssnl commented Nov 5, 2018

@t-vi @crcrpar Thank you for your reviews! I appreciate them. :)

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ssnl is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zmurez
Copy link

zmurez commented Nov 5, 2018

It seems that when spectral norm is applied to a conv module, u is randomly initialized. Then v is computed such that the invariant holds.

  1. This initial solve is very very slow (over tens of minutes) for large networks.
    a) Is this really necessary or is a random initialization ok?
    b) Can this solve be replaced by a sequence of power iterations?

  2. If the weights of the conv module are later initialized using nn.init, I think this invariant will be broken?

@ssnl
Copy link
Collaborator Author

ssnl commented Nov 5, 2018

@zmurez Good point. I'll think about ways to fix it.

@t-vi
Copy link
Collaborator

t-vi commented Nov 6, 2018

For the initialization I think it should be not terribly important whether to init u or v randomly, so that could likely be changed. For loading the state dict: this only happens when there you load an "old version" state dict. How about warning about the performance and (potentially) offering a switch to forward re-compute u instead of solving for v. (I think it might be save to "do the right, slow thing" by default rather than the other way round.)

@crcrpar
Copy link
Collaborator

crcrpar commented Nov 6, 2018

I agree with t-vi. Random initialization of v is not harmful.

@ssnl
Copy link
Collaborator Author

ssnl commented Nov 6, 2018

@t-vi @crcrpar @zmurez Thank you. I have removed the solving part at initialization in the new commit :)

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ssnl has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ssnl is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ssnl ssnl deleted the sn_eval_dp branch November 7, 2018 17:21
facebook-github-bot pushed a commit that referenced this pull request Nov 8, 2018
Summary:
In `broadcast_coalesced`, since multiple variables can be "views" of a big flattened tensor, they can share the same version counter. However, this base flat tensor is not exposed and they don't share any memory locations, so this is not necessary. Furthermore, it can cause problems, e.g., when two buffers are broadcast together in `DataParallel` and one of them is modified in-place during `forward` but the other is needed in backward, autograd engine will complain.

Fixing the bug discovered at #13350 (comment)

edit: This is a very real problem. E.g., consider using Spectral Norm + Batch Norm together.
Pull Request resolved: #13594

Differential Revision: D12967311

Pulled By: SsnL

fbshipit-source-id: 52998dbabe149f575cf0fb79e7016f0b95e4b9e5
facebook-github-bot pushed a commit that referenced this pull request Feb 27, 2019
Summary:
Causing a problem with spectral norm, although SN won't use that anymore after #13350 .
Pull Request resolved: #13352

Differential Revision: D14209562

Pulled By: ezyang

fbshipit-source-id: f5e3183e1e7050ac5a66d203de6f8cf56e775134
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants