-
Notifications
You must be signed in to change notification settings - Fork 26.3k
C++ APIs Transformer NN Module Top Layer #44333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit fd9f2aa (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 17 times. |
Codecov Report
@@ Coverage Diff @@
## gh/glaringlee/25/base #44333 +/- ##
======================================================
Coverage 68.00% 68.00%
======================================================
Files 382 382
Lines 49527 49527
======================================================
+ Hits 33679 33680 +1
+ Misses 15848 15847 -1
Continue to review full report at Codecov.
|
This is to provide C++ APIs for Transformer NN Module which has been implemented in python here:https://github.com/pytorch/pytorch/blob/eace0533985641d9c2f36e43e3de694aca886bd9/torch/nn/modules/transformer.py It also adjusted the TransformerDecoderLayer's module registration order to 100% match the python impl. Differential Revision: [D23584010](https://our.internmc.facebook.com/intern/diff/D23584010) [ghstack-poisoned]
This is to provide C++ APIs for Transformer NN Module which has been implemented in python here:https://github.com/pytorch/pytorch/blob/eace0533985641d9c2f36e43e3de694aca886bd9/torch/nn/modules/transformer.py It also adjusted the TransformerDecoderLayer's module registration order to 100% match the python impl. Differential Revision: [D23584010](https://our.internmc.facebook.com/intern/diff/D23584010) [ghstack-poisoned]
| return output; | ||
| } | ||
|
|
||
| Tensor TransformerImpl::generate_square_subsequent_mask(int64_t sz) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if we want to have generate_square_subsequent_mask func within the transformer. Although in the python version it is, it's really not part of the paper "Attention is All Your Need".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangguanheng66 do you know why generate_square_subsequent_mask is exposed as a part of the Python API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangguanheng66 if I could start from scratch, I would not include it in the transformer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 @zhangguanheng66
If we know someone already use this function in python, then I think we'd better keep this function here since otherwise, this will be an unexpected bc breaking change to user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not really BC-breaking if the functionality doesn't exist in C++ yet. It sounds like we might want to deprecate generate_square_subsequent_mask on the Python side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangguanheng66
Should I deprecate this, I saw this function is in the python transformer test driver as well?
Or I probably can move this to the util namespace.
But I need to know the following:
- Is this function used specifically for transformer? (If so, probably leave this here, otherwise, can put this in nn::util)
- Is this a generic method to generate square subsequent masks?
- Any other place such mask will be used within or codebase?
Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This func is mainly used to generate a square masks for the transformer application. In the word language task, word in a sentence is only allowed to see the tokens before itself so we need an attention mask here. Can we leave this func in the transformer namespace but as an independent function, rather than part of the transformer model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 @zhangguanheng66
I am going to make this static method within TransformerImpl class. People just call Transformer::Impl::generate_sequare_subsequent_mask() instead of creating an obj and invoke it from obj.
| // check IEEE754 support here since -inf is not guaranteed to be valid on non IEEE754 platform | ||
| if (std::numeric_limits<float>::is_iec559) { | ||
| mask = mask.masked_fill(mask == 0, -std::numeric_limits<float>::infinity()).masked_fill(mask == 1, 0.f); | ||
| } | ||
| // if IEEE754 is not supported, we use the smallest float number in current platform |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems fine, but does the Python API have the same behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python doesn't have this issue.
I checked with @zhangguanheng66, using the smallest float number in this situation is fine.
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this looks good to me, the two main questions I have are:
- should
generate_square_subsequent_maskbe a part of the Transformer API? I don't know what it's for - Is there a better way to test this? It would be nice if we could use the parity harness; it's hard for a code reader to figure out correctness from reading the sample inputs/outputs
I will address other comments one by one. |
|
This update fix all previous comments except the |
This is to provide C++ APIs for Transformer NN Module which has been implemented in python here:https://github.com/pytorch/pytorch/blob/eace0533985641d9c2f36e43e3de694aca886bd9/torch/nn/modules/transformer.py It also adjusted the TransformerDecoderLayer's module registration order to 100% match the python impl. Differential Revision: [D23584010](https://our.internmc.facebook.com/intern/diff/D23584010) [ghstack-poisoned]
|
changed the |
| // NOTE: reset() is for initializing the model only, call reset() after the model is created | ||
| // will cause throwing exceptions. Call reset_parameter() if the created model need a reset | ||
| // NOTE: reset() is for initializing the model only, calling reset() after the model is created | ||
| // will cause throwing exceptions. Call reset_parameter() if the created model needs a reset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "will cause throwing exceptions" -> "will cause exceptions to be thrown", or "will throw exceptions"
| "norm3", | ||
| LayerNorm(LayerNormOptions(std::vector<int64_t> {options.d_model()}))); | ||
| // NOTE: reset() is for initializing the model only, calling reset() after the model is created | ||
| // will cause throwing exceptions. Call reset_parameter() if the created model needs a reset. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "will cause throwing exceptions" -> "will cause exceptions to be thrown", or "will throw exceptions"
| /// 2. This function requires the platform support IEEE754, since `-inf` is guaranteed to | ||
| /// be valid only when IEEE754 is supported. If the platform doesn't support IEEE754, | ||
| /// this function will fill the mask with the smallest float number instead of `-inf`, | ||
| /// a one time warning will be pop up as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "will be pop up as well" -> "will pop up as well"
| /// model with corresponding parameters. | ||
| /// | ||
| /// See https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html to | ||
| /// learn abouut the exact behavior of this transformer model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "abouut" -> "about"
| "src and tgt should have 3 dimensions, but got ", src.dim(), " and ", tgt.dim()); | ||
|
|
||
| TORCH_CHECK(src.size(1) == tgt.size(1), | ||
| "src and tgt should have equal batch number (at dim 1), but got ", src.size(1), " and ", tgt.size(1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "batch size" or "batch dim size" sounds more pytorch-like than "batch number"
| "src and tgt should have equal batch number (at dim 1), but got ", src.size(1), " and ", tgt.size(1)); | ||
|
|
||
| TORCH_CHECK(src.size(2) == options.d_model() && tgt.size(2) == options.d_model(), | ||
| "src and tgt should have same feature number as d_model (at dim 2), but got ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "feature size" sounds more pytorch-like than "feature number"
| ASSERT_EQ(result.sizes(), ref_output.sizes()); | ||
| ASSERT_TRUE(result.equal(result_cus)); | ||
| ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd also test the error conditions (e.g., src and target don't have the same size in the batch dim) with an ASSERT_RAISES to check that it doesn't crash in an unfortunate way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will add these tests
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, some very minor comments
will address |
This is to provide C++ APIs for Transformer NN Module which has been implemented in python here:https://github.com/pytorch/pytorch/blob/eace0533985641d9c2f36e43e3de694aca886bd9/torch/nn/modules/transformer.py It also adjusted the TransformerDecoderLayer's module registration order to 100% match the python impl. Differential Revision: [D23584010](https://our.internmc.facebook.com/intern/diff/D23584010) [ghstack-poisoned]
This is to provide C++ APIs for Transformer NN Module which has been implemented in python here:https://github.com/pytorch/pytorch/blob/eace0533985641d9c2f36e43e3de694aca886bd9/torch/nn/modules/transformer.py It also adjusted the TransformerDecoderLayer's module registration order to 100% match the python impl. Differential Revision: [D23584010](https://our.internmc.facebook.com/intern/diff/D23584010) [ghstack-poisoned]
|
@glaringlee merged this pull request in 77cc7d1. |
Summary: Pull Request resolved: #44333 Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D23584010 Pulled By: glaringlee fbshipit-source-id: 990026e3f1b5ae276776e344ea981386cb7528fe
Stack from ghstack:
This is to provide C++ APIs for Transformer NN Module which has been implemented in python here:https://github.com/pytorch/pytorch/blob/eace0533985641d9c2f36e43e3de694aca886bd9/torch/nn/modules/transformer.py
It also adjusted the TransformerDecoderLayer's module registration order to 100% match the python impl.
Differential Revision: D23584010