Skip to content

Commit e64bf93

Browse files
committed
check in
1 parent 9f0b2c7 commit e64bf93

File tree

10 files changed

+95
-4
lines changed

10 files changed

+95
-4
lines changed

aten/src/ATen/Context.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,15 @@ void Context::setAllowFP16ReductionCuBLAS(bool b) {
254254
allow_fp16_reduction_cublas = b;
255255
}
256256

257+
bool Context::allowBF16ReductionCuBLAS() const {
258+
return allow_bf16_reduction_cublas;
259+
}
260+
261+
void Context::setAllowBF16ReductionCuBLAS(bool b) {
262+
allow_bf16_reduction_cublas = b;
263+
}
264+
265+
257266
bool Context::hasMKL() {
258267
#if AT_MKL_ENABLED()
259268
return true;

aten/src/ATen/Context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ class TORCH_API Context {
239239
void setFloat32MatmulPrecision(Float32MatmulPrecision p);
240240
bool allowFP16ReductionCuBLAS() const;
241241
void setAllowFP16ReductionCuBLAS(bool);
242+
bool allowBF16ReductionCuBLAS() const;
243+
void setAllowBF16ReductionCuBLAS(bool);
242244
at::QEngine qEngine() const;
243245
void setQEngine(at::QEngine e);
244246
static const std::vector<at::QEngine>& supportedQEngines();
@@ -286,6 +288,7 @@ class TORCH_API Context {
286288
int benchmark_limit_cudnn = 10;
287289
bool allow_tf32_cudnn = true;
288290
bool allow_fp16_reduction_cublas = true;
291+
bool allow_bf16_reduction_cublas = false;
289292
bool enabled_mkldnn = true;
290293
at::LinalgBackend linalg_preferred_backend = at::LinalgBackend::Default;
291294
#ifdef C10_MOBILE

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,12 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
538538
float fbeta = beta;
539539
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
540540
GEMM_CHECK_ARGVALUES(at::BFloat16);
541+
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
542+
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
543+
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
544+
}
545+
// Disallow fp16 reductions that could lead to unexpected overflow issues.
546+
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
541547
TORCH_CUDABLAS_CHECK(cublasGemmEx(
542548
handle,
543549
opa,
@@ -558,6 +564,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
558564
ldc,
559565
CUDA_R_32F,
560566
CUBLAS_GEMM_DFALT_TENSOR_OP));
567+
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
561568
}
562569
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000
563570

docs/source/backends.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ torch.backends.cuda
3434

3535
A :class:`bool` that controls whether reduced precision reductions (e.g., with fp16 accumulation type) are allowed with fp16 GEMMs.
3636

37+
.. attribute:: torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
38+
39+
A :class:`bool` that controls whether reduced precision reductions are allowed with bf16 GEMMs.
40+
3741
.. attribute:: torch.backends.cuda.cufft_plan_cache
3842

3943
``cufft_plan_cache`` caches the cuFFT plans

docs/source/notes/cuda.rst

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,33 @@ If full precision reductions are needed, users can disable reduced precision red
169169
170170
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
171171
172-
To toggle the reduced precision reduction flags in C++, you can do
172+
To toggle the reduced precision reduction flags in C++, one can do
173173

174174
.. code:: C++
175175

176176
at::globalContext().setAllowFP16ReductionCuBLAS(false);
177177

178+
.. _bf16reducedprecision:
179+
180+
Reduced Precision Reduction in BF16 GEMMs
181+
-----------------------------------------
182+
183+
A similar flag (as above) exists for BFloat16 GEMMs. Note that this switch is
184+
set to `False` by default for BF16 as we have observed numerical instability in
185+
PyTorch CI tests (e.g., test/test_matmul_cuda.py).
186+
187+
If reduced precision reductions are desired, users can disable reduced precision reductions in bf16 GEMMs with:
188+
189+
.. code:: python
190+
191+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
192+
193+
To toggle the reduced precision reduction flags in C++, one can do
194+
195+
.. code:: C++
196+
197+
at::globalContext().setAllowBF16ReductionCuBLAS(true);
198+
178199
Asynchronous execution
179200
----------------------
180201

docs/source/notes/numerical_accuracy.rst

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,17 @@ If your network needs full float32 precision for both matrix multiplications and
9898

9999
For more information see :ref:`TensorFloat32<tf32_on_ampere>`.
100100

101-
Reduced Precision Reduction for FP16 GEMMs
102-
------------------------------------------
101+
Reduced Precision Reduction for FP16 and BF16 GEMMs
102+
----------------------------------------------------
103103
Half-precision GEMM operations are typically done with intermediate accumulations (reduction) in single-precision for numerical accuracy and improved resilience to overflow. For performance, certain GPU architectures, especially more recent ones, allow a few truncations of the intermediate accumulation results to the reduced precision (e.g., half-precision). This change is often benign from the perspective of model convergence, though it may lead to unexpected results (e.g., ``inf`` values when the final result should be be representable in half-precision).
104104
If reduced-precision reductions are problematic, they can be turned off with
105105
``torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False``
106106

107-
For more information see :ref:`allow_fp16_reduced_precision_reduction<fp16reducedprecision>`
107+
A similar flag exists for BF16 GEMM operations and is turned on by default. If
108+
reduced-precision reductions are desired for BF16, they can be turn on with
109+
``torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True``
110+
111+
For more information see :ref:`allow_fp16_reduced_precision_reduction<fp16reducedprecision>` and :ref:`allow_bf16_reduced_precision_reduction<bf16reducedprecision>`
108112

109113
.. _fp16_on_mi200:
110114

test/test_cuda.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,14 @@ def test_cublas_allow_fp16_reduced_precision_reduction_get_set(self):
635635
self.assertEqual(torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), not orig)
636636
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig
637637

