@@ -4859,29 +4859,31 @@ def _in_projection(
48594859.. warning:: This function is beta and subject to change.
48604860
48614861Note:
4862- This function calls into one of three backends:
4862+ For the CUDA backend this function has the ability to call into fused kernels for improved performance.
4863+ There are currently three supported backends:
48634864 * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_
48644865 * `Memory-Efficient Attention`_
4865- * A pytorch implementation defined in c++ matching the above formulation
4866+ * A PyTorch implementation defined in c++ matching the above formulation
48664867
48674868 The function defaults to selecting the highest-performing implementation based on the inputs provided.
48684869 However, each of the fused kernels has specific input limitations.
48694870 If you require a specific backend to be utilized, there exist functions to enable or disable specific backends.
48704871 Please note that all backends are enabled by default.
48714872
4873+ The following functions can be used for enabling and disabling backends. The context manager being the preferred mechanism:
4874+ * :func:`torch.backends.cuda.sdp_kernel`: A context manager used to enable/disable any of the backends.
4875+ * :func:`torch.backends.cuda.enable_flash_sdp`: Enables or Disables FlashAttention.
4876+ * :func:`torch.backends.cuda.enable_mem_efficient_sdp`: Enables or Disables Memory-Efficient Attention.
4877+ * :func:`torch.backends.cuda.enable_math_sdp`: Enables or Disables the PyTorch c++ implementation.
48724878
4873- For example :func:`~torch.backends.cuda.enable_flash_sdp` can be used to enable/disable FlashAttention.
4874- The context manager :func:`~torch.backends.cuda.sdp_kernel` can be used to enable/disable the backends
4875- for a specific scope.
4879+ If a user wants to enforce that one of the fused implementations is used, disable the PyTorch c++ implementation
4880+ using :func:`torch.backends.cuda.sdp_kernel`.
4881+ If for some reason a fused implementation is not available, the function will throw an error with the
4882+ reasons why the fused implementation was not used.
48764883
4877- If a user wants to enforce that one of the fused implementations is used, disable the math fallback
4878- using one of the above mechanisms. If for some reason a fused implementation is not available,
4879- the function will throw an error with the reasons why the fused implementation was not used.
4880-
4881- The numerical accuracy of the fused kernels has been tested but due to the nature of fusing floating point operations
4882- the deviations from the infinite precision implementation may be significant. If that is the case we encourage users
4883- to please file an issue. A work around would be disabiling the fused kernels and using the math fallback. For more
4884- information please see :doc:`/notes/numerical_accuracy`.
4884+ Due to the nature of fusing floating point operations the output of this funciton may be different depending on what backend kernel is chosen.
4885+ The c++ implementation supports torch.float64 and can be used when higher precision is required.
4886+ For more information please see :doc:`/notes/numerical_accuracy`
48854887
48864888Note:
48874889 {cudnn_reproducibility_note}
@@ -4911,6 +4913,15 @@ def _in_projection(
49114913 * :math:`Ev: \text{Embedding dimension of the value}`
49124914 * :math:`\text{num\_heads}: \text{Number of heads}`
49134915
4916+ Examples::
4917+
4918+ >>> # Optionally use the context manager to ensure one of the fused kerenels is run
4919+ >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
4920+ >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
4921+ >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
4922+ >>> with torch.backends.cuda.sdp_kernel(enable_math=False):
4923+ >>> F.scaled_dot_product_attention(query,key,value)
4924+
49144925.. _FlashAttention\: Fast and Memory-Efficient Exact Attention with IO-Awareness:
49154926 https://arxiv.org/abs/2205.14135
49164927.. _Memory-Efficient Attention:
0 commit comments