Skip to content

Conversation

@mcarilli
Copy link
Collaborator

@mcarilli mcarilli commented Apr 17, 2020

Should close #35810.

I decided to keep sparse handling on the Python side for clarity, although it could be moved to the C++ side (into _amp_non_finite_check_and_unscale_) without much trouble.

For non-fp16 sparse grads the logic is simple (call _amp_non_finite_check_and_unscale_ on grad._values()) instead of grad itself. At least I hope it's that easy.

For fp16 sparse grads, it's tricker. Sparse tensors can be uncoalesced. From the Note:

Our sparse tensor format permits uncoalesced sparse tensors, where there may be duplicate coordinates in the indices; in this case, the interpretation is that the value at that index is the sum of all duplicate value entries.

An uncoalesced scaled fp16 grad may have values at duplicate coordinates that are all finite but large, such that adding them to make the coalesced version WOULD cause overflows.** If I checked _values() on the uncoalesced version, it might not report overflows, but I think it should.

So, if the grad is sparse, fp16, and uncoalesced, I still call _amp_non_finite_check_and_unscale_ to unscale grad._values() in-place, but I also double-check the coalesced version by calling a second _amp_non_finite_check_and_unscale_ on grad.coalesce()._values(). coalesce() is out-of-place, so this call doesn't redundantly affect grad._values(), but it does have the power to populate the same found_inf tensor. The is_coalesced() check and coalesce() probably aren't great for performance, but if someone needs a giant embedding table in FP16, they're better than nothing and memorywise, they'll only create a copy of nnz gradient values+indices, which is still way better than changing the whole table to FP32.

An unscale variant with liberty to create unscaled grads out-of-place, and replace param.grad instead of writing through it, could get away with just one _amp_non_finite_check_and_unscale_. It could say coalesced = grad.coalesced(), do only the stronger _amp_non_finite_check_and_unscale_ on coalesced._values(), and set param.grad = coalesced. I could even avoid replacing param.grad itself by going one level deeper and setting param.grad's indices and values to coalesced's, but that seems brittle and still isn't truly "in place".

** you could whiteboard an uncoalesced fp32 grad with the same property, but fp32's range is big enough that I don't think it's realistic.

@mcarilli mcarilli requested a review from ngimel April 17, 2020 03:19
@mcarilli mcarilli changed the title Allow gradient scaling to support sparse gradients Allow GradScaler to support sparse gradients Apr 17, 2020
@mcarilli mcarilli changed the title Allow GradScaler to support sparse gradients Allow torch.cuda.amp.GradScaler to support sparse gradients Apr 17, 2020
@dr-ci
Copy link

dr-ci bot commented Apr 17, 2020

💊 CI failures summary and remediations

As of commit 0c7bc4f (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 5 times.

@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 18, 2020
@mruberry mruberry added the module: cuda Related to torch.cuda, and CUDA support in general label Apr 18, 2020
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should double check the coalesced _values().
torch._amp_non_finite_check_and_unscale_(g.coalesce()._values(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should you just replace g with its coalesced version in this case, and not do the check twice?

Copy link
Collaborator Author

@mcarilli mcarilli Jun 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that, but my original thinking was I don't want to replace param.grad with a new reference because unscale_ advertises itself as in-place (last paragraph of PR wall of text). If you think it's ok for unscale_ to replace sparse .grads it's an easy change.

Copy link
Collaborator Author

@mcarilli mcarilli Jun 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0c7bc4f replaces param.grad with the coalesced, unscaled version if param.grad was fp16 and uncoalesced.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in b4ccdef.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: amp (automated mixed precision) autocast module: cuda Related to torch.cuda, and CUDA support in general open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Lack of AMP support for sparse gradients

6 participants