Skip to content

Commit 2854d15

Browse files
Chilleepytorchmergebot
authored andcommitted
Add type annotations for higher order ops/flex_attention (#137065)
Pull Request resolved: #137065 Approved by: https://github.com/drisspg, https://github.com/Skylion007 ghstack dependencies: #136826, #137043, #137049
1 parent 3b8511d commit 2854d15

File tree

4 files changed

+71
-35
lines changed

4 files changed

+71
-35
lines changed

torch/_higher_order_ops/auto_functionalize.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,9 @@ def auto_functionalized_fake(
586586
**kwargs: Any,
587587
) -> Tuple[Any, Tuple[Tensor, ...]]:
588588
with mode:
589-
result = auto_functionalized_dense(_mutable_op, **kwargs)
589+
result = auto_functionalized_dense(
590+
_mutable_op, _only_clone_these_tensors=None, **kwargs
591+
)
590592
return result
591593

592594

@@ -681,7 +683,9 @@ def auto_functionalized_v2_fake(
681683
**kwargs: Dict[str, Any],
682684
) -> Tuple[Any, Tuple[Tensor, ...]]:
683685
with mode:
684-
result = auto_functionalized_v2_dense(_mutable_op, **kwargs)
686+
result = auto_functionalized_v2_dense(
687+
_mutable_op, _only_clone_these_bases=None, **kwargs
688+
)
685689
return result
686690

687691

torch/_higher_order_ops/flex_attention.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1-
# mypy: allow-untyped-decorators
2-
# mypy: allow-untyped-defs
31
import math
4-
from typing import Any, Callable, Dict, Sequence, Tuple, Union
2+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
53

64
import torch
75
import torch.utils._pytree as pytree
6+
from torch import Tensor
87
from torch._C import DispatchKey
98
from torch._higher_order_ops.utils import (
109
_has_potential_branch_input_mutation,
1110
autograd_not_implemented,
1211
reenter_make_fx,
1312
UnsupportedAliasMutationException,
1413
)
15-
from torch._ops import HigherOrderOperator
14+
from torch._ops import HigherOrderOperator, OpOverload
1615
from torch._subclasses import FakeTensorMode
1716
from torch.fx.experimental.proxy_tensor import (
1817
make_fx,
@@ -77,7 +76,13 @@ class TransformGetItemToIndex(TorchFunctionMode):
7776
# scalar and create a view. We do not want that behavior in this case, so we
7877
# use this torchfunctionmode to override that behavior for score_mod
7978
# wherever we're running it.
80-
def __torch_function__(self, func, types, args=(), kwargs=None):
79+
def __torch_function__(
80+
self,
81+
func: OpOverload,
82+
types: Tuple[torch._C._TensorMeta, ...],
83+
args: Tuple[object, ...] = (),
84+
kwargs: Optional[Dict[str, object]] = None,
85+
) -> object:
8186
if func == torch.Tensor.__getitem__:
8287
index_args = pytree.tree_leaves(args[1])
8388
if all(isinstance(x, torch.Tensor) for x in index_args):
@@ -485,7 +490,11 @@ def flex_attention_fake_tensor_mode(
485490

486491

487492
# ---------------------------- Autograd Implementation ----------------------------
488-
def create_fw_bw_graph(score_mod, index_values, other_buffers):
493+
def create_fw_bw_graph(
494+
score_mod: Callable,
495+
index_values: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor],
496+
other_buffers: Tuple[Tensor, ...],
497+
) -> Tuple[Callable, Callable]:
489498
# See Note:[HOP create fw_bw graph]
490499

491500
# All of these imports need to be here in order to avoid circular dependencies
@@ -508,7 +517,7 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
508517
with suspend_functionalization(), disable_functional_mode():
509518
with disable_proxy_modes_tracing():
510519

511-
def _from_fun(t):
520+
def _from_fun(t: Tensor) -> Tensor:
512521
return torch.empty_strided(
513522
t.size(),
514523
t.stride(),
@@ -544,8 +553,18 @@ def _from_fun(t):
544553
)
545554
example_grad = _from_fun(example_flat_out)
546555

547-
def joint_f(score, b, h, m, n, example_grad, *other_buffers):
548-
def fw_with_masks(*args):
556+
def joint_f(
557+
score: Tensor,
558+
b: Tensor,
559+
h: Tensor,
560+
m: Tensor,
561+
n: Tensor,
562+
example_grad: Tensor,
563+
*other_buffers: Tuple[Tensor, ...],
564+
) -> Tuple[Tensor, ...]:
565+
def fw_with_masks(
566+
*args: Tuple[Tensor, ...]
567+
) -> Tuple[Tuple[Tensor], Tuple[bool]]:
549568
fw_out = score_mod(*args)
550569
out_requires_grad = fw_out.requires_grad
551570
return ((fw_out,), (out_requires_grad,))
@@ -566,17 +585,17 @@ def fw_with_masks(*args):
566585
class FlexAttentionAutogradOp(torch.autograd.Function):
567586
@staticmethod
568587
def forward(
569-
ctx,
570-
query,
571-
key,
572-
value,
573-
fw_graph,
574-
joint_graph,
575-
block_mask,
576-
scale,
577-
kernel_options,
578-
score_mod_other_buffers,
579-
mask_mod_other_buffers,
588+
ctx: Any,
589+
query: Tensor,
590+
key: Tensor,
591+
value: Tensor,
592+
fw_graph: Callable,
593+
joint_graph: Callable,
594+
block_mask: Tuple[Any, ...],
595+
scale: float,
596+
kernel_options: Dict[str, Any],
597+
score_mod_other_buffers: Tuple[Any, ...],
598+
mask_mod_other_buffers: Tuple[Any, ...],
580599
) -> Tuple[torch.Tensor, torch.Tensor]:
581600
any_buffer_requires_grad = any(
582601
buffer.requires_grad
@@ -620,7 +639,7 @@ def forward(
620639
return out, logsumexp
621640

622641
@staticmethod
623-
def backward(ctx, grad_out, grad_logsumexp):
642+
def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Optional[Tensor], ...]: # type: ignore[override]
624643
fw_args = ctx.saved_tensors
625644
(
626645
query,
@@ -693,15 +712,19 @@ def flex_attention_autograd(
693712
block_mask: Tuple,
694713
scale: float,
695714
kernel_options: Dict[str, Any],
696-
score_mod_other_buffers: Tuple = (),
697-
mask_mod_other_buffers: Tuple = (),
715+
score_mod_other_buffers: Tuple[Tensor, ...] = (),
716+
mask_mod_other_buffers: Tuple[Tensor, ...] = (),
698717
) -> Tuple[torch.Tensor, torch.Tensor]:
699718
with TransformGetItemToIndex():
700719
input_requires_grad = any(t.requires_grad for t in (query, key, value))
701720
if torch.is_grad_enabled() and input_requires_grad:
702-
example_vals = [
703-
torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad)
704-
] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
721+
example_vals = (
722+
torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad),
723+
torch.zeros((), dtype=torch.int),
724+
torch.zeros((), dtype=torch.int),
725+
torch.zeros((), dtype=torch.int),
726+
torch.zeros((), dtype=torch.int),
727+
)
705728
fw_graph, bw_graph = create_fw_bw_graph(
706729
score_mod, example_vals, score_mod_other_buffers
707730
)

torch/_inductor/fx_passes/post_grad.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,9 @@ def _(match: Match, *args, **kwargs):
845845
# tracing a function with kwargs.
846846
def decomp(*flat_args):
847847
args, kwargs = pytree.tree_unflatten(flat_args, spec)
848-
return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs)
848+
assert len(args) == 1
849+
mode = args[0]
850+
return auto_functionalized_dense(mode, only_clone_these_tensors, **kwargs)
849851

850852
match.replace_by_example(decomp, flat_args, run_functional_passes=False)
851853

@@ -889,7 +891,11 @@ def _(match: Match, *args, **kwargs):
889891
# tracing a function with kwargs.
890892
def decomp(*flat_args):
891893
args, kwargs = pytree.tree_unflatten(flat_args, spec)
892-
return auto_functionalized_v2_dense(*args, only_clone_these_bases, **kwargs)
894+
assert len(args) == 1
895+
mutable_op = args[0]
896+
return auto_functionalized_v2_dense(
897+
mutable_op, only_clone_these_bases, **kwargs
898+
)
893899

894900
match.replace_by_example(decomp, flat_args, run_functional_passes=False)
895901

torch/_ops.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import inspect
77
import sys
88
import types
9-
from typing import Any, Callable, Dict, List, Set, Type, Union
9+
from typing import Any, Callable, Dict, List, Set, Type, TypeVar, Union
1010

1111
import torch
1212
import torch.utils._pytree as pytree
@@ -16,6 +16,9 @@
1616
from torch.utils._python_dispatch import TorchDispatchMode
1717

1818

19+
_F = TypeVar("_F", bound=Callable[..., Any])
20+
21+
1922
# Query `hasattr` only once.
2023
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
2124

@@ -99,8 +102,8 @@ def has_kernel_for_any_dispatch_key(self, ks):
99102
return True
100103
return False
101104

102-
def py_impl(self, k):
103-
def inner(fn):
105+
def py_impl(self, k: Any) -> Callable[[_F], _F]:
106+
def inner(fn: _F) -> _F:
104107
if inspect.isclass(k) and (
105108
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
106109
):
@@ -141,7 +144,7 @@ def inner(fn):
141144
# with ctx.redispatch_to_next():
142145
# out = ctx.functionalize(inner_f)(*args_unwrapped)
143146
# return ctx.wrap_tensors(out)
144-
def py_functionalize_impl(self, fn):
147+
def py_functionalize_impl(self, fn: _F) -> _F:
145148
from torch._subclasses.functional_tensor import (
146149
CppFunctionalizeAPI as _CppFunctionalizeAPI,
147150
FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI,
@@ -273,7 +276,7 @@ def __init__(self, name, *, cacheable=False):
273276
# it to next key. This is only safe to do when PreDispatch key stack has no
274277
# active modes.
275278

276-
def py_impl(self, k):
279+
def py_impl(self, k: Any) -> Callable[[_F], _F]:
277280
if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k):
278281
self.non_fallthrough_keys = self.non_fallthrough_keys.add(k)
279282
return super().py_impl(k)

0 commit comments

Comments
 (0)