1- # mypy: allow-untyped-decorators
2- # mypy: allow-untyped-defs
31import math
4- from typing import Any , Callable , Dict , Sequence , Tuple , Union
2+ from typing import Any , Callable , Dict , Optional , Sequence , Tuple , Union
53
64import torch
75import torch .utils ._pytree as pytree
6+ from torch import Tensor
87from torch ._C import DispatchKey
98from torch ._higher_order_ops .utils import (
109 _has_potential_branch_input_mutation ,
1110 autograd_not_implemented ,
1211 reenter_make_fx ,
1312 UnsupportedAliasMutationException ,
1413)
15- from torch ._ops import HigherOrderOperator
14+ from torch ._ops import HigherOrderOperator , OpOverload
1615from torch ._subclasses import FakeTensorMode
1716from torch .fx .experimental .proxy_tensor import (
1817 make_fx ,
@@ -77,7 +76,13 @@ class TransformGetItemToIndex(TorchFunctionMode):
7776 # scalar and create a view. We do not want that behavior in this case, so we
7877 # use this torchfunctionmode to override that behavior for score_mod
7978 # wherever we're running it.
80- def __torch_function__ (self , func , types , args = (), kwargs = None ):
79+ def __torch_function__ (
80+ self ,
81+ func : OpOverload ,
82+ types : Tuple [torch ._C ._TensorMeta , ...],
83+ args : Tuple [object , ...] = (),
84+ kwargs : Optional [Dict [str , object ]] = None ,
85+ ) -> object :
8186 if func == torch .Tensor .__getitem__ :
8287 index_args = pytree .tree_leaves (args [1 ])
8388 if all (isinstance (x , torch .Tensor ) for x in index_args ):
@@ -485,7 +490,11 @@ def flex_attention_fake_tensor_mode(
485490
486491
487492# ---------------------------- Autograd Implementation ----------------------------
488- def create_fw_bw_graph (score_mod , index_values , other_buffers ):
493+ def create_fw_bw_graph (
494+ score_mod : Callable ,
495+ index_values : Tuple [Tensor , Tensor , Tensor , Tensor , Tensor ],
496+ other_buffers : Tuple [Tensor , ...],
497+ ) -> Tuple [Callable , Callable ]:
489498 # See Note:[HOP create fw_bw graph]
490499
491500 # All of these imports need to be here in order to avoid circular dependencies
@@ -508,7 +517,7 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
508517 with suspend_functionalization (), disable_functional_mode ():
509518 with disable_proxy_modes_tracing ():
510519
511- def _from_fun (t ) :
520+ def _from_fun (t : Tensor ) -> Tensor :
512521 return torch .empty_strided (
513522 t .size (),
514523 t .stride (),
@@ -544,8 +553,18 @@ def _from_fun(t):
544553 )
545554 example_grad = _from_fun (example_flat_out )
546555
547- def joint_f (score , b , h , m , n , example_grad , * other_buffers ):
548- def fw_with_masks (* args ):
556+ def joint_f (
557+ score : Tensor ,
558+ b : Tensor ,
559+ h : Tensor ,
560+ m : Tensor ,
561+ n : Tensor ,
562+ example_grad : Tensor ,
563+ * other_buffers : Tuple [Tensor , ...],
564+ ) -> Tuple [Tensor , ...]:
565+ def fw_with_masks (
566+ * args : Tuple [Tensor , ...]
567+ ) -> Tuple [Tuple [Tensor ], Tuple [bool ]]:
549568 fw_out = score_mod (* args )
550569 out_requires_grad = fw_out .requires_grad
551570 return ((fw_out ,), (out_requires_grad ,))
@@ -566,17 +585,17 @@ def fw_with_masks(*args):
566585class FlexAttentionAutogradOp (torch .autograd .Function ):
567586 @staticmethod
568587 def forward (
569- ctx ,
570- query ,
571- key ,
572- value ,
573- fw_graph ,
574- joint_graph ,
575- block_mask ,
576- scale ,
577- kernel_options ,
578- score_mod_other_buffers ,
579- mask_mod_other_buffers ,
588+ ctx : Any ,
589+ query : Tensor ,
590+ key : Tensor ,
591+ value : Tensor ,
592+ fw_graph : Callable ,
593+ joint_graph : Callable ,
594+ block_mask : Tuple [ Any , ...] ,
595+ scale : float ,
596+ kernel_options : Dict [ str , Any ] ,
597+ score_mod_other_buffers : Tuple [ Any , ...] ,
598+ mask_mod_other_buffers : Tuple [ Any , ...] ,
580599 ) -> Tuple [torch .Tensor , torch .Tensor ]:
581600 any_buffer_requires_grad = any (
582601 buffer .requires_grad
@@ -620,7 +639,7 @@ def forward(
620639 return out , logsumexp
621640
622641 @staticmethod
623- def backward (ctx , grad_out , grad_logsumexp ):
642+ def backward (ctx : Any , grad_out : Tensor , grad_logsumexp : Tensor ) -> Tuple [ Optional [ Tensor ], ...]: # type: ignore[override]
624643 fw_args = ctx .saved_tensors
625644 (
626645 query ,
@@ -693,15 +712,19 @@ def flex_attention_autograd(
693712 block_mask : Tuple ,
694713 scale : float ,
695714 kernel_options : Dict [str , Any ],
696- score_mod_other_buffers : Tuple = (),
697- mask_mod_other_buffers : Tuple = (),
715+ score_mod_other_buffers : Tuple [ Tensor , ...] = (),
716+ mask_mod_other_buffers : Tuple [ Tensor , ...] = (),
698717) -> Tuple [torch .Tensor , torch .Tensor ]:
699718 with TransformGetItemToIndex ():
700719 input_requires_grad = any (t .requires_grad for t in (query , key , value ))
701720 if torch .is_grad_enabled () and input_requires_grad :
702- example_vals = [
703- torch .zeros ((), dtype = query .dtype , requires_grad = input_requires_grad )
704- ] + [torch .zeros ((), dtype = torch .int ) for _ in range (4 )]
721+ example_vals = (
722+ torch .zeros ((), dtype = query .dtype , requires_grad = input_requires_grad ),
723+ torch .zeros ((), dtype = torch .int ),
724+ torch .zeros ((), dtype = torch .int ),
725+ torch .zeros ((), dtype = torch .int ),
726+ torch .zeros ((), dtype = torch .int ),
727+ )
705728 fw_graph , bw_graph = create_fw_bw_graph (
706729 score_mod , example_vals , score_mod_other_buffers
707730 )
0 commit comments