Skip to content

Commit 5833a6d

Browse files
committed
few more comments
1 parent 690e670 commit 5833a6d

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

torch/nn/functional.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4859,29 +4859,31 @@ def _in_projection(
48594859
.. warning:: This function is beta and subject to change.
48604860
48614861
Note:
4862-
This function calls into one of three backends:
4862+
For the CUDA backend this function has the ability to call into fused kernels for improved performance.
4863+
There are currently three supported backends:
48634864
* `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_
48644865
* `Memory-Efficient Attention`_
4865-
* A pytorch implementation defined in c++ matching the above formulation
4866+
* A PyTorch implementation defined in c++ matching the above formulation
48664867
48674868
The function defaults to selecting the highest-performing implementation based on the inputs provided.
48684869
However, each of the fused kernels has specific input limitations.
48694870
If you require a specific backend to be utilized, there exist functions to enable or disable specific backends.
48704871
Please note that all backends are enabled by default.
48714872
4873+
The following functions can be used for enabling and disabling backends. The context manager being the preferred mechanism:
4874+
* :func:`torch.backends.cuda.sdp_kernel`: A context manager used to enable/disable any of the backends.
4875+
* :func:`torch.backends.cuda.enable_flash_sdp`: Enables or Disables FlashAttention.
4876+
* :func:`torch.backends.cuda.enable_mem_efficient_sdp`: Enables or Disables Memory-Efficient Attention.
4877+
* :func:`torch.backends.cuda.enable_math_sdp`: Enables or Disables the PyTorch c++ implementation.
48724878
4873-
For example :func:`~torch.backends.cuda.enable_flash_sdp` can be used to enable/disable FlashAttention.
4874-
The context manager :func:`~torch.backends.cuda.sdp_kernel` can be used to enable/disable the backends
4875-
for a specific scope.
4879+
If a user wants to enforce that one of the fused implementations is used, disable the PyTorch c++ implementation
4880+
using :func:`torch.backends.cuda.sdp_kernel`.
4881+
If for some reason a fused implementation is not available, the function will throw an error with the
4882+
reasons why the fused implementation was not used.
48764883
4877-
If a user wants to enforce that one of the fused implementations is used, disable the math fallback
4878-
using one of the above mechanisms. If for some reason a fused implementation is not available,
4879-
the function will throw an error with the reasons why the fused implementation was not used.
4880-
4881-
The numerical accuracy of the fused kernels has been tested but due to the nature of fusing floating point operations
4882-
the deviations from the infinite precision implementation may be significant. If that is the case we encourage users
4883-
to please file an issue. A work around would be disabiling the fused kernels and using the math fallback. For more
4884-
information please see :doc:`/notes/numerical_accuracy`.
4884+
Due to the nature of fusing floating point operations the output of this funciton may be different depending on what backend kernel is chosen.
4885+
The c++ implementation supports torch.float64 and can be used when higher precision is required.
4886+
For more information please see :doc:`/notes/numerical_accuracy`
48854887
48864888
Note:
48874889
{cudnn_reproducibility_note}
@@ -4911,6 +4913,15 @@ def _in_projection(
49114913
* :math:`Ev: \text{Embedding dimension of the value}`
49124914
* :math:`\text{num\_heads}: \text{Number of heads}`
49134915
4916+
Examples::
4917+
4918+
>>> # Optionally use the context manager to ensure one of the fused kerenels is run
4919+
>>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
4920+
>>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
4921+
>>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
4922+
>>> with torch.backends.cuda.sdp_kernel(enable_math=False):
4923+
>>> F.scaled_dot_product_attention(query,key,value)
4924+
49144925
.. _FlashAttention\: Fast and Memory-Efficient Exact Attention with IO-Awareness:
49154926
https://arxiv.org/abs/2205.14135
49164927
.. _Memory-Efficient Attention:

0 commit comments

Comments
 (0)