-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] Support MultiheadedAttention module #24204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| L is the target sequence length, S is the source sequence length. | ||
| """ | ||
| if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False: | ||
| if not self._qkv_same_embed_dim: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The above check is to avoid backward compatibility problem. There may be some models trained by the old version of nn.MHA, which has no _qkv_same_embed_dim attribute.
However, since we have released pytorch 1.2, can we finally retire the support for this? @cpuhrsch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this block JIT-ing? If so, why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hasattr call does, and this version determination shouldn't be happening in forward anyways, that's something that can be determined in _load_state_dict() like was suggested here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we add the warning message for the BC breaking in PyTorch 1.2 release, IMO, we could drop the support now. How's do you feel about it? @cpuhrsch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangguanheng66 - let's move this mechanism into _load_state_dict like was originally suggested and then keep it around for now. The BC breaking change should we a separate PR so we can easily keep track of it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the old model, there is no _qkv_same_embed_dim attribute.
However, I cannot use _load_from_state_dict() func to add _qkv_same_embed_dim. The reason is that _qkv_same_embed_dim is not in the state_dict of the new model.
Is there a way that I can add a new attribute to an old model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
never mind. I figure out a way to walk around. After some tests, I will submit a PR first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, | ||
| v_proj_weight=self.v_proj_weight) | ||
| else: | ||
| if not hasattr(self, '_qkv_same_embed_dim'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above. Will this block JIT?
| self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) | ||
| self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) | ||
| else: | ||
| self.register_parameter('q_proj_weight', None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does JIT also require to register all the parameter, even they are not used?
| >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) | ||
| >>> attn_output, attn_output_weights = multihead_attn(query, key, value) | ||
| """ | ||
| __annotations__ = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess annotations and constants are required for JIT-ing?
|
@driazati Just wondering if we are still going to add |
|
Closing in favor of #28555 |
This changes up
nn.MultiheadedAttentionso that it can be compiled with TorchScript and adds a test that it compiles.Fixes #24173