-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add max pooling support to EmbeddingBag #5725
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add max pooling support to EmbeddingBag #5725
Conversation
|
@pytorchbot test this please |
|
@ezyang Oops. It looks like there were some compilation issues on some of the platforms (unfortunately it looks like my local linux cuda9 test environment doesn't catch everything). I believe I just fixed them. Would it be possible for you to trigger another test run? (I'll guess I'll try to trigger a test here, but I don't think I have permission. @pytorchbot test this please) |
|
@pytorchbot test this please |
|
@pytorchbot add to whitelist |
|
It looks like this pull request passes all the tests. Let me know if you want me to make any changes! |
|
@pytorchbot retest this please |
2 similar comments
|
@pytorchbot retest this please |
|
@pytorchbot retest this please |
|
@goldsborough Can you trigger another test run of this? I believe the latest CI changes caused a bunch of failures and it would be nice to double check that this code works. |
|
@lalaland can you rebase your commits on top of current master? That might help with the failures |
284382a to
43df74f
Compare
|
@apaszke Done. |
|
@pytorchbot test this please |
|
@pytorchbot retest this please |
43df74f to
48b0d07
Compare
|
There was a merge conflict due to another pull request being merged so I rebased the code and fixed the conflict. |
|
Looks like the CI had some spurious failures. @pytorchbot retest this please |
|
Another spurious failure? @pytorchbot retest this please |
|
@pytorchbot retest this please |
?? @pytorchbot retest this please |
That's well within the margin of error of our floating point operations. That testing code should probably switch over to using a maximum relative error. Also, it's failing on the sparse operations which I did not touch. @pytorchbot retest this please |
|
@pytorchbot retest this please |
|
@lalaland You can adjust the desired precision on However, it is kind of weird that only the CUDA8 test is failing, and not the CUDA9 tests. |
|
@ezyang It seems to be failing somewhat randomly. One issue here is that there is a lot of non-determinism in the order in which things get summed up. Note that the area where it's failing is not even code that I changed. It's failing in the sparse operations, which I did not touch. I think the core issue here is that static epsilon values are a bad idea. The problem is that larger floating point numbers need larger epsilons. We should have epsilon measurements relative to the magnitude of the values being compared. There are three main options here as I see them:
Which do you want me to do? |
|
@ezyang I increased the epsilon for the embedding bag tests and now all the tests are now passing. Do let me know if you want me to change anything. |
|
I just merged the branch and resolved some merge conflicts. |
|
thanks for patiently waiting @lalaland . Now that the 0.4 release is done, we have more bandwidth. I'll review the PR tomorrow. |
soumith
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR looks great, pretty well done.
We really apologize for the late review.
i'm requesting minor changes in the naming of the cuda kernels, and some other minor API usage changes, once they are pushed, this is good to merge.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@soumith Thanks for the review. I implemented your changes and updated this PR. (Well, implemented all of them except the deterministic backward gpu pass. I'll do that in a separate PR). |
|
will merge once tests pass |
|
@soumith Just wanted to give you a heads up that the tests have passed. |
|
thanks @lalaland ! |
* Add max mode support to EmbeddingBag * Lint fix * Fix compilation issue on other platforms * Rebase + don't waste memory when not in max mode * Oops, missed a spot * Fix whitespace from merge * less precision * Lower precision to avoid spurious failures * Minor typo * Switch to size()
* Add max mode support to EmbeddingBag * Lint fix * Fix compilation issue on other platforms * Rebase + don't waste memory when not in max mode * Oops, missed a spot * Fix whitespace from merge * less precision * Lower precision to avoid spurious failures * Minor typo * Switch to size()
This pull request adds max pooling support to the EmbeddingBag feature. Max pooling is a very common way of aggregating embeddings and it is quite useful to have it built-in to EmbeddingBag for both performance and ergonomics reasons.
This particular implementation of EmbeddingBag max pooling does not support sparse matrices or the scale_grad_by_freq feature. Those can be added in following pull requests if necessary.
This code has been tested by using the test_embedding_bag and test_embedding_bag_cuda unit tests within test_nn.py.
This closes #4762.