Skip to content

Commit 5560eb4

Browse files
Michael Gschwindfacebook-github-bot
authored andcommitted
Introduce causal mask (#90508)
Summary: Pull Request resolved: #90508 Introduce causal mask Test Plan: sandcastle & github ci/cd Reviewed By: albanD, mleshen Differential Revision: D41723137 fbshipit-source-id: 5cb222d672e8f174a1447a91414068ba87490dc3
1 parent 01e7f46 commit 5560eb4

File tree

7 files changed

+100
-28
lines changed

7 files changed

+100
-28
lines changed

test/nn/test_multihead_attention.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,9 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
183183
multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias,
184184
multihead_attn_module.training, key_padding_mask_tensor, True, attn_mask_tensor,
185185
static_k=saved_k_tensor, static_v=saved_v_tensor,
186-
average_attn_weights=average_attn_weights)
186+
average_attn_weights=average_attn_weights,
187+
is_causal=False,
188+
)
187189
else:
188190
result, result_weight = torch.nn.functional.multi_head_attention_forward(
189191
_Q, _K, _V,
@@ -196,7 +198,9 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
196198
True, multihead_attn_module.q_proj_weight,
197199
multihead_attn_module.k_proj_weight, multihead_attn_module.v_proj_weight,
198200
static_k=saved_k_tensor, static_v=saved_v_tensor,
199-
average_attn_weights=average_attn_weights)
201+
average_attn_weights=average_attn_weights,
202+
is_causal=False,
203+
)
200204

201205
result = result.squeeze(0).detach().numpy()
202206

torch/ao/nn/quantizable/modules/activation.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ def forward(self,
240240
key_padding_mask: Optional[Tensor] = None,
241241
need_weights: bool = True,
242242
attn_mask: Optional[Tensor] = None,
243-
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
243+
average_attn_weights: bool = True,
244+
is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
244245
r"""
245246
Note::
246247
Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more
@@ -277,6 +278,8 @@ def forward(self,
277278
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
278279
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
279280
is provided, it will be added to the attention weight.
281+
- is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
282+
Default: ``False``.
280283
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
281284
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
282285
effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)
@@ -290,7 +293,8 @@ def forward(self,
290293
head of shape :math:`(N, num_heads, L, S)`.
291294
"""
292295
return self._forward_impl(query, key, value, key_padding_mask,
293-
need_weights, attn_mask, average_attn_weights)
296+
need_weights, attn_mask, average_attn_weights,
297+
is_causal)
294298

