Skip to content

Commit 7776fbb

Browse files
committed
Vectorize sigmoid
1 parent b19b38c commit 7776fbb

File tree

14 files changed

+246
-88
lines changed

14 files changed

+246
-88
lines changed

aten/src/ATen/CPUApplyUtils.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,16 +253,15 @@ apply_op(int64_t numel, int64_t offset, const Op& op, Args... iters) {
253253
}
254254
}
255255

256+
256257
inline void apply_kernel(){};
257258

259+
// TODO: Deal elegantly with 0-dim tensors. iters.strides_ of 0-dim
260+
// strided_tensor_iter will be of size 0 for dim 0 and iters.strides_[iters.dim_
261+
// - 1] will index at -1. C++14 integer_sequence could be of use here.
258262
template <typename Op, typename... Args>
259263
inline void
260264
apply_kernel(int64_t numel, int64_t offset, const Op& op, Args... iters) {
261-
// For 0-dim tensors
262-
if (numel == 1 && max_dim(iters...) == 0) {
263-
op(1, iters.data_..., iters.strides_[iters.dim_ - 1]...);
264-
return;
265-
}
266265
if (offset > 0)
267266
forward(offset, iters...);
268267
int64_t size = std::min(numel, max_iterate_size(iters...));
@@ -284,6 +283,10 @@ inline void
284283
CPU_tensor_parallel_kernel_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
285284
if (!_apply_preamble({tensor1, tensor2}))
286285
return;
286+
if (tensor1.numel() == 1) {
287+
op(1, tensor1.data<scalar1>(), tensor2.data<scalar2>(), 0, 0);
288+
return;
289+
}
287290
if (tensor1.ndimension() < 8 && tensor2.ndimension() < 8) {
288291
parallel_for(
289292
0,

aten/src/ATen/Declarations.cwrap

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,24 +1090,10 @@
10901090
- THTensor* self
10911091
]]
10921092
[[
1093-
name: sigmoid_
1093+
name: _th_sigmoid
10941094
types:
10951095
- floating_point
10961096
backends:
1097-
- CPU
1098-
- CUDA
1099-
cname: sigmoid
1100-
return: self
1101-
arguments:
1102-
- THTensor* self
1103-
- THTensor* self
1104-
]]
1105-
[[
1106-
name: sigmoid
1107-
types:
1108-
- floating_point
1109-
backends:
1110-
- CPU
11111097
- CUDA
11121098
cname: sigmoid
11131099
variants:

aten/src/ATen/cpu/vec256/vec256_base.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ struct Vec256 {
125125
Vec256<T> floor() const {
126126
return map(std::floor);
127127
}
128+
Vec256<T> neg() const {
129+
return map([](T x) { return -x; });
130+
}
128131
Vec256<T> round() const {
129132
return map(std::round);
130133
}
@@ -146,6 +149,9 @@ struct Vec256 {
146149
Vec256<T> sqrt() const {
147150
return map(std::sqrt);
148151
}
152+
Vec256<T> reciprocal() const {
153+
return map([](T x) { return (T)(1) / x; });
154+
}
149155
Vec256<T> rsqrt() const {
150156
return map([](T x) { return 1 / std::sqrt(x); });
151157
}

aten/src/ATen/cpu/vec256/vec256_double.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ template <> class Vec256<double> {
121121
Vec256<double> floor() const {
122122
return _mm256_floor_pd(values);
123123
}
124+
Vec256<double> neg() const {
125+
return _mm256_xor_pd(_mm256_set1_pd(-0.), values);
126+
}
124127
Vec256<double> round() const {
125128
return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
126129
}
@@ -136,6 +139,9 @@ template <> class Vec256<double> {
136139
Vec256<double> sqrt() const {
137140
return _mm256_sqrt_pd(values);
138141
}
142+
Vec256<double> reciprocal() const {
143+
return _mm256_div_pd(_mm256_set1_pd(1), values);
144+
}
139145
Vec256<double> rsqrt() const {
140146
return _mm256_div_pd(_mm256_set1_pd(1), _mm256_sqrt_pd(values));
141147
}

aten/src/ATen/cpu/vec256/vec256_float.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ template <> class Vec256<float> {
126126
Vec256<float> floor() const {
127127
return _mm256_floor_ps(values);
128128
}
129+
Vec256<float> neg() const {
130+
return _mm256_xor_ps(_mm256_set1_ps(-0.f), values);
131+
}
129132
Vec256<float> round() const {
130133
return _mm256_round_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
131134
}
@@ -141,6 +144,9 @@ template <> class Vec256<float> {
141144
Vec256<float> sqrt() const {
142145
return _mm256_sqrt_ps(values);
143146
}
147+
Vec256<float> reciprocal() const {
148+
return _mm256_div_ps(_mm256_set1_ps(1), values);
149+
}
144150
Vec256<float> rsqrt() const {
145151
return _mm256_div_ps(_mm256_set1_ps(1), _mm256_sqrt_ps(values));
146152
}

aten/src/ATen/cpu/vec256/vec256_int.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ struct Vec256<int64_t> : public Vec256i {
2929
__at_align32__ int64_t tmp_values[size];
3030
a.store(tmp_values);
3131
if (mask & 0x01)
32-
tmp_values[0] = _mm256_extract_epi16(b.values, 0);
32+
tmp_values[0] = _mm256_extract_epi64(b.values, 0);
3333
if (mask & 0x02)
34-
tmp_values[1] = _mm256_extract_epi16(b.values, 1);
34+
tmp_values[1] = _mm256_extract_epi64(b.values, 1);
3535
if (mask & 0x04)
36-
tmp_values[2] = _mm256_extract_epi16(b.values, 2);
36+
tmp_values[2] = _mm256_extract_epi64(b.values, 2);
3737
if (mask & 0x08)
38-
tmp_values[3] = _mm256_extract_epi16(b.values, 3);
38+
tmp_values[3] = _mm256_extract_epi64(b.values, 3);
3939
return loadu(tmp_values);
4040
}
4141
static Vec256<int64_t>

aten/src/ATen/cpu/vml.h

Lines changed: 98 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@
88
// This header implements various unary operations using a MKL VML style
99
// interface.
1010

11+
// It implements various functions with a simple interface
12+
// For example it enables the user to call vsin(float* out, const float* in,
13+
// size) This functions takes a pointer to a contious output array of floats and
14+
// a constant input array. It will then apply sin to each value in in the input
15+
// array and write the result into the output array. out and in may point to the
16+
// same memory, i.e. this fully supports in-place operations. These functions
17+
// also implement their own parallelization, so take precautions when calling
18+
// these from threaded functions.
19+
20+
// When MKL is available it will call into MKL's VML library similar to NumPy
21+
// If MKL is not available it will use SLEEF.
22+
23+
// This file might be compiled under AVX or AVX2 when called from e.g.
24+
// UnaryOpsKernel.cpp
25+
1126
#include <algorithm>
1227
#include <cstddef>
1328
#include <cstdint>
@@ -16,7 +31,19 @@
1631

1732
#if AT_MKL_ENABLED() && !defined(__APPLE__)
1833
#include <mkl.h>
19-
#include <mkl_vml.h>
34+
#endif
35+
36+
// [Note SSE-AVX transitions]
37+
// There is a bug in Glibc2.23
38+
// https://bugs.launchpad.net/ubuntu/+source/glibc/+bug/1663280. Calling zeroall
39+
// when using AVX/AVX2 code resolves this.
40+
#if defined(__AVX__) && defined(__GLIBC__) && __GLIBC_MINOR__ == 23
41+
#define DL_RUNTIME_BUG(op, type) \
42+
volatile type x = (type)(1); \
43+
x = std::op(x); \
44+
_mm256_zeroall();
45+
#else
46+
#define DL_RUNTIME_BUG(op, type)
2047
#endif
2148

2249
namespace at {
@@ -40,9 +67,16 @@ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
4067

4168
// NB: We ignore numerical errors by convention and leave them to the user
4269

43-
#define IMPLEMENT_VML(op) \
70+
// We unfortunately need to duplicate code here to deal with the SSE-AVX
71+
// transition bug (see [Note SSE-AVX transitions]). As soon as we can expect
72+
// users to use a version of glibc newer than 2.23 we will be able to ditch
73+
// this. This duplication is also necessary since not all functions (e.g. rsqrt)
74+
// might be part of cmath.
75+
76+
#define IMPLEMENT_VML_BUG(op) \
4477
template <typename scalar_t> \
45-
inline void v##op(scalar_t* out, scalar_t* in, int64_t size) { \
78+
inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \
79+
DL_RUNTIME_BUG(op, scalar_t) \
4680
parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) { \
4781
map([](const Vec256<scalar_t>& x) { return x.op(); }, \
4882
out + begin, \
@@ -51,70 +85,82 @@ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
5185
}); \
5286
}
5387

54-
#define IMPLEMENT_FLOAT_MKL_VML(op, mklop) \
55-
template <typename scalar_t> \
56-
inline void v##op(scalar_t* out, scalar_t* in, int64_t size); \
88+
#define IMPLEMENT_VML(op) \
89+
template <typename scalar_t> \
90+
inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \
91+
parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) { \
92+
map([](const Vec256<scalar_t>& x) { return x.op(); }, \
93+
out + begin, \
94+
in + begin, \
95+
end - begin); \
96+
}); \
97+
}
98+
99+
IMPLEMENT_VML_BUG(abs)
100+
IMPLEMENT_VML_BUG(acos)
101+
IMPLEMENT_VML_BUG(asin)
102+
IMPLEMENT_VML_BUG(atan)
103+
IMPLEMENT_VML_BUG(ceil)
104+
IMPLEMENT_VML_BUG(cos)
105+
// IMPLEMENT_VML_BUG(cosh)
106+
IMPLEMENT_VML_BUG(erf)
107+
IMPLEMENT_VML_BUG(exp)
108+
IMPLEMENT_VML_BUG(expm1)
109+
IMPLEMENT_VML_BUG(floor)
110+
IMPLEMENT_VML(reciprocal)
111+
IMPLEMENT_VML_BUG(log)
112+
IMPLEMENT_VML_BUG(log10)
113+
IMPLEMENT_VML_BUG(log1p)
114+
IMPLEMENT_VML_BUG(log2)
115+
IMPLEMENT_VML(neg)
116+
IMPLEMENT_VML_BUG(sin)
117+
// IMPLEMENT_VML_BUG(sinh)
118+
IMPLEMENT_VML_BUG(sqrt)
119+
IMPLEMENT_VML_BUG(round)
120+
IMPLEMENT_VML(rsqrt)
121+
IMPLEMENT_VML_BUG(tan)
122+
IMPLEMENT_VML_BUG(tanh)
123+
IMPLEMENT_VML_BUG(trunc)
124+
125+
#if AT_MKL_ENABLED() && !defined(__APPLE__)
126+
127+
#define IMPLEMENT_VML_MKL(op, mklop) \
57128
template <> \
58-
inline void v##op(float* out, float* in, int64_t size) { \
129+
inline void v##op(float* out, const float* in, int64_t size) { \
59130
vms##mklop(size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
60131
} \
61132
template <> \
62-
inline void v##op(double* out, double* in, int64_t size) { \
133+
inline void v##op(double* out, const double* in, int64_t size) { \
63134
vmd##mklop(size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
64135
}
65136

66137
// NB: abs, cosh and sinh were temporarily disabled due to issues with Apple clang
67138

68-
#if AT_MKL_ENABLED() && !defined(__APPLE__)
69-
IMPLEMENT_FLOAT_MKL_VML(acos, Acos)
70-
IMPLEMENT_FLOAT_MKL_VML(asin, Asin)
71-
IMPLEMENT_FLOAT_MKL_VML(atan, Atan)
72-
IMPLEMENT_FLOAT_MKL_VML(cos, Cos)
73-
// IMPLEMENT_FLOAT_MKL_VML(cosh, Cosh)
74-
IMPLEMENT_FLOAT_MKL_VML(erf, Erf)
75-
IMPLEMENT_FLOAT_MKL_VML(exp, Exp)
76-
IMPLEMENT_FLOAT_MKL_VML(expm1, Expm1)
77-
IMPLEMENT_FLOAT_MKL_VML(log, Ln)
78-
IMPLEMENT_FLOAT_MKL_VML(log10, Log10)
79-
IMPLEMENT_FLOAT_MKL_VML(log1p, Log1p)
80-
IMPLEMENT_FLOAT_MKL_VML(sin, Sin)
81-
// IMPLEMENT_FLOAT_MKL_VML(sinh, Sinh)
82-
IMPLEMENT_FLOAT_MKL_VML(sqrt, Sqrt)
83-
IMPLEMENT_FLOAT_MKL_VML(tan, Tan)
84-
IMPLEMENT_FLOAT_MKL_VML(tanh, Tanh)
85-
IMPLEMENT_FLOAT_MKL_VML(trunc, Trunc)
139+
IMPLEMENT_VML_MKL(abs, Abs)
140+
IMPLEMENT_VML_MKL(acos, Acos)
141+
IMPLEMENT_VML_MKL(asin, Asin)
142+
IMPLEMENT_VML_MKL(atan, Atan)
143+
IMPLEMENT_VML_MKL(cos, Cos)
144+
// IMPLEMENT_VML_MKL(cosh, Cosh)
145+
IMPLEMENT_VML_MKL(erf, Erf)
146+
IMPLEMENT_VML_MKL(exp, Exp)
147+
IMPLEMENT_VML_MKL(expm1, Expm1)
148+
IMPLEMENT_VML_MKL(log, Ln)
149+
IMPLEMENT_VML_MKL(log10, Log10)
150+
IMPLEMENT_VML_MKL(log1p, Log1p)
151+
IMPLEMENT_VML_MKL(sin, Sin)
152+
// IMPLEMENT_VML_MKL(sinh, Sinh)
153+
IMPLEMENT_VML_MKL(sqrt, Sqrt)
154+
IMPLEMENT_VML_MKL(tan, Tan)
155+
IMPLEMENT_VML_MKL(tanh, Tanh)
156+
IMPLEMENT_VML_MKL(trunc, Trunc)
86157

87158
#if INTEL_MKL_VERSION >= 20180406
88-
IMPLEMENT_FLOAT_MKL_VML(log2, Log2)
89-
#else
90-
IMPLEMENT_VML(log2)
159+
IMPLEMENT_VML_MKL(log2, Log2)
91160
#endif
92161

93-
#else
94-
IMPLEMENT_VML(acos)
95-
IMPLEMENT_VML(asin)
96-
IMPLEMENT_VML(atan)
97-
IMPLEMENT_VML(cos)
98-
// IMPLEMENT_VML(cosh)
99-
IMPLEMENT_VML(erf)
100-
IMPLEMENT_VML(exp)
101-
IMPLEMENT_VML(expm1)
102-
IMPLEMENT_VML(log)
103-
IMPLEMENT_VML(log10)
104-
IMPLEMENT_VML(log1p)
105-
IMPLEMENT_VML(log2)
106-
IMPLEMENT_VML(sin)
107-
// IMPLEMENT_VML(sinh)
108-
IMPLEMENT_VML(sqrt)
109-
IMPLEMENT_VML(tan)
110-
IMPLEMENT_VML(tanh)
111162
#endif
112163

113-
IMPLEMENT_VML(ceil)
114-
IMPLEMENT_VML(floor)
115-
IMPLEMENT_VML(round)
116-
IMPLEMENT_VML(trunc)
117-
118164
} // namespace
119165
} // namespace vml
120166
} // namespace at

