Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ Pooling functions
fractional_max_pool2d
fractional_max_pool3d

Attention Mechanisms
-------------------------------

.. autosummary::
:toctree: generated
:nosignatures:

scaled_dot_product_attention

Non-linear activation functions
-------------------------------

Expand Down
1 change: 1 addition & 0 deletions tools/pyi/gen_pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def gen_nn_functional(fm: FileManager) -> None:
"softplus",
"softshrink",
"one_hot",
"scaled_dot_product_attention",
]
import_code = ["from .. import {0} as {0}".format(_) for _ in imports]
# TODO make these types more precise
Expand Down
32 changes: 16 additions & 16 deletions torch/backends/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def preferred_linalg_library(backend: Union[None, str, torch._C._LinalgBackend]
class SDPBackend(IntEnum):
r"""Enum class for the scaled dot product attention backends.

.. warning:: This flag is experimental and subject to change.'
.. warning:: This class is in beta and subject to change.

This class needs to stay inline with the enum defined in:
This class needs to stay aligned with the enum defined in:
pytorch/aten/src/ATen/native/transformers/sdp_utils_cpp.h
"""
ERROR = -1
Expand All @@ -185,62 +185,62 @@ class SDPBackend(IntEnum):

def flash_sdp_enabled():
r"""
.. warning:: This flag is experimental and subject to change.
.. warning:: This flag is beta and subject to change.

Returns whether flash sdp is enabled or not.
Returns whether flash scaled dot product attention is enabled or not.
"""
return torch._C._get_flash_sdp_enabled()


def enable_flash_sdp(enabled: bool):
r"""
.. warning:: This flag is experimental and subject to change.
.. warning:: This flag is beta and subject to change.

Enables or disables flash sdp.
Enables or disables flash scaled dot product attention.
"""
torch._C._set_sdp_use_flash(enabled)

def mem_efficient_sdp_enabled():
r"""
.. warning:: This flag is experimental and subject to change.
.. warning:: This flag is beta and subject to change.

Returns whether memory efficient sdp is enabled or not.
Returns whether memory efficient scaled dot product attention is enabled or not.
"""
return torch._C._get_mem_efficient_sdp_enabled()


def enable_mem_efficient_sdp(enabled: bool):
r"""
.. warning:: This flag is experimental and subject to change.
.. warning:: This flag is beta and subject to change.

Enables or disables memory efficient sdp.
Enables or disables memory efficient scaled dot product attention.
"""
torch._C._set_sdp_use_mem_efficient(enabled)

def math_sdp_enabled():
r"""
.. warning:: This flag is experimental and subject to change.
.. warning:: This flag is beta and subject to change.

Returns whether math sdp is enabled or not.
Returns whether math scaled dot product attention is enabled or not.
"""
return torch._C._get_math_sdp_enabled()


def enable_math_sdp(enabled: bool):
r"""
.. warning:: This flag is experimental and subject to change.
.. warning:: This flag is beta and subject to change.

Enables or disables math sdp.
Enables or disables math scaled dot product attention.
"""
torch._C._set_sdp_use_math(enabled)


@contextlib.contextmanager
def sdp_kernel(enable_flash: bool = True, enable_math: bool = True, enable_mem_efficient: bool = True):
r"""
.. warning:: This flag is experimental and subject to change.
.. warning:: This flag is beta and subject to change.

This context manager can be used to temporarily enable or disable flash/memory efficient sdp and math sdp.
This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention.
Upon exiting the context manager, the previous state of the flags will be restored.
"""
previous_flash: bool = flash_sdp_enabled()
Expand Down
92 changes: 79 additions & 13 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4841,28 +4841,94 @@ def _in_projection(

scaled_dot_product_attention = _add_docstr(
torch._C._nn.scaled_dot_product_attention, r"""
scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor:

Computes scaled dot product attention on query, key and value tensors, using
an optional attention mask if passed, and applying dropout if a probability
greater than 0.0 is specified.

.. code-block:: python

# Efficient implementation equivalent to the following:
attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))) + attn_mask, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p)
return attn_weight @ V

.. warning:: This function is beta and subject to change.

Note:

There are currently three supported implementations of scaled dot product attention:

- `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_
- `Memory-Efficient Attention`_
- A PyTorch implementation defined in C++ matching the above formulation

The function may call optimized kernels for improved performance when using the CUDA backend.
For all other backends, the PyTorch implementation will be used.

All implementations are enabled by default. Scaled dot product attention attempts to automatically select the
most optimal implementation based on the inputs. In order to provide more fine-grained control over what implementation
is used, the following functions are provided for enabling and disabling implementations.
The context manager is the preferred mechanism:

- :func:`torch.backends.cuda.sdp_kernel`: A context manager used to enable/disable any of the implementations.
- :func:`torch.backends.cuda.enable_flash_sdp`: Enables or Disables FlashAttention.
- :func:`torch.backends.cuda.enable_mem_efficient_sdp`: Enables or Disables Memory-Efficient Attention.
- :func:`torch.backends.cuda.enable_math_sdp`: Enables or Disables the PyTorch C++ implementation.

Each of the fused kernels has specific input limitations. If the user requires the use of a specific fused implementation,
disable the PyTorch C++ implementation using :func:`torch.backends.cuda.sdp_kernel`.
In the event that a fused implementation is not available, an error will be raised with the
reasons why the fused implementation cannot run.

Due to the nature of fusing floating point operations, the output of this function may be different
depending on what backend kernel is chosen.
The c++ implementation supports torch.float64 and can be used when higher precision is required.
For more information please see :doc:`/notes/numerical_accuracy`

Note:
{cudnn_reproducibility_note}
""".format(**reproducibility_notes)
+ r"""

Args:
query (Tensor): Query tensor; shape (N, ..., L, E)
key (Tensor): Key tensor; shape (N, ..., S, E)
value (Tensor): Value tensor; shape (N, ..., S, E)
attn_mask (optional Tensor): Attention mask; shape (N, ..., L, S) or (L, S). Currently, only a boolean mask
is supported, where a value of True indicates that the element *should* take part in attention.
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
is_causal (bool): If true, assumes causal attention masking and ignores attn_mask.
query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.
key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.
value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.
attn_mask (optional Tensor): Attention mask; shape :math:`(N, ..., L, S)`. Two types of masks are supported.
A boolean mask where a value of True indicates that the element *should* take part in attention.
A float mask of the same type as query, key, value that is added to the attention score.
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
is_causal (bool): If true, assumes causal attention masking and errors if both attn_mask and is_causal
are set.


Returns a tuple containing:
output (Tensor): Attention output; shape (N, ..., L, E)
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.

Shape legend:
N: Batch size
...: Any number of other batch dimensions (optional)
S: Source sequence length
L: Target sequence lengthE: Embedding dimension
- :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
- :math:`S: \text{Source sequence length}`
- :math:`L: \text{Target sequence length}`
- :math:`E: \text{Embedding dimension of the query and key}`
- :math:`Ev: \text{Embedding dimension of the value}`

Examples::

>>> # Optionally use the context manager to ensure one of the fused kerenels is run
>>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> with torch.backends.cuda.sdp_kernel(enable_math=False):
>>> F.scaled_dot_product_attention(query,key,value)

.. _FlashAttention\: Fast and Memory-Efficient Exact Attention with IO-Awareness:
https://arxiv.org/abs/2205.14135
.. _Memory-Efficient Attention:
https://github.com/facebookresearch/xformers

""")

Expand Down