@@ -194,12 +194,18 @@ def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=Tr
194194 self .enable_nested_tensor = enable_nested_tensor
195195 self .mask_check = mask_check
196196
197- def forward (self , src : Tensor , mask : Optional [Tensor ] = None , src_key_padding_mask : Optional [Tensor ] = None ) -> Tensor :
197+ def forward (
198+ self ,
199+ src : Tensor ,
200+ mask : Optional [Tensor ] = None ,
201+ src_key_padding_mask : Optional [Tensor ] = None ,
202+ is_causal : bool = False ) -> Tensor :
198203 r"""Pass the input through the encoder layers in turn.
199204
200205 Args:
201206 src: the sequence to the encoder (required).
202207 mask: the mask for the src sequence (optional).
208+ is_causal: If specified, applies a causal mask as mask (optional). Mutually exclusive with providing mask. Default: ``False``.
203209 src_key_padding_mask: the mask for the src keys per batch (optional).
204210
205211 Shape:
@@ -278,8 +284,18 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_ma
278284 output = torch ._nested_tensor_from_mask (output , src_key_padding_mask .logical_not (), mask_check = False )
279285 src_key_padding_mask_for_layers = None
280286
287+ # Prevent type refinement
288+ make_causal = False
289+ if mask is not None :
290+ if is_causal :
291+ raise RuntimeError ("specify either mask or is_causal, but not both" )
292+
293+ if make_causal :
294+ is_causal = True
295+ mask = None
296+
281297 for mod in self .layers :
282- output = mod (output , src_mask = mask , src_key_padding_mask = src_key_padding_mask_for_layers )
298+ output = mod (output , src_mask = mask , is_causal = is_causal , src_key_padding_mask = src_key_padding_mask_for_layers )
283299
284300 if convert_to_nested :
285301 output = output .to_padded_tensor (0. )
@@ -437,13 +453,19 @@ def __setstate__(self, state):
437453 self .activation = F .relu
438454
439455
440- def forward (self , src : Tensor , src_mask : Optional [Tensor ] = None ,
441- src_key_padding_mask : Optional [Tensor ] = None ) -> Tensor :
456+ def forward (
457+ self ,
458+ src : Tensor ,
459+ src_mask : Optional [Tensor ] = None ,
460+ src_key_padding_mask : Optional [Tensor ] = None ,
461+ is_causal : bool = False ) -> Tensor :
442462 r"""Pass the input through the encoder layer.
443463
444464 Args:
445465 src: the sequence to the encoder layer (required).
446466 src_mask: the mask for the src sequence (optional).
467+ is_causal: If specified, applies a causal mask as src_mask. Mutually exclusive with providing src_mask.
468+ Default: ``False``.
447469 src_key_padding_mask: the mask for the src keys per batch (optional).
448470
449471 Shape:
@@ -623,8 +645,17 @@ def __setstate__(self, state):
623645 state ['activation' ] = F .relu
624646 super (TransformerDecoderLayer , self ).__setstate__ (state )
625647
626- def forward (self , tgt : Tensor , memory : Tensor , tgt_mask : Optional [Tensor ] = None , memory_mask : Optional [Tensor ] = None ,
627- tgt_key_padding_mask : Optional [Tensor ] = None , memory_key_padding_mask : Optional [Tensor ] = None ) -> Tensor :
648+ def forward (
649+ self ,
650+ tgt : Tensor ,
651+ memory : Tensor ,
652+ tgt_mask : Optional [Tensor ] = None ,
653+ memory_mask : Optional [Tensor ] = None ,
654+ tgt_key_padding_mask : Optional [Tensor ] = None ,
655+ memory_key_padding_mask : Optional [Tensor ] = None ,
656+ tgt_is_causal : bool = False ,
657+ memory_is_causal : bool = False ,
658+ ) -> Tensor :
628659 r"""Pass the inputs (and mask) through the decoder layer.
629660
630661 Args:
@@ -634,39 +665,42 @@ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None
634665 memory_mask: the mask for the memory sequence (optional).
635666 tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
636667 memory_key_padding_mask: the mask for the memory keys per batch (optional).
637-
668+ tgt_is_causal: If specified, applies a causal mask as tgt mask. Mutually exclusive with providing tgt_mask. Default: ``False``.
669+ memory_is_causal: If specified, applies a causal mask as tgt mask. Mutually exclusive with providing memory_mask. Default: ``False``.
638670 Shape:
639671 see the docs in Transformer class.
640672 """
641673 # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
642674
643675 x = tgt
644676 if self .norm_first :
645- x = x + self ._sa_block (self .norm1 (x ), tgt_mask , tgt_key_padding_mask )
646- x = x + self ._mha_block (self .norm2 (x ), memory , memory_mask , memory_key_padding_mask )
677+ x = x + self ._sa_block (self .norm1 (x ), tgt_mask , tgt_key_padding_mask , tgt_is_causal )
678+ x = x + self ._mha_block (self .norm2 (x ), memory , memory_mask , memory_key_padding_mask , memory_is_causal )
647679 x = x + self ._ff_block (self .norm3 (x ))
648680 else :
649- x = self .norm1 (x + self ._sa_block (x , tgt_mask , tgt_key_padding_mask ))
650- x = self .norm2 (x + self ._mha_block (x , memory , memory_mask , memory_key_padding_mask ))
681+ x = self .norm1 (x + self ._sa_block (x , tgt_mask , tgt_key_padding_mask , tgt_is_causal ))
682+ x = self .norm2 (x + self ._mha_block (x , memory , memory_mask , memory_key_padding_mask , memory_is_causal ))
651683 x = self .norm3 (x + self ._ff_block (x ))
652684
653685 return x
654686
655687 # self-attention block
656688 def _sa_block (self , x : Tensor ,
657- attn_mask : Optional [Tensor ], key_padding_mask : Optional [Tensor ]) -> Tensor :
689+ attn_mask : Optional [Tensor ], key_padding_mask : Optional [Tensor ], is_causal : bool = False ) -> Tensor :
658690 x = self .self_attn (x , x , x ,
659691 attn_mask = attn_mask ,
660692 key_padding_mask = key_padding_mask ,
693+ is_causal = is_causal ,
661694 need_weights = False )[0 ]
662695 return self .dropout1 (x )
663696
664697 # multihead attention block
665698 def _mha_block (self , x : Tensor , mem : Tensor ,
666- attn_mask : Optional [Tensor ], key_padding_mask : Optional [Tensor ]) -> Tensor :
699+ attn_mask : Optional [Tensor ], key_padding_mask : Optional [Tensor ], is_causal : bool = False ) -> Tensor :
667700 x = self .multihead_attn (x , mem , mem ,
668701 attn_mask = attn_mask ,
669702 key_padding_mask = key_padding_mask ,
703+ is_causal = is_causal ,
670704 need_weights = False )[0 ]
671705 return self .dropout2 (x )
672706
0 commit comments