aten/src/ATen/native/UnaryOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Tensor& fill_(Tensor& self, const Tensor& value) {
3535
// NB: If you use this macro, you may also need to add a CUDA forwarding
3636
// stub in CUDAUnaryOps
3737

38-
#define IMPLEMENT_UNARY_OP_VEC(op) \
38+
#define IMPLEMENT_UNARY_OP_VEC(op) \
3939
Tensor op(const Tensor& self) { \
4040
Tensor result = self.type().tensor(); \
4141
return at::op##_out(result, self); \
@@ -87,6 +87,7 @@ IMPLEMENT_UNARY_OP_VEC(log1p)
8787
IMPLEMENT_UNARY_OP_VEC(log2)
8888
IMPLEMENT_UNARY_OP_VEC(round)
8989
IMPLEMENT_UNARY_OP_VEC(rsqrt)
90+
IMPLEMENT_UNARY_OP_VEC(sigmoid)
9091
IMPLEMENT_UNARY_OP_VEC(sin)
9192
IMPLEMENT_UNARY_OP_TH(sinh)
9293
IMPLEMENT_UNARY_OP_VEC(sqrt)

aten/src/ATen/native/cpu/CapabilityDispatch.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ struct DispatchStub {
4848
#ifndef __powerpc__
4949
if (cpuinfo_initialize()) {
5050
int avx2 = static_cast<int>(CPUCapability::AVX2);
51-
if (!std::getenv("ATEN_DISABLE_AVX2") && cpuinfo_has_x86_avx2() && table[avx2]) {
51+
if (!std::getenv("ATEN_DISABLE_AVX2") && cpuinfo_has_x86_avx2() &&
52+
cpuinfo_has_x86_fma3() && table[avx2]) {
5253
return table[avx2];
5354
}
5455
int avx = static_cast<int>(CPUCapability::AVX);

0 commit comments

Comments
 (0)