295299
def _forward_impl(self,
296300
query: Tensor,
@@ -299,7 +303,8 @@ def _forward_impl(self,
299303
key_padding_mask: Optional[Tensor] = None,
300304
need_weights: bool = True,
301305
attn_mask: Optional[Tensor] = None,
302-
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
306+
average_attn_weights: bool = True,
307+
is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
303308
# This version will not deal with the static key/value pairs.
304309
# Keeping it here for future changes.
305310
#
@@ -308,6 +313,12 @@ def _forward_impl(self,
308313
static_k = None
309314
static_v = None
310315

316+
if attn_mask is not None and is_causal:
317+
raise AssertionError("Only allow causal mask or attn_mask")
318+
319+
if is_causal:
320+
raise AssertionError("causal mask not supported by AO MHA module")
321+
311322
if self.batch_first:
312323
query, key, value = [x.transpose(0, 1) for x in (query, key, value)]
313324

torch/nn/functional.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4929,6 +4929,7 @@ def multi_head_attention_forward(
49294929
static_k: Optional[Tensor] = None,
49304930
static_v: Optional[Tensor] = None,
49314931
average_attn_weights: bool = True,
4932+
is_causal: bool = False,
49324933
) -> Tuple[Tensor, Optional[Tensor]]:
49334934
r"""
49344935
Args:
@@ -4949,6 +4950,8 @@ def multi_head_attention_forward(
49494950
need_weights: output attn_output_weights.
49504951
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
49514952
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
4953+
is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
4954+
Default: ``False``.
49524955
use_separate_proj_weight: the function accept the proj. weights for query, key,
49534956
and value in different forms. If false, in_proj_weight will be used, which is
49544957
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
@@ -5014,6 +5017,7 @@ def multi_head_attention_forward(
50145017
key_padding_mask=key_padding_mask,
50155018
need_weights=need_weights,
50165019
attn_mask=attn_mask,
5020+
is_causal=is_causal,
50175021
use_separate_proj_weight=use_separate_proj_weight,
50185022
q_proj_weight=q_proj_weight,
50195023
k_proj_weight=k_proj_weight,
@@ -5184,7 +5188,7 @@ def multi_head_attention_forward(
51845188
v = v.view(bsz, num_heads, src_len, head_dim)
51855189

51865190
attn_output, attn_output_weights = _scaled_dot_product_attention(
5187-
q, k, v, attn_mask, dropout_p, need_weights, False)
5191+
q, k, v, attn_mask, dropout_p, need_weights, is_causal)
51885192
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
51895193

51905194
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)

torch/nn/functional.pyi.in

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,8 @@ def multi_head_attention_forward(query: Tensor,
379379
v_proj_weight: Optional[Tensor] = None,
380380
static_k: Optional[Tensor] = None,
381381
static_v: Optional[Tensor] = None,
382-
average_attn_weights: bool = True
382+
average_attn_weights: bool = True,
383+
is_causal: bool = False
383384
) -> Tuple[Tensor, Optional[Tensor]]: ...
384385

385386

torch/nn/modules/activation.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,9 +1009,16 @@ def __setstate__(self, state):
10091009

10101010
super(MultiheadAttention, self).__setstate__(state)
10111011

1012-
def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
1013-
need_weights: bool = True, attn_mask: Optional[Tensor] = None,
1014-
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
1012+
def forward(
1013+
self,
1014+
query: Tensor,
1015+
key: Tensor,
1016+
value: Tensor,
1017+
key_padding_mask: Optional[Tensor] = None,
1018+
need_weights: bool = True,
1019+
attn_mask: Optional[Tensor] = None,
1020+
average_attn_weights: bool = True,
1021+
is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
10151022
r"""
10161023
Args:
10171024
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
@@ -1042,6 +1049,8 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
10421049
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
10431050
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
10441051
the attention weight.
1052+
is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
1053+
Default: ``False``.
10451054
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
10461055
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
10471056
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
@@ -1060,6 +1069,9 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
10601069
.. note::
10611070
`batch_first` argument is ignored for unbatched inputs.
10621071
"""
1072+
if attn_mask is not None and is_causal:
1073+
raise AssertionError("Only allow causal mask or attn_mask")
1074+
10631075
is_batched = query.dim() == 3
10641076
if key_padding_mask is not None:
10651077
_kpm_dtype = key_padding_mask.dtype
@@ -1157,18 +1169,24 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
11571169
self.dropout, self.out_proj.weight, self.out_proj.bias,
11581170
training=self.training,
11591171
key_padding_mask=key_padding_mask, need_weights=need_weights,
1160-
attn_mask=attn_mask, use_separate_proj_weight=True,
1172+
attn_mask=attn_mask,
1173+
use_separate_proj_weight=True,
11611174
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
1162-
v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights)
1175+
v_proj_weight=self.v_proj_weight,
1176+
average_attn_weights=average_attn_weights,
1177+
is_causal=is_causal)
11631178
else:
11641179
attn_output, attn_output_weights = F.multi_head_attention_forward(
11651180
query, key, value, self.embed_dim, self.num_heads,
11661181
self.in_proj_weight, self.in_proj_bias,
11671182
self.bias_k, self.bias_v, self.add_zero_attn,
11681183
self.dropout, self.out_proj.weight, self.out_proj.bias,
11691184
training=self.training,
1170-
key_padding_mask=key_padding_mask, need_weights=need_weights,
1171-
attn_mask=attn_mask, average_attn_weights=average_attn_weights)
1185+
key_padding_mask=key_padding_mask,
1186+
need_weights=need_weights,
1187+
attn_mask=attn_mask,
1188+
average_attn_weights=average_attn_weights,
1189+
is_causal=is_causal)
11721190
if self.batch_first and is_batched:
11731191
return attn_output.transpose(1, 0), attn_output_weights
11741192
else:

torch/nn/modules/transformer.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,18 @@ def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=Tr
194194
self.enable_nested_tensor = enable_nested_tensor
195195
self.mask_check = mask_check
196196

197-
def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
197+
def forward(
198+
self,
199+
src: Tensor,
200+
mask: Optional[Tensor] = None,
201+
src_key_padding_mask: Optional[Tensor] = None,
202+
is_causal: bool = False) -> Tensor:
198203
r"""Pass the input through the encoder layers in turn.
199204
200205
Args:
201206
src: the sequence to the encoder (required).
202207
mask: the mask for the src sequence (optional).
208+
is_causal: If specified, applies a causal mask as mask (optional). Mutually exclusive with providing mask. Default: ``False``.
203209
src_key_padding_mask: the mask for the src keys per batch (optional).
204210
205211
Shape:
@@ -278,8 +284,18 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_ma
278284
output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
279285
src_key_padding_mask_for_layers = None
280286

287+
# Prevent type refinement
288+
make_causal = False
289+
if mask is not None:
290+
if is_causal:
291+
raise RuntimeError("specify either mask or is_causal, but not both")
292+
293+
if make_causal:
294+
is_causal = True
295+
mask = None
296+
281297
for mod in self.layers:
282-
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask_for_layers)
298+
output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
283299

284300
if convert_to_nested:
285301
output = output.to_padded_tensor(0.)
@@ -437,13 +453,19 @@ def __setstate__(self, state):
437453
self.activation = F.relu
438454

439455

440-
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
441-
src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
456+
def forward(
457+
self,
458+
src: Tensor,
459+
src_mask: Optional[Tensor] = None,
460+
src_key_padding_mask: Optional[Tensor] = None,
461+
is_causal: bool = False) -> Tensor:
442462
r"""Pass the input through the encoder layer.
443463
444464
Args:
445465
src: the sequence to the encoder layer (required).
446466
src_mask: the mask for the src sequence (optional).
467+
is_causal: If specified, applies a causal mask as src_mask. Mutually exclusive with providing src_mask.
468+
Default: ``False``.
447469
src_key_padding_mask: the mask for the src keys per batch (optional).
448470
449471
Shape:
@@ -623,8 +645,17 @@ def __setstate__(self, state):
623645
state['activation'] = F.relu
624646
super(TransformerDecoderLayer, self).__setstate__(state)
625647

626-
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
627-
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
648+
def forward(
649+
self,
650+
tgt: Tensor,
651+
memory: Tensor,
652+
tgt_mask: Optional[Tensor] = None,
653+
memory_mask: Optional[Tensor] = None,
654+
tgt_key_padding_mask: Optional[Tensor] = None,
655+
memory_key_padding_mask: Optional[Tensor] = None,
656+
tgt_is_causal: bool = False,
657+
memory_is_causal: bool = False,
658+
) -> Tensor:
628659
r"""Pass the inputs (and mask) through the decoder layer.
629660
630661
Args:
@@ -634,39 +665,42 @@ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None
634665
memory_mask: the mask for the memory sequence (optional).
635666
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
636667
memory_key_padding_mask: the mask for the memory keys per batch (optional).
637-
668+
tgt_is_causal: If specified, applies a causal mask as tgt mask. Mutually exclusive with providing tgt_mask. Default: ``False``.
669+
memory_is_causal: If specified, applies a causal mask as tgt mask. Mutually exclusive with providing memory_mask. Default: ``False``.
638670
Shape:
639671
see the docs in Transformer class.
640672
"""
641673
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
642674

643675
x = tgt
644676
if self.norm_first:
645-
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
646-
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)
677+
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
678+
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
647679
x = x + self._ff_block(self.norm3(x))
648680
else:
649-
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
650-
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask))
681+
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
682+
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
651683
x = self.norm3(x + self._ff_block(x))
652684

653685
return x
654686

655687
# self-attention block
656688
def _sa_block(self, x: Tensor,
657-
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
689+
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
658690
x = self.self_attn(x, x, x,
659691
attn_mask=attn_mask,
660692
key_padding_mask=key_padding_mask,
693+
is_causal=is_causal,
661694
need_weights=False)[0]
662695
return self.dropout1(x)
663696

664697
# multihead attention block
665698
def _mha_block(self, x: Tensor, mem: Tensor,
666-
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
699+
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
667700
x = self.multihead_attn(x, mem, mem,
668701
attn_mask=attn_mask,
669702
key_padding_mask=key_padding_mask,
703+
is_causal=is_causal,
670704
need_weights=False)[0]
671705
return self.dropout2(x)
672706

torch/overrides.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
827827
lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v,
828828
add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None,
829829
need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None,
830-
v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None: -1),
830+
v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1),
831831
torch.nn.functional.multi_margin_loss: (lambda input, target, p=1, margin=1.0, weight=None, size_average=None,
832832
reduce=None, reduction='mean': -1),
833833
torch.nn.functional.multilabel_margin_loss: (lambda input, target, size_average=None, reduce=None,

0 commit comments

Comments
 (0)