@@ -3077,3 +3077,194 @@ def _pad_circular(input, padding):
30773077 input = torch .cat ([input [:, :, :, :, - (padding [- 5 ] + padding [- 6 ]):- padding [- 5 ]], input ], dim = 4 )
30783078
30793079 return input
3080+
3081+
3082+ @weak_script
3083+ def multi_head_attention_forward (query , # type: Tensor
3084+ key , # type: Tensor
3085+ value , # type: Tensor
3086+ embed_dim_to_check , # type: int
3087+ num_heads , # type: int
3088+ in_proj_weight , # type: Tensor
3089+ in_proj_bias , # type: Tensor
3090+ bias_k , # type: Tensor
3091+ bias_v , # type: Tensor
3092+ add_zero_attn , # type: bool
3093+ dropout_p , # type: float
3094+ out_proj , # type: Tensor
3095+ training = True , # type: bool
3096+ key_padding_mask = None , # type: Optional[Tensor]
3097+ need_weights = True , # type: bool
3098+ attn_mask = None # type: Optional[Tensor]
3099+ ):
3100+ # type: (...) -> Tuple[Tensor, Tensor]
3101+ r"""
3102+ Args:
3103+ query, key, value: map a query and a set of key-value pairs to an output.
3104+ See "Attention Is All You Need" for more details.
3105+ embed_dim_to_check: total dimension of the model.
3106+ num_heads: parallel attention heads.
3107+ in_proj_weight, in_proj_bias: input projection weight and bias.
3108+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
3109+ add_zero_attn: add a new batch of zeros to the key and
3110+ value sequences at dim=1.
3111+ dropout_p: probability of an element to be zeroed.
3112+ out_proj: the output projection.
3113+ training: apply dropout if is ``True``.
3114+ key_padding_mask: if provided, specified padding elements in the key will
3115+ be ignored by the attention.
3116+ need_weights: output attn_output_weights.
3117+ attn_mask: mask that prevents attention to certain positions.
3118+
3119+
3120+ Shape:
3121+ Inputs:
3122+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
3123+ the embedding dimension.
3124+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
3125+ the embedding dimension.
3126+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
3127+ the embedding dimension.
3128+ - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
3129+ - attn_mask: :math:`(L, L)` where L is the target sequence length.
3130+
3131+ Outputs:
3132+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
3133+ E is the embedding dimension.
3134+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
3135+ L is the target sequence length, S is the source sequence length.
3136+ """
3137+
3138+ @weak_script
3139+ def _in_proj (input , weight , bias , start = 0 , end = None ):
3140+ # type: (Tensor, Tensor, Optional[Tensor], int, Optional[int]) -> Tensor
3141+ weight = weight [start :end , :]
3142+ if bias is not None :
3143+ bias = bias [start :end ]
3144+ return linear (input , weight , bias )
3145+
3146+
3147+ @weak_script
3148+ def _in_proj_qkv (weight , bias , query ):
3149+ # type: (Tensor, Tensor, Tensor) -> Tensor
3150+ return _in_proj (query , weight , bias ).chunk (3 , dim = - 1 )
3151+
3152+
3153+ @weak_script
3154+ def _in_proj_kv (weight , bias , embed_dim , key ):
3155+ # type: (Tensor, Tensor, int, Tensor) -> Tensor
3156+ return _in_proj (key , weight , bias , start = embed_dim ).chunk (2 , dim = - 1 )
3157+
3158+
3159+ @weak_script
3160+ def _in_proj_q (weight , bias , embed_dim , query ):
3161+ # type: (Tensor, Tensor, int, Tensor) -> Tensor
3162+ return _in_proj (query , weight , bias , end = embed_dim )
3163+
3164+
3165+ @weak_script
3166+ def _in_proj_k (weight , bias , embed_dim , key ):
3167+ # type: (Tensor, Tensor, int, Tensor) -> Tensor
3168+ return _in_proj (key , weight , bias , start = embed_dim , end = 2 * embed_dim )
3169+
3170+
3171+ @weak_script
3172+ def _in_proj_v (weight , bias , embed_dim , value ):
3173+ # type: (Tensor, Tensor, int, Tensor) -> Tensor
3174+ return _in_proj (value , weight , bias , start = 2 * embed_dim )
3175+
3176+
3177+ qkv_same = query .data_ptr () == key .data_ptr () == value .data_ptr ()
3178+ kv_same = key .data_ptr () == value .data_ptr ()
3179+
3180+ tgt_len , bsz , embed_dim = query .size ()
3181+ assert embed_dim == embed_dim_to_check
3182+ assert list (query .size ()) == [tgt_len , bsz , embed_dim ]
3183+ assert key .size () == value .size ()
3184+
3185+ head_dim = embed_dim // num_heads
3186+ assert head_dim * num_heads == embed_dim , "embed_dim must be divisible by num_heads"
3187+ scaling = head_dim ** - 0.5
3188+
3189+ if qkv_same :
3190+ # self-attention
3191+ q , k , v = _in_proj_qkv (in_proj_weight , in_proj_bias , query )
3192+ elif kv_same :
3193+ # encoder-decoder attention
3194+ q = _in_proj_q (in_proj_weight , in_proj_bias , embed_dim , query )
3195+ if key is None :
3196+ assert value is None
3197+ k = v = None
3198+ else :
3199+ k , v = _in_proj_kv (in_proj_weight , in_proj_bias , embed_dim , key )
3200+ else :
3201+ q = _in_proj_q (in_proj_weight , in_proj_bias , embed_dim , query )
3202+ k = _in_proj_k (in_proj_weight , in_proj_bias , embed_dim , key )
3203+ v = _in_proj_v (in_proj_weight , in_proj_bias , embed_dim , value )
3204+ q *= scaling
3205+
3206+ if bias_k is not None :
3207+ assert bias_v is not None
3208+ k = torch .cat ([k , bias_k .repeat (1 , bsz , 1 )])
3209+ v = torch .cat ([v , bias_v .repeat (1 , bsz , 1 )])
3210+ if attn_mask is not None :
3211+ attn_mask = torch .cat ([attn_mask , attn_mask .new_zeros (attn_mask .size (0 ), 1 )], dim = 1 )
3212+ if key_padding_mask is not None :
3213+ key_padding_mask = torch .cat (
3214+ [key_padding_mask , key_padding_mask .new_zeros (key_padding_mask .size (0 ), 1 )], dim = 1 )
3215+
3216+ q = q .contiguous ().view (tgt_len , bsz * num_heads , head_dim ).transpose (0 , 1 )
3217+ if k is not None :
3218+ k = k .contiguous ().view (- 1 , bsz * num_heads , head_dim ).transpose (0 , 1 )
3219+ if v is not None :
3220+ v = v .contiguous ().view (- 1 , bsz * num_heads , head_dim ).transpose (0 , 1 )
3221+
3222+ src_len = k .size (1 )
3223+
3224+ if key_padding_mask is not None :
3225+ assert key_padding_mask .size (0 ) == bsz
3226+ assert key_padding_mask .size (1 ) == src_len
3227+
3228+ if add_zero_attn :
3229+ src_len += 1
3230+ k = torch .cat ([k , k .new_zeros ((k .size (0 ), 1 ) + k .size ()[2 :])], dim = 1 )
3231+ v = torch .cat ([v , v .new_zeros ((v .size (0 ), 1 ) + v .size ()[2 :])], dim = 1 )
3232+ if attn_mask is not None :
3233+ attn_mask = torch .cat ([attn_mask , attn_mask .new_zeros (attn_mask .size (0 ), 1 )], dim = 1 )
3234+ if key_padding_mask is not None :
3235+ key_padding_mask = torch .cat (
3236+ [key_padding_mask , torch .zeros (key_padding_mask .size (0 ), 1 ).type_as (key_padding_mask )], dim = 1 )
3237+
3238+ attn_output_weights = torch .bmm (q , k .transpose (1 , 2 ))
3239+ assert list (attn_output_weights .size ()) == [bsz * num_heads , tgt_len , src_len ]
3240+
3241+ if attn_mask is not None :
3242+ attn_mask = attn_mask .unsqueeze (0 )
3243+ attn_output_weights += attn_mask
3244+
3245+ if key_padding_mask is not None :
3246+ attn_output_weights = attn_output_weights .view (bsz , num_heads , tgt_len , src_len )
3247+ attn_output_weights = attn_output_weights .masked_fill (
3248+ key_padding_mask .unsqueeze (1 ).unsqueeze (2 ),
3249+ float ('-inf' ),
3250+ )
3251+ attn_output_weights = attn_output_weights .view (bsz * num_heads , tgt_len , src_len )
3252+
3253+ attn_output_weights = softmax (
3254+ attn_output_weights .float (), dim = - 1 ,
3255+ dtype = torch .float32 if attn_output_weights .dtype == torch .float16 else attn_output_weights .dtype )
3256+ attn_output_weights = dropout (attn_output_weights , p = dropout_p , training = training )
3257+
3258+ attn_output = torch .bmm (attn_output_weights , v )
3259+ assert list (attn_output .size ()) == [bsz * num_heads , tgt_len , head_dim ]
3260+ attn_output = attn_output .transpose (0 , 1 ).contiguous ().view (tgt_len , bsz , embed_dim )
3261+ attn_output = out_proj (attn_output )
3262+
3263+ if need_weights :
3264+ # average attention weights over heads
3265+ attn_output_weights = attn_output_weights .view (bsz , num_heads , tgt_len , src_len )
3266+ attn_output_weights = attn_output_weights .sum (dim = 1 ) / num_heads
3267+ else :
3268+ attn_output_weights = None
3269+
3270+ return attn_output , attn_output_weights
0 commit comments