-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Allow torch.cuda.amp.GradScaler to support sparse gradients #36786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit 0c7bc4f (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 5 times. |
torch/cuda/amp/grad_scaler.py
Outdated
| # coalesce() deduplicates indices and adds all values that have the same index. | ||
| # For scaled fp16 values, there's a good chance coalescing will cause overflow, | ||
| # so we should double check the coalesced _values(). | ||
| torch._amp_non_finite_check_and_unscale_(g.coalesce()._values(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should you just replace g with its coalesced version in this case, and not do the check twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like that, but my original thinking was I don't want to replace param.grad with a new reference because unscale_ advertises itself as in-place (last paragraph of PR wall of text). If you think it's ok for unscale_ to replace sparse .grads it's an easy change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0c7bc4f replaces param.grad with the coalesced, unscaled version if param.grad was fp16 and uncoalesced.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Should close #35810.
I decided to keep sparse handling on the Python side for clarity, although it could be moved to the C++ side (into
_amp_non_finite_check_and_unscale_) without much trouble.For non-fp16 sparse grads the logic is simple (call
_amp_non_finite_check_and_unscale_ongrad._values()) instead ofgraditself. At least I hope it's that easy.For fp16 sparse grads, it's tricker. Sparse tensors can be uncoalesced. From the Note:
An uncoalesced scaled fp16 grad may have values at duplicate coordinates that are all finite but large, such that adding them to make the coalesced version WOULD cause overflows.** If I checked
_values()on the uncoalesced version, it might not report overflows, but I think it should.So, if the grad is sparse, fp16, and uncoalesced, I still call
_amp_non_finite_check_and_unscale_to unscalegrad._values()in-place, but I also double-check the coalesced version by calling a second_amp_non_finite_check_and_unscale_ongrad.coalesce()._values().coalesce()is out-of-place, so this call doesn't redundantly affectgrad._values(), but it does have the power to populate the samefound_inftensor. Theis_coalesced()check andcoalesce()probably aren't great for performance, but if someone needs a giant embedding table in FP16, they're better than nothing and memorywise, they'll only create a copy of nnz gradient values+indices, which is still way better than changing the whole table to FP32.An
unscalevariant with liberty to create unscaled grads out-of-place, and replaceparam.gradinstead of writing through it, could get away with just one_amp_non_finite_check_and_unscale_. It could saycoalesced = grad.coalesced(), do only the stronger_amp_non_finite_check_and_unscale_oncoalesced._values(), and setparam.grad = coalesced. I could even avoid replacingparam.graditself by going one level deeper and settingparam.grad's indices and values tocoalesced's, but that seems brittle and still isn't truly "in place".** you could whiteboard an uncoalesced fp32 grad with the same property, but fp32's range is big enough that I don't think it's realistic.