Skip to content

Conversation

@janeyx99
Copy link
Contributor

@janeyx99 janeyx99 commented Jan 20, 2023

Attempts to fix #92656

BC-breaking! This changes the default of zero_grad in optim and in nn to default set grads to None instead of zero tensors. We are changing the default because there are proven perf wins and existing code has typically not regressed due to this change. (will probably have to flesh out this note more).

Stack from ghstack (oldest at bottom):


BC-breaking note

Gradients are now set to None instead of zeros by default in torch.optim.*.zero_grad() and torch.nn.Module.zero_grad()

This changes the default behavior of zero_grad() to zero out the grads by setting them to None instead of zero tensors. In other words, the set_to_none kwarg is now True by default instead of False. Setting grads to None reduces peak memory usage and increases performance. This will break code that directly accesses data or does computation on the grads after calling zero_grad() as they will now be None. To revert to the old behavior, pass in zero_grad(set_to_none=False).

1.13

>>> import torch
>>> from torch import nn
>>> module = nn.Linear(5, 5)
>>> i = torch.randn(2, 5, requires_grad=True)
>>> module(i).sum().backward()
>>> module.zero_grad()
>>> module.weight.grad == None
False
>>> module.weight.grad.data
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
>>> module.weight.grad + 1.0
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

2.0

>>> import torch
>>> from torch import nn
>>> module = nn.Linear(5, 5)
>>> i = torch.randn(2, 5, requires_grad=True)
>>> module(i).sum().backward()
>>> module.zero_grad()
>>> module.weight.grad == None
True
>>> module.weight.grad.data
AttributeError: 'NoneType' object has no attribute 'data'
>>> module.weight.grad + 1.0
TypeError: unsupported operand type(s) for +: 'NoneType' and 'float'

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92731

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit b589af5:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: cpp release notes category label Jan 20, 2023
janeyx99 added a commit that referenced this pull request Jan 20, 2023
ghstack-source-id: 2ab94b1
Pull Request resolved: #92731
@janeyx99 janeyx99 added release notes: nn release notes category topic: bc breaking topic category and removed release notes: cpp release notes category labels Jan 20, 2023
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite a few tests need updating. But YES!

@janeyx99 janeyx99 added the keep-going Don't stop on first failure, keep running tests until the end label Jan 24, 2023
janeyx99 added a commit that referenced this pull request Jan 24, 2023
ghstack-source-id: 20668e6
Pull Request resolved: #92731
@janeyx99 janeyx99 added ciflow/trunk Trigger trunk jobs on your pull request and removed keep-going Don't stop on first failure, keep running tests until the end labels Jan 24, 2023
janeyx99 added a commit that referenced this pull request Jan 25, 2023
ghstack-source-id: ccf55da
Pull Request resolved: #92731
@pytorch-bot pytorch-bot bot added the ciflow/mps Run MPS tests (subset of trunk) label Jan 25, 2023
@ngimel
Copy link
Collaborator

ngimel commented Jan 26, 2023

Can you please add a bc-breaking note here?

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, if tests pass

@janeyx99
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@soumith
Copy link
Contributor

soumith commented Feb 3, 2023

ooooh exciting. this is a big change :)

@FindHao
Copy link
Member

FindHao commented Feb 9, 2023

This PR leads speedups widely for many models in torchbench. 22 models obtain over 1.03X speedup on A100! Thanks for your work!
But I just found yolov3 has about a 17% slowdown for training on A100. The nightly regression test of torchbench runs on T4 for now. So the test may miss this case. Working on finding the root cause now.
Somehow, I can't reproduce the same results in the profiling trace generated by pytorch profiler.

@janeyx99
Copy link
Contributor Author

@FindHao Recently came back from PTO so sorry this response is delayed. Thanks for this callout! I'm curious about the yolov3 slowdown--have you been able to root cause it thus far? The simple workaround is to just directly pass set_to_none=False to regain perf, but I would like to help with figuring out the cause here.

@FindHao
Copy link
Member

FindHao commented Feb 23, 2023

Hi @janeyx99 , we found it is caused by torch.cuda.empty_cache(). It takes longer than the original version. Since torchbench only tests one iteration for training and we don't need to empty the cache, we removed this function as a workaround. But we still don't know why this function takes longer. Do you have any ideas?

@janeyx99
Copy link
Contributor Author

@FindHao Ah, I spoke with @albanD and this is not surprising. When set_to_none was False, the same grad tensor was allocated once and kept alive throughout the iterations (it would be filled with 0s and then filled with values, and so forth).

Now, because we set grad to None, it would increase the number of allocations whenever we alternate between None -> real values -> None -> real values, and so forth. This is typically not a problem except for certain configurations (like given a particular batch_size on a particular GPU) where the stars align just right and allocations incur an actual communication to the GPU vs being able to service from existing allocated memory.

@FindHao
Copy link
Member

FindHao commented Feb 23, 2023

@janeyx99 Thanks for your explanation! It makes sense. I have another question. If I understand it correctly, setting it to none means marking the memory allocated to current tensors as going to be freed, and it will be deallocated by the torch.cuda.empty_cache from the memory pool, right? If so, does it mean the memory usage would increase until we call empty_cache?

@ngimel
Copy link
Collaborator

ngimel commented Feb 23, 2023

No, set_to_none=True decreases memory usage, as it frees grad memory when called, and doesn't allocate them again until they are computed (which will likely be after high memory watermark is reached). @robieta had plots showing how set_to_none decreases memory usage.

@janeyx99
Copy link
Contributor Author

Haha I will attempt to answer this question by setting down some terminology. PyTorch has a CUDACachingAllocator which reserves and manages memory for the duration of a PyTorch program. In a sense, you can imagine that it reserves a chunk of memory from the GPU and interfaces on top of the actual GPU so that every time the program releases/requests memory, we don't have to talk to the GPU. For example, if the PyTorch program releases memory, our CachingAllocator will hold onto it instead of immediately releasing it to the GPU so that later on when the program wants memory again, it can lend that memory out. This would save time as communication between the GPU would be avoided entirely.

Thus, we have the concept of memory reserved and memory allocated. The memory reserved is the total memory managed by the CUDACachingAllocator, and the memory allocated is the memory taken up by actual PyTorch tensors. Setting to None here will "free" the tensor so that the memory allocated goes down immediately BUT the memory reserved would remain the same. Calling torch.cuda.empty_cache() will empty the memory so that memory reserved approaches memory allocated.

@FindHao
Copy link
Member

FindHao commented Feb 23, 2023

@ngimel @janeyx99 Thanks for all your explanation!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: nn release notes category topic: bc breaking topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants