Skip to content

Commit 6e82b1c

Browse files
jamarshonfacebook-github-bot
authored andcommitted
Split nn.MultiHeadAttention into Module + functional (#20415)
Summary: Moving functions from torch/nn/modules/activation.py to torch/nn/functional.py. For functions not implemented (_get_input_buffer and _set_input_buffer), a TODO is added. Pull Request resolved: #20415 Differential Revision: D15318078 Pulled By: jamarshon fbshipit-source-id: 5ca698e2913821442cf8609cc61ac8190496a3c6
1 parent b46a630 commit 6e82b1c

File tree

2 files changed

+205
-124
lines changed

2 files changed

+205
-124
lines changed

torch/nn/functional.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3077,3 +3077,194 @@ def _pad_circular(input, padding):
30773077
input = torch.cat([input[:, :, :, :, -(padding[-5] + padding[-6]):-padding[-5]], input], dim=4)
30783078

30793079
return input
3080+
3081+
3082+
@weak_script
3083+
def multi_head_attention_forward(query, # type: Tensor
3084+
key, # type: Tensor
3085+
value, # type: Tensor
3086+
embed_dim_to_check, # type: int
3087+
num_heads, # type: int
3088+
in_proj_weight, # type: Tensor
3089+
in_proj_bias, # type: Tensor
3090+
bias_k, # type: Tensor
3091+
bias_v, # type: Tensor
3092+
add_zero_attn, # type: bool
3093+
dropout_p, # type: float
3094+
out_proj, # type: Tensor
3095+
training=True, # type: bool
3096+
key_padding_mask=None, # type: Optional[Tensor]
3097+
need_weights=True, # type: bool
3098+
attn_mask=None # type: Optional[Tensor]
3099+
):
3100+
# type: (...) -> Tuple[Tensor, Tensor]
3101+
r"""
3102+
Args:
3103+
query, key, value: map a query and a set of key-value pairs to an output.
3104+
See "Attention Is All You Need" for more details.
3105+
embed_dim_to_check: total dimension of the model.
3106+
num_heads: parallel attention heads.
3107+
in_proj_weight, in_proj_bias: input projection weight and bias.
3108+
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
3109+
add_zero_attn: add a new batch of zeros to the key and
3110+
value sequences at dim=1.
3111+
dropout_p: probability of an element to be zeroed.
3112+
out_proj: the output projection.
3113+
training: apply dropout if is ``True``.
3114+
key_padding_mask: if provided, specified padding elements in the key will
3115+
be ignored by the attention.
3116+
need_weights: output attn_output_weights.
3117+
attn_mask: mask that prevents attention to certain positions.
3118+
3119+
3120+
Shape:
3121+
Inputs:
3122+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
3123+
the embedding dimension.
3124+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
3125+
the embedding dimension.
3126+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
3127+
the embedding dimension.
3128+
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
3129+
- attn_mask: :math:`(L, L)` where L is the target sequence length.
3130+
3131+
Outputs:
3132+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
3133+
E is the embedding dimension.
3134+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
3135+
L is the target sequence length, S is the source sequence length.
3136+
"""
3137+
3138+
@weak_script
3139+
def _in_proj(input, weight, bias, start=0, end=None):
3140+
# type: (Tensor, Tensor, Optional[Tensor], int, Optional[int]) -> Tensor
3141+
weight = weight[start:end, :]
3142+
if bias is not None:
3143+
bias = bias[start:end]
3144+
return linear(input, weight, bias)
3145+
3146+
3147+
@weak_script
3148+
def _in_proj_qkv(weight, bias, query):
3149+
# type: (Tensor, Tensor, Tensor) -> Tensor
3150+
return _in_proj(query, weight, bias).chunk(3, dim=-1)
3151+
3152+
3153+
@weak_script
3154+
def _in_proj_kv(weight, bias, embed_dim, key):
3155+
# type: (Tensor, Tensor, int, Tensor) -> Tensor
3156+
return _in_proj(key, weight, bias, start=embed_dim).chunk(2, dim=-1)
3157+
3158+
3159+
@weak_script
3160+
def _in_proj_q(weight, bias, embed_dim, query):
3161+
# type: (Tensor, Tensor, int, Tensor) -> Tensor
3162+
return _in_proj(query, weight, bias, end=embed_dim)
3163+
3164+
3165+
@weak_script
3166+
def _in_proj_k(weight, bias, embed_dim, key):
3167+
# type: (Tensor, Tensor, int, Tensor) -> Tensor
3168+
return _in_proj(key, weight, bias, start=embed_dim, end=2 * embed_dim)
3169+
3170+
3171+
@weak_script
3172+
def _in_proj_v(weight, bias, embed_dim, value):
3173+
# type: (Tensor, Tensor, int, Tensor) -> Tensor
3174+
return _in_proj(value, weight, bias, start=2 * embed_dim)
3175+
3176+
3177+
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
3178+
kv_same = key.data_ptr() == value.data_ptr()
3179+
3180+
tgt_len, bsz, embed_dim = query.size()
3181+
assert embed_dim == embed_dim_to_check
3182+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
3183+
assert key.size() == value.size()
3184+
3185+
head_dim = embed_dim // num_heads
3186+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
3187+
scaling = head_dim ** -0.5
3188+
3189+
if qkv_same:
3190+
# self-attention
3191+
q, k, v = _in_proj_qkv(in_proj_weight, in_proj_bias, query)
3192+
elif kv_same:
3193+
# encoder-decoder attention
3194+
q = _in_proj_q(in_proj_weight, in_proj_bias, embed_dim, query)
3195+
if key is None:
3196+
assert value is None
3197+
k = v = None
3198+
else:
3199+
k, v = _in_proj_kv(in_proj_weight, in_proj_bias, embed_dim, key)
3200+
else:
3201+
q = _in_proj_q(in_proj_weight, in_proj_bias, embed_dim, query)
3202+
k = _in_proj_k(in_proj_weight, in_proj_bias, embed_dim, key)
3203+
v = _in_proj_v(in_proj_weight, in_proj_bias, embed_dim, value)
3204+
q *= scaling
3205+
3206+
if bias_k is not None:
3207+
assert bias_v is not None
3208+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
3209+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
3210+
if attn_mask is not None:
3211+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
3212+
if key_padding_mask is not None:
3213+
key_padding_mask = torch.cat(
3214+
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
3215+
3216+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
3217+
if k is not None:
3218+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
3219+
if v is not None:
3220+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
3221+
3222+
src_len = k.size(1)
3223+
3224+
if key_padding_mask is not None:
3225+
assert key_padding_mask.size(0) == bsz
3226+
assert key_padding_mask.size(1) == src_len
3227+
3228+
if add_zero_attn:
3229+
src_len += 1
3230+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
3231+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
3232+
if attn_mask is not None:
3233+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
3234+
if key_padding_mask is not None:
3235+
key_padding_mask = torch.cat(
3236+
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
3237+
3238+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
3239+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
3240+
3241+
if attn_mask is not None:
3242+
attn_mask = attn_mask.unsqueeze(0)
3243+
attn_output_weights += attn_mask
3244+
3245+
if key_padding_mask is not None:
3246+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
3247+
attn_output_weights = attn_output_weights.masked_fill(
3248+
key_padding_mask.unsqueeze(1).unsqueeze(2),
3249+
float('-inf'),
3250+
)
3251+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
3252+
3253+
attn_output_weights = softmax(
3254+
attn_output_weights.float(), dim=-1,
3255+
dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype)
3256+
attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
3257+
3258+
attn_output = torch.bmm(attn_output_weights, v)
3259+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
3260+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
3261+
attn_output = out_proj(attn_output)
3262+
3263+
if need_weights:
3264+
# average attention weights over heads
3265+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
3266+
attn_output_weights = attn_output_weights.sum(dim=1) / num_heads
3267+
else:
3268+
attn_output_weights = None
3269+
3270+
return attn_output, attn_output_weights

torch/nn/modules/activation.py

Lines changed: 14 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ class MultiheadAttention(Module):
694694
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
695695
bias: add bias as module parameter. Default: True.
696696
add_bias_kv: add bias to the key and value sequences at dim=0.
697-
add_zero_attn: add a new batch of zeros to the key and
697+
add_zero_attn: add a new batch of zeros to the key and
698698
value sequences at dim=1.
699699
700700
Examples::
@@ -708,9 +708,6 @@ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=Fals
708708
self.embed_dim = embed_dim
709709
self.num_heads = num_heads
710710
self.dropout = dropout
711-
self.head_dim = embed_dim // num_heads
712-
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
713-
self.scaling = self.head_dim ** -0.5
714711

715712
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
716713
if bias:
@@ -748,143 +745,36 @@ def forward(self, query, key, value, key_padding_mask=None,
748745
need_weights=True, attn_mask=None):
749746
r"""
750747
Args:
751-
query, key, value: map a query and a set of key-value pairs to an output.
752-
See "Attention Is All You Need" for more details.
753-
key_padding_mask: if provided, specified padding elements in the key will
748+
query, key, value: map a query and a set of key-value pairs to an output.
749+
See "Attention Is All You Need" for more details.
750+
key_padding_mask: if provided, specified padding elements in the key will
754751
be ignored by the attention.
755752
need_weights: output attn_output_weights.
756753
attn_mask: mask that prevents attention to certain positions.
757754
755+
758756
Shape:
759757
Inputs:
760-
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
758+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
761759
the embedding dimension.
762-
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
760+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
763761
the embedding dimension.
764-
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
762+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
765763
the embedding dimension.
766764
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
767765
- attn_mask: :math:`(L, L)` where L is the target sequence length.
768766
769767
Outputs:
770-
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
768+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
771769
E is the embedding dimension.
772770
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
773771
L is the target sequence length, S is the source sequence length.
774772
"""
775-
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
776-
kv_same = key.data_ptr() == value.data_ptr()
777-
778-
tgt_len, bsz, embed_dim = query.size()
779-
assert embed_dim == self.embed_dim
780-
assert list(query.size()) == [tgt_len, bsz, embed_dim]
781-
assert key.size() == value.size()
782-
783-
if qkv_same:
784-
# self-attention
785-
q, k, v = self._in_proj_qkv(query)
786-
elif kv_same:
787-
# encoder-decoder attention
788-
q = self._in_proj_q(query)
789-
if key is None:
790-
assert value is None
791-
k = v = None
792-
else:
793-
k, v = self._in_proj_kv(key)
794-
else:
795-
q = self._in_proj_q(query)
796-
k = self._in_proj_k(key)
797-
v = self._in_proj_v(value)
798-
q *= self.scaling
799-
800-
if self.bias_k is not None:
801-
assert self.bias_v is not None
802-
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
803-
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
804-
if attn_mask is not None:
805-
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
806-
if key_padding_mask is not None:
807-
key_padding_mask = torch.cat(
808-
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
809-
810-
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
811-
if k is not None:
812-
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
813-
if v is not None:
814-
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
815-
816-
src_len = k.size(1)
817-
818-
if key_padding_mask is not None:
819-
assert key_padding_mask.size(0) == bsz
820-
assert key_padding_mask.size(1) == src_len
821-
822-
if self.add_zero_attn:
823-
src_len += 1
824-
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
825-
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
826-
if attn_mask is not None:
827-
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
828-
if key_padding_mask is not None:
829-
key_padding_mask = torch.cat(
830-
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
831-
832-
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
833-
assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
834-
835-
if attn_mask is not None:
836-
attn_mask = attn_mask.unsqueeze(0)
837-
attn_output_weights += attn_mask
838-
839-
if key_padding_mask is not None:
840-
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
841-
attn_output_weights = attn_output_weights.masked_fill(
842-
key_padding_mask.unsqueeze(1).unsqueeze(2),
843-
float('-inf'),
844-
)
845-
attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
846-
847-
attn_output_weights = F.softmax(
848-
attn_output_weights.float(), dim=-1,
849-
dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype)
850-
attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training)
851-
852-
attn_output = torch.bmm(attn_output_weights, v)
853-
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
854-
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
855-
attn_output = self.out_proj(attn_output)
856-
857-
if need_weights:
858-
# average attention weights over heads
859-
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
860-
attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads
861-
else:
862-
attn_output_weights = None
863-
864-
return attn_output, attn_output_weights
865-
866-
def _in_proj_qkv(self, query):
867-
return self._in_proj(query).chunk(3, dim=-1)
868-
869-
def _in_proj_kv(self, key):
870-
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
871-
872-
def _in_proj_q(self, query):
873-
return self._in_proj(query, end=self.embed_dim)
874-
875-
def _in_proj_k(self, key):
876-
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
877-
878-
def _in_proj_v(self, value):
879-
return self._in_proj(value, start=2 * self.embed_dim)
880-
881-
def _in_proj(self, input, start=0, end=None):
882-
weight = self.in_proj_weight
883-
bias = self.in_proj_bias
884-
weight = weight[start:end, :]
885-
if bias is not None:
886-
bias = bias[start:end]
887-
return F.linear(input, weight, bias)
773+
return F.multi_head_attention_forward(
774+
query, key, value, self.embed_dim, self.num_heads,
775+
self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn,
776+
self.dropout, self.out_proj, training=self.training,
777+
key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask)
888778

889779

890780
@weak_module

0 commit comments

Comments
 (0)