Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions test/nn/test_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias,
multihead_attn_module.training, key_padding_mask_tensor, True, attn_mask_tensor,
static_k=saved_k_tensor, static_v=saved_v_tensor,
average_attn_weights=average_attn_weights)
average_attn_weights=average_attn_weights,
is_causal=False,
)
else:
result, result_weight = torch.nn.functional.multi_head_attention_forward(
_Q, _K, _V,
Expand All @@ -196,7 +198,9 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
True, multihead_attn_module.q_proj_weight,
multihead_attn_module.k_proj_weight, multihead_attn_module.v_proj_weight,
static_k=saved_k_tensor, static_v=saved_v_tensor,
average_attn_weights=average_attn_weights)
average_attn_weights=average_attn_weights,
is_causal=False,
)

result = result.squeeze(0).detach().numpy()

Expand Down
17 changes: 14 additions & 3 deletions torch/ao/nn/quantizable/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def forward(self,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
average_attn_weights: bool = True,
is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Note::
Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more
Expand Down Expand Up @@ -277,6 +278,8 @@ def forward(self,
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
Default: ``False``.
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)
Expand All @@ -290,7 +293,8 @@ def forward(self,
head of shape :math:`(N, num_heads, L, S)`.
"""
return self._forward_impl(query, key, value, key_padding_mask,
need_weights, attn_mask, average_attn_weights)
need_weights, attn_mask, average_attn_weights,
is_causal)

def _forward_impl(self,
query: Tensor,
Expand All @@ -299,7 +303,8 @@ def _forward_impl(self,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
average_attn_weights: bool = True,
is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
# This version will not deal with the static key/value pairs.
# Keeping it here for future changes.
#
Expand All @@ -308,6 +313,12 @@ def _forward_impl(self,
static_k = None
static_v = None

if attn_mask is not None and is_causal:
raise AssertionError("Only allow causal mask or attn_mask")

if is_causal:
raise AssertionError("causal mask not supported by AO MHA module")

if self.batch_first:
query, key, value = [x.transpose(0, 1) for x in (query, key, value)]

Expand Down
6 changes: 5 additions & 1 deletion torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4929,6 +4929,7 @@ def multi_head_attention_forward(
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
Expand All @@ -4949,6 +4950,8 @@ def multi_head_attention_forward(
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
Default: ``False``.
use_separate_proj_weight: the function accept the proj. weights for query, key,
and value in different forms. If false, in_proj_weight will be used, which is
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
Expand Down Expand Up @@ -5014,6 +5017,7 @@ def multi_head_attention_forward(
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
is_causal=is_causal,
use_separate_proj_weight=use_separate_proj_weight,
q_proj_weight=q_proj_weight,
k_proj_weight=k_proj_weight,
Expand Down Expand Up @@ -5184,7 +5188,7 @@ def multi_head_attention_forward(
v = v.view(bsz, num_heads, src_len, head_dim)

attn_output, attn_output_weights = _scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, need_weights, False)
q, k, v, attn_mask, dropout_p, need_weights, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)

attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
Expand Down
3 changes: 2 additions & 1 deletion torch/nn/functional.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ def multi_head_attention_forward(query: Tensor,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True
average_attn_weights: bool = True,
is_causal: bool = False
) -> Tuple[Tensor, Optional[Tensor]]: ...


Expand Down
32 changes: 25 additions & 7 deletions torch/nn/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,9 +1009,16 @@ def __setstate__(self, state):

super(MultiheadAttention, self).__setstate__(state)

def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
Expand Down Expand Up @@ -1042,6 +1049,8 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.
is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
Default: ``False``.
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
Expand All @@ -1060,6 +1069,9 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
.. note::
`batch_first` argument is ignored for unbatched inputs.
"""
if attn_mask is not None and is_causal:
raise AssertionError("Only allow causal mask or attn_mask")

is_batched = query.dim() == 3
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
Expand Down Expand Up @@ -1157,18 +1169,24 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
attn_mask=attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights)
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,
is_causal=is_causal)
else:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, average_attn_weights=average_attn_weights)
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,
is_causal=is_causal)
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:
Expand Down
60 changes: 47 additions & 13 deletions torch/nn/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,18 @@ def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=Tr
self.enable_nested_tensor = enable_nested_tensor
self.mask_check = mask_check

def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
def forward(
self,
src: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
is_causal: bool = False) -> Tensor:
r"""Pass the input through the encoder layers in turn.

Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
is_causal: If specified, applies a causal mask as mask (optional). Mutually exclusive with providing mask. Default: ``False``.
src_key_padding_mask: the mask for the src keys per batch (optional).

Shape:
Expand Down Expand Up @@ -278,8 +284,18 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_ma
output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
src_key_padding_mask_for_layers = None

# Prevent type refinement
make_causal = False
if mask is not None:
if is_causal:
raise RuntimeError("specify either mask or is_causal, but not both")

if make_causal:
is_causal = True
mask = None

for mod in self.layers:
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask_for_layers)
output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)

if convert_to_nested:
output = output.to_padded_tensor(0.)
Expand Down Expand Up @@ -437,13 +453,19 @@ def __setstate__(self, state):
self.activation = F.relu


def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
is_causal: bool = False) -> Tensor:
r"""Pass the input through the encoder layer.

Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
is_causal: If specified, applies a causal mask as src_mask. Mutually exclusive with providing src_mask.
Default: ``False``.
src_key_padding_mask: the mask for the src keys per batch (optional).

Shape:
Expand Down Expand Up @@ -623,8 +645,17 @@ def __setstate__(self, state):
state['activation'] = F.relu
super(TransformerDecoderLayer, self).__setstate__(state)

def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
tgt_is_causal: bool = False,
memory_is_causal: bool = False,
) -> Tensor:
r"""Pass the inputs (and mask) through the decoder layer.

Args:
Expand All @@ -634,39 +665,42 @@ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).

tgt_is_causal: If specified, applies a causal mask as tgt mask. Mutually exclusive with providing tgt_mask. Default: ``False``.
memory_is_causal: If specified, applies a causal mask as tgt mask. Mutually exclusive with providing memory_mask. Default: ``False``.
Shape:
see the docs in Transformer class.
"""
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf

x = tgt
if self.norm_first:
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
x = x + self._ff_block(self.norm3(x))
else:
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask))
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
x = self.norm3(x + self._ff_block(x))

return x

# self-attention block
def _sa_block(self, x: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
need_weights=False)[0]
return self.dropout1(x)

# multihead attention block
def _mha_block(self, x: Tensor, mem: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
x = self.multihead_attn(x, mem, mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
need_weights=False)[0]
return self.dropout2(x)

Expand Down
2 changes: 1 addition & 1 deletion torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v,
add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None,
need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None,
v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None: -1),
v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1),
torch.nn.functional.multi_margin_loss: (lambda input, target, p=1, margin=1.0, weight=None, size_average=None,
reduce=None, reduction='mean': -1),
torch.nn.functional.multilabel_margin_loss: (lambda input, target, size_average=None, reduce=None,
Expand Down