88// This header implements various unary operations using a MKL VML style
99// interface.
1010
11+ // It implements various functions with a simple interface
12+ // For example it enables the user to call vsin(float* out, const float* in,
13+ // size) This functions takes a pointer to a contious output array of floats and
14+ // a constant input array. It will then apply sin to each value in in the input
15+ // array and write the result into the output array. out and in may point to the
16+ // same memory, i.e. this fully supports in-place operations. These functions
17+ // also implement their own parallelization, so take precautions when calling
18+ // these from threaded functions.
19+
20+ // When MKL is available it will call into MKL's VML library similar to NumPy
21+ // If MKL is not available it will use SLEEF.
22+
23+ // This file might be compiled under AVX or AVX2 when called from e.g.
24+ // UnaryOpsKernel.cpp
25+
1126#include < algorithm>
1227#include < cstddef>
1328#include < cstdint>
1631
1732#if AT_MKL_ENABLED() && !defined(__APPLE__)
1833#include < mkl.h>
19- #include < mkl_vml.h>
34+ #endif
35+
36+ // [Note SSE-AVX transitions]
37+ // There is a bug in Glibc2.23
38+ // https://bugs.launchpad.net/ubuntu/+source/glibc/+bug/1663280. Calling zeroall
39+ // when using AVX/AVX2 code resolves this.
40+ #if defined(__AVX__) && defined(__GLIBC__) && __GLIBC_MINOR__ == 23
41+ #define DL_RUNTIME_BUG (op, type ) \
42+ volatile type x = (type)(1 ); \
43+ x = std::op(x); \
44+ _mm256_zeroall ();
45+ #else
46+ #define DL_RUNTIME_BUG (op, type )
2047#endif
2148
2249namespace at {
@@ -40,9 +67,16 @@ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
4067
4168// NB: We ignore numerical errors by convention and leave them to the user
4269
43- #define IMPLEMENT_VML (op ) \
70+ // We unfortunately need to duplicate code here to deal with the SSE-AVX
71+ // transition bug (see [Note SSE-AVX transitions]). As soon as we can expect
72+ // users to use a version of glibc newer than 2.23 we will be able to ditch
73+ // this. This duplication is also necessary since not all functions (e.g. rsqrt)
74+ // might be part of cmath.
75+
76+ #define IMPLEMENT_VML_BUG (op ) \
4477 template <typename scalar_t > \
45- inline void v##op(scalar_t * out, scalar_t * in, int64_t size) { \
78+ inline void v##op(scalar_t * out, const scalar_t * in, int64_t size) { \
79+ DL_RUNTIME_BUG (op, scalar_t ) \
4680 parallel_for (0 , size, 2048 , [out, in](int64_t begin, int64_t end) { \
4781 map ([](const Vec256<scalar_t >& x) { return x.op (); }, \
4882 out + begin, \
@@ -51,70 +85,82 @@ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
5185 }); \
5286 }
5387
54- #define IMPLEMENT_FLOAT_MKL_VML (op, mklop ) \
55- template <typename scalar_t > \
56- inline void v##op(scalar_t * out, scalar_t * in, int64_t size); \
88+ #define IMPLEMENT_VML (op ) \
89+ template <typename scalar_t > \
90+ inline void v##op(scalar_t * out, const scalar_t * in, int64_t size) { \
91+ parallel_for (0 , size, 2048 , [out, in](int64_t begin, int64_t end) { \
92+ map ([](const Vec256<scalar_t >& x) { return x.op (); }, \
93+ out + begin, \
94+ in + begin, \
95+ end - begin); \
96+ }); \
97+ }
98+
99+ IMPLEMENT_VML_BUG (abs)
100+ IMPLEMENT_VML_BUG (acos)
101+ IMPLEMENT_VML_BUG (asin)
102+ IMPLEMENT_VML_BUG (atan)
103+ IMPLEMENT_VML_BUG (ceil)
104+ IMPLEMENT_VML_BUG (cos)
105+ // IMPLEMENT_VML_BUG(cosh)
106+ IMPLEMENT_VML_BUG (erf)
107+ IMPLEMENT_VML_BUG (exp)
108+ IMPLEMENT_VML_BUG (expm1)
109+ IMPLEMENT_VML_BUG (floor)
110+ IMPLEMENT_VML (reciprocal)
111+ IMPLEMENT_VML_BUG (log)
112+ IMPLEMENT_VML_BUG (log10)
113+ IMPLEMENT_VML_BUG (log1p)
114+ IMPLEMENT_VML_BUG (log2)
115+ IMPLEMENT_VML (neg)
116+ IMPLEMENT_VML_BUG (sin)
117+ // IMPLEMENT_VML_BUG(sinh)
118+ IMPLEMENT_VML_BUG (sqrt)
119+ IMPLEMENT_VML_BUG (round)
120+ IMPLEMENT_VML (rsqrt)
121+ IMPLEMENT_VML_BUG (tan)
122+ IMPLEMENT_VML_BUG (tanh)
123+ IMPLEMENT_VML_BUG (trunc)
124+
125+ #if AT_MKL_ENABLED() && !defined(__APPLE__)
126+
127+ #define IMPLEMENT_VML_MKL (op, mklop ) \
57128 template <> \
58- inline void v##op(float * out, float * in, int64_t size) { \
129+ inline void v##op(float * out, const float * in, int64_t size) { \
59130 vms##mklop (size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
60131 } \
61132 template <> \
62- inline void v##op(double * out, double * in, int64_t size) { \
133+ inline void v##op(double * out, const double * in, int64_t size) { \
63134 vmd##mklop (size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
64135 }
65136
66137// NB: abs, cosh and sinh were temporarily disabled due to issues with Apple clang
67138
68- # if AT_MKL_ENABLED() && !defined(__APPLE__ )
69- IMPLEMENT_FLOAT_MKL_VML (acos, Acos)
70- IMPLEMENT_FLOAT_MKL_VML (asin, Asin)
71- IMPLEMENT_FLOAT_MKL_VML (atan, Atan)
72- IMPLEMENT_FLOAT_MKL_VML (cos, Cos)
73- // IMPLEMENT_FLOAT_MKL_VML (cosh, Cosh)
74- IMPLEMENT_FLOAT_MKL_VML (erf, Erf)
75- IMPLEMENT_FLOAT_MKL_VML (exp, Exp)
76- IMPLEMENT_FLOAT_MKL_VML (expm1, Expm1)
77- IMPLEMENT_FLOAT_MKL_VML (log, Ln)
78- IMPLEMENT_FLOAT_MKL_VML (log10, Log10)
79- IMPLEMENT_FLOAT_MKL_VML (log1p, Log1p)
80- IMPLEMENT_FLOAT_MKL_VML (sin, Sin)
81- // IMPLEMENT_FLOAT_MKL_VML (sinh, Sinh)
82- IMPLEMENT_FLOAT_MKL_VML (sqrt, Sqrt)
83- IMPLEMENT_FLOAT_MKL_VML (tan, Tan)
84- IMPLEMENT_FLOAT_MKL_VML (tanh, Tanh)
85- IMPLEMENT_FLOAT_MKL_VML (trunc, Trunc)
139+ IMPLEMENT_VML_MKL (abs, Abs )
140+ IMPLEMENT_VML_MKL (acos, Acos)
141+ IMPLEMENT_VML_MKL (asin, Asin)
142+ IMPLEMENT_VML_MKL (atan, Atan)
143+ IMPLEMENT_VML_MKL (cos, Cos)
144+ // IMPLEMENT_VML_MKL (cosh, Cosh)
145+ IMPLEMENT_VML_MKL (erf, Erf)
146+ IMPLEMENT_VML_MKL (exp, Exp)
147+ IMPLEMENT_VML_MKL (expm1, Expm1)
148+ IMPLEMENT_VML_MKL (log, Ln)
149+ IMPLEMENT_VML_MKL (log10, Log10)
150+ IMPLEMENT_VML_MKL (log1p, Log1p)
151+ IMPLEMENT_VML_MKL (sin, Sin)
152+ // IMPLEMENT_VML_MKL (sinh, Sinh)
153+ IMPLEMENT_VML_MKL (sqrt, Sqrt)
154+ IMPLEMENT_VML_MKL (tan, Tan)
155+ IMPLEMENT_VML_MKL (tanh, Tanh)
156+ IMPLEMENT_VML_MKL (trunc, Trunc)
86157
87158#if INTEL_MKL_VERSION >= 20180406
88- IMPLEMENT_FLOAT_MKL_VML (log2, Log2)
89- #else
90- IMPLEMENT_VML (log2)
159+ IMPLEMENT_VML_MKL (log2, Log2)
91160#endif
92161
93- #else
94- IMPLEMENT_VML (acos)
95- IMPLEMENT_VML(asin)
96- IMPLEMENT_VML(atan)
97- IMPLEMENT_VML(cos)
98- // IMPLEMENT_VML(cosh)
99- IMPLEMENT_VML(erf)
100- IMPLEMENT_VML(exp)
101- IMPLEMENT_VML(expm1)
102- IMPLEMENT_VML(log)
103- IMPLEMENT_VML(log10)
104- IMPLEMENT_VML(log1p)
105- IMPLEMENT_VML(log2)
106- IMPLEMENT_VML(sin)
107- // IMPLEMENT_VML(sinh)
108- IMPLEMENT_VML(sqrt)
109- IMPLEMENT_VML(tan)
110- IMPLEMENT_VML(tanh)
111162#endif
112163
113- IMPLEMENT_VML (ceil)
114- IMPLEMENT_VML (floor)
115- IMPLEMENT_VML (round)
116- IMPLEMENT_VML (trunc)
117-
118164} // namespace
119165} // namespace vml
120166} // namespace at
0 commit comments