Skip to content

Commit fae6c67

Browse files
tunzsoumith
authored andcommitted
Configurable flushing denormal numbers on CPU (#5294)
* Configurable flushing denormal numbers on CPU * Formatting * Update docs * Minor doc changes
1 parent 6279367 commit fae6c67

File tree

7 files changed

+87
-0
lines changed

7 files changed

+87
-0
lines changed

aten/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ ENDIF()
261261
IF(C_SSE3_FOUND)
262262
MESSAGE(STATUS "SSE3 Found")
263263
SET(CMAKE_C_FLAGS "${C_SSE3_FLAGS} -DUSE_SSE3 ${CMAKE_C_FLAGS}")
264+
SET(CMAKE_CXX_FLAGS "${C_SSE3_FLAGS} -DUSE_SSE3 ${CMAKE_CXX_FLAGS}")
264265
ENDIF(C_SSE3_FOUND)
265266

266267
# we don't set -mavx and -mavx2 flags globally, but only for specific files

aten/src/ATen/Context.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
#endif
1616
#include "ATen/CPUGenerator.h"
1717

18+
#ifdef USE_SSE3
19+
#include <pmmintrin.h>
20+
#endif
21+
1822
namespace at {
1923

2024
static inline void errorHandler(const char * msg, void * data) {
@@ -118,4 +122,19 @@ int64_t Context::current_device() const {
118122
return -1;
119123
}
120124

125+
bool Context::setFlushDenormal(bool on) {
126+
#ifdef USE_SSE3
127+
// Setting flush-to-zero (FTZ) flag
128+
_MM_SET_FLUSH_ZERO_MODE(on ? _MM_FLUSH_ZERO_ON
129+
: _MM_FLUSH_ZERO_OFF);
130+
131+
// Setting denormals-are-zero (DAZ) flag
132+
_MM_SET_DENORMALS_ZERO_MODE(on ? _MM_DENORMALS_ZERO_ON
133+
: _MM_DENORMALS_ZERO_OFF);
134+
return true;
135+
#else
136+
return false;
137+
#endif
138+
}
139+
121140
}

aten/src/ATen/Context.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ class AT_API Context {
5656
cudaStream_t getCurrentCUDAStream() const;
5757
cudaDeviceProp* getCurrentDeviceProperties() const;
5858

59+
bool setFlushDenormal(bool on);
60+
5961
// NB: This method is *purely* whether or not a user requested
6062
// that CuDNN was enabled, it doesn't actually say anything about
6163
// whether or not CuDNN is actually usable. Use cudnn_is_acceptable

docs/source/torch.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Tensors
99
.. autofunction:: set_default_tensor_type
1010
.. autofunction:: numel
1111
.. autofunction:: set_printoptions
12+
.. autofunction:: set_flush_denormal
1213

1314

1415
Creation Ops

test/test_torch.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5471,6 +5471,29 @@ def test_offset_scalar_cast(self):
54715471
y = x[2:]
54725472
self.assertEqual(int(y), 3)
54735473

5474+
@unittest.skipIf(torch.set_flush_denormal(False),
5475+
"flush_denormal not supported")
5476+
def test_set_flush_denormal(self):
5477+
tiny_float = 1e-42
5478+
tiny_double = 1e-320
5479+
float_tensor = torch.FloatTensor([1.0, tiny_float])
5480+
double_tensor = torch.DoubleTensor([1.0, tiny_float, tiny_double])
5481+
5482+
self.assertEqual(float_tensor[0], 1.0, prec=0.0)
5483+
self.assertEqual(float_tensor[1], tiny_float, prec=tiny_float / 16)
5484+
self.assertEqual(double_tensor[0], 1.0, prec=0.0)
5485+
self.assertEqual(double_tensor[1], tiny_float, prec=0.0)
5486+
self.assertEqual(double_tensor[2], tiny_double, prec=0.0)
5487+
5488+
torch.set_flush_denormal(True)
5489+
self.assertEqual(float_tensor[0], 1.0, prec=0.0)
5490+
self.assertEqual(float_tensor[1], 0.0, prec=0.0) # tiny_float to zero
5491+
self.assertEqual(double_tensor[0], 1.0, prec=0.0)
5492+
# tiny_float is not converted to zero in double type
5493+
self.assertEqual(double_tensor[1], tiny_float, prec=0.0)
5494+
self.assertEqual(double_tensor[2], 0.0, prec=0.0) # tiny_double to zero
5495+
torch.set_flush_denormal(False)
5496+
54745497
# Functions to test negative dimension wrapping
54755498
METHOD = 1
54765499
INPLACE_METHOD = 2

torch/_torch_docs.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3932,6 +3932,37 @@
39323932
39333933
""")
39343934

3935+
add_docstr(torch._C.set_flush_denormal,
3936+
r"""
3937+
set_flush_denormal(mode) -> bool
3938+
3939+
Disables denormal floating numbers on CPU.
3940+
3941+
Returns ``True`` if your system supports flushing denormal numbers and it
3942+
successfully configures flush denormal mode. :meth:`~torch.set_flush_denormal`
3943+
is only supported on x86 architectures supporting SSE3.
3944+
3945+
Args:
3946+
mode (bool): Controls whether to enable flush denormal mode or not
3947+
3948+
Example::
3949+
3950+
>>> torch.set_flush_denormal(True)
3951+
True
3952+
>>> torch.DoubleTensor([1e-323])
3953+
3954+
0
3955+
[torch.DoubleTensor of size 1]
3956+
3957+
>>> torch.set_flush_denormal(False)
3958+
True
3959+
>>> torch.DoubleTensor([1e-323])
3960+
3961+
9.88131e-324 *
3962+
1.0000
3963+
[torch.DoubleTensor of size 1]
3964+
""")
3965+
39353966
add_docstr(torch._C.set_num_threads,
39363967
r"""
39373968
set_num_threads(int)

torch/csrc/Module.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,15 @@ PyObject *THPModule_benchmarkCuDNN(PyObject *_unused)
586586
else Py_RETURN_FALSE;
587587
}
588588

589+
PyObject *THPModule_setFlushDenormal(PyObject *_unused, PyObject *arg) {
590+
THPUtils_assert(PyBool_Check(arg), "flush_denormal expects a bool, "
591+
"but got %s", THPUtils_typename(arg));
592+
if (!at::globalContext().setFlushDenormal(arg == Py_True)) {
593+
Py_RETURN_FALSE;
594+
};
595+
Py_RETURN_TRUE;
596+
}
597+
589598
#ifdef WITH_CUDA
590599
extern PyObject * THCSPModule_initExtension(PyObject *self);
591600
#endif
@@ -619,6 +628,7 @@ static PyMethodDef TorchMethods[] = {
619628
{"from_numpy", (PyCFunction)THPModule_fromNumpy, METH_O, NULL},
620629
{"_to_dlpack", (PyCFunction)THPModule_toDLPack, METH_O, NULL},
621630
{"_from_dlpack", (PyCFunction)THPModule_fromDLPack, METH_O, NULL},
631+
{"set_flush_denormal", (PyCFunction)THPModule_setFlushDenormal, METH_O, NULL},
622632

623633
{"sigmoid", (PyCFunction)THPModule_sigmoid, METH_VARARGS | METH_KEYWORDS, NULL},
624634
{"log", (PyCFunction)THPModule_log, METH_VARARGS | METH_KEYWORDS, NULL},

0 commit comments

Comments
 (0)