638+
def test_cublas_allow_bf16_reduced_precision_reduction_get_set(self):
639+
orig = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
640+
self.assertEqual(torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), orig)
641+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = not orig
642+
self.assertEqual(torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), not orig)
643+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig
644+
645+
638646
def test_cudnn_allow_tf32_get_set(self):
639647
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
640648
self.assertFalse(torch.backends.cudnn.allow_tf32)

torch/_C/__init__.pyi.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,8 @@ def _get_float32_matmul_precision() -> str: ... #THPModule_float32MatmulPrecisio
844844
def _set_float32_matmul_precision(arg: str) -> None: ... #THPModule_setFloat32MatmulPrecision
845845
def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... #THPModule_allowFP16ReductionCuBLAS
846846
def _set_cublas_allow_fp16_reduced_precision_reduction(arg: _bool) -> None: ... #THPModule_setAllowFP16ReductionCuBLAS
847+
def _get_cublas_allow_bf16_reduced_precision_reduction() -> _bool: ... #THPModule_allowBF16ReductionCuBLAS
848+
def _set_cublas_allow_bf16_reduced_precision_reduction(arg: _bool) -> None: ... #THPModule_setAllowBF16ReductionCuBLAS
847849
def _set_conj(x: Tensor, conj: _bool) -> None: ...
848850
def _set_neg(x: Tensor, neg: _bool) -> None: ...
849851
def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...

torch/backends/cuda/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,17 @@ def __getattr__(self, name):
9696
return torch._C._get_cublas_allow_tf32()
9797
elif name == "allow_fp16_reduced_precision_reduction":
9898
return torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
99+
elif name == "allow_bf16_reduced_precision_reduction":
100+
return torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
99101
raise AssertionError("Unknown attribute " + name)
100102

101103
def __setattr__(self, name, value):
102104
if name == "allow_tf32":
103105
return torch._C._set_cublas_allow_tf32(value)
104106
elif name == "allow_fp16_reduced_precision_reduction":
105107
return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value)
108+
elif name == "allow_bf16_reduced_precision_reduction":
109+
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value)
106110
raise AssertionError("Unknown attribute " + name)
107111

108112
_LinalgBackends = {

torch/csrc/Module.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,27 @@ PyObject* THPModule_allowFP16ReductionCuBLAS(
739739
Py_RETURN_FALSE;
740740
}
741741

742+
PyObject* THPModule_setAllowBF16ReductionCuBLAS(
743+
PyObject* _unused,
744+
PyObject* arg) {
745+
THPUtils_assert(
746+
PyBool_Check(arg),
747+
"set_allow_bf16_reduction_cublas expects a bool, "
748+
"but got %s",
749+
THPUtils_typename(arg));
750+
at::globalContext().setAllowBF16ReductionCuBLAS(arg == Py_True);
751+
Py_RETURN_NONE;
752+
}
753+
754+
PyObject* THPModule_allowBF16ReductionCuBLAS(
755+
PyObject* _unused,
756+
PyObject* noargs) {
757+
if (at::globalContext().allowBF16ReductionCuBLAS()) {
758+
Py_RETURN_TRUE;
759+
}
760+
Py_RETURN_FALSE;
761+
}
762+
742763
PyObject* THPModule_setFlushDenormal(PyObject* _unused, PyObject* arg) {
743764
THPUtils_assert(
744765
PyBool_Check(arg),
@@ -1052,6 +1073,14 @@ static PyMethodDef TorchMethods[] = {
10521073
THPModule_setAllowFP16ReductionCuBLAS,
10531074
METH_O,
10541075
nullptr},
1076+
{"_get_cublas_allow_bf16_reduced_precision_reduction",
1077+
THPModule_allowBF16ReductionCuBLAS,
1078+
METH_NOARGS,
1079+
nullptr},
1080+
{"_set_cublas_allow_bf16_reduced_precision_reduction",
1081+
THPModule_setAllowBF16ReductionCuBLAS,
1082+
METH_O,
1083+
nullptr},
10551084
{"_vmapmode_increment_nesting",
10561085
THPModule_vmapmode_increment_nesting,
10571086
METH_NOARGS,

0 commit comments

Comments
 (0)