Skip to content

Conversation

@driazati
Copy link
Contributor

@driazati driazati commented Aug 12, 2019

This changes up nn.MultiheadedAttention so that it can be compiled with TorchScript and adds a test that it compiles.

Fixes #24173

@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: nn Related to torch.nn labels Aug 12, 2019
L is the target sequence length, S is the source sequence length.
"""
if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False:
if not self._qkv_same_embed_dim:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above check is to avoid backward compatibility problem. There may be some models trained by the old version of nn.MHA, which has no _qkv_same_embed_dim attribute.
However, since we have released pytorch 1.2, can we finally retire the support for this? @cpuhrsch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this block JIT-ing? If so, why?

Copy link
Contributor Author

@driazati driazati Aug 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hasattr call does, and this version determination shouldn't be happening in forward anyways, that's something that can be determined in _load_state_dict() like was suggested here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we add the warning message for the BC breaking in PyTorch 1.2 release, IMO, we could drop the support now. How's do you feel about it? @cpuhrsch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 - let's move this mechanism into _load_state_dict like was originally suggested and then keep it around for now. The BC breaking change should we a separate PR so we can easily keep track of it.

Copy link
Contributor

@zhangguanheng66 zhangguanheng66 Aug 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the old model, there is no _qkv_same_embed_dim attribute.

However, I cannot use _load_from_state_dict() func to add _qkv_same_embed_dim. The reason is that _qkv_same_embed_dim is not in the state_dict of the new model.

Is there a way that I can add a new attribute to an old model?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

never mind. I figure out a way to walk around. After some tests, I will submit a PR first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@driazati I chat this with Soumith and we feel that the fix in #24404 may actually confuse our users.

If the PR is blocked only by the hasattr func, could you add it to your JIT codebase? We saw hasattr func are applied many places (like quantization). It will be required anyway in the future.

q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight)
else:
if not hasattr(self, '_qkv_same_embed_dim'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. Will this block JIT?

self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
else:
self.register_parameter('q_proj_weight', None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does JIT also require to register all the parameter, even they are not used?

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
__annotations__ = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess annotations and constants are required for JIT-ing?

@zhangguanheng66
Copy link
Contributor

@driazati Just wondering if we are still going to add hasattr func scriptable and land this PR. Thanks.

@driazati
Copy link
Contributor Author

Closing in favor of #28555

@driazati driazati closed this Oct 24, 2019
@facebook-github-bot facebook-github-bot deleted the driazati/multi branch July 13, 2020 17:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: nn Related to torch.nn oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Transformer model seems not supported in TorchScript?

5 participants