Skip to content

Commit 227a764

Browse files
MlWoosoumith
authored andcommitted
Accelerate bernoulli number generation on CPU (#7171)
* opt bernoulli rng with vsl and openmp * detect cpu vendor for bernnoulli * retrigger test platform * check the vendor more severely * use cpuinfo to check vendor
1 parent ee0b75a commit 227a764

File tree

11 files changed

+228
-9
lines changed

11 files changed

+228
-9
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3568,6 +3568,19 @@
35683568
kwarg_only: True
35693569
- double p
35703570
]]
3571+
[[
3572+
name: _cpu_bernoulli_
3573+
backends:
3574+
- CPU
3575+
cname: bernoulli
3576+
return: self
3577+
arguments:
3578+
- THTensor* self
3579+
- arg: THGenerator* generator
3580+
default: nullptr
3581+
kwarg_only: True
3582+
- double p
3583+
]]
35713584
[[
35723585
name: _th_bernoulli
35733586
types:

aten/src/ATen/native/Distributions.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,7 @@ Tensor& bernoulli_(Tensor& self, const Tensor& p_, Generator* gen) {
146146

147147
Tensor& bernoulli_(Tensor& self, double p, Generator* gen) {
148148
if (!self.is_cuda()) {
149-
AT_DISPATCH_ALL_TYPES(self.type(), "bernoulli_", [&] {
150-
THGenerator* generator = get_generator(gen);
151-
std::lock_guard<std::mutex> lock(generator->mutex);
152-
CPU_tensor_apply1<scalar_t>(self, [generator, p](scalar_t& ret_val) {
153-
ret_val = (scalar_t)THRandom_bernoulli(generator, p);
154-
});
155-
});
149+
self._cpu_bernoulli_(p, gen);
156150
return self;
157151
}
158152
Tensor probs = self.type().toScalarType(kDouble).tensor({}).fill_(p);

aten/src/TH/THGeneral.h.in

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
#include <stddef.h>
1818
#include <inttypes.h>
1919

20+
#ifdef TH_BLAS_MKL
21+
#include <mkl_vsl.h>
22+
#endif
23+
2024
#cmakedefine USE_BLAS
2125
#cmakedefine USE_LAPACK
2226
#cmakedefine BLAS_F2C

aten/src/TH/THTensorApply.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@
369369
TYPE1 *rp = TENSOR1->storage->data<TYPE1>()+TENSOR1->storageOffset; \
370370
TYPE2 *tp = TENSOR2->storage->data<TYPE2>()+TENSOR2->storageOffset; \
371371
ptrdiff_t iter = 0; \
372-
if(tp != rp) { \
372+
if(tp != (TYPE2*)rp) { \
373373
PRAGMA(ivdep) \
374374
PRAGMA( omp parallel for if (SIZE > OMP_THRESHOLD * 10) firstprivate(rp, tp)) \
375375
for (iter = 0; iter < SIZE; iter++) { \
@@ -449,7 +449,7 @@
449449
TYPE2 *tp = TENSOR2->storage->data<TYPE2>()+TENSOR2->storageOffset; \
450450
TYPE3 *srcp = TENSOR3->storage->data<TYPE3>()+TENSOR3->storageOffset; \
451451
ptrdiff_t iter = 0;\
452-
if (rp != tp) { \
452+
if(tp != (TYPE2*)rp) { \
453453
PRAGMA(ivdep) \
454454
PRAGMA( omp parallel for if (SIZE > OMP_THRESHOLD * 10) ) \
455455
for (iter = 0; iter < SIZE; iter++) {\

aten/src/TH/generic/THTensorRandom.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
#define TH_GENERIC_FILE "generic/THTensorRandom.cpp"
33
#else
44

5+
#ifdef _OPENMP
6+
#include <omp.h>
7+
#endif
8+
9+
#include <cpuinfo.h>
10+
511
#include "THGenerator.hpp"
612

713
void THTensor_(random)(THTensor *self, THGenerator *_generator)
@@ -51,10 +57,93 @@ void THTensor_(geometric)(THTensor *self, THGenerator *_generator, double p)
5157
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_geometric(_generator, p););
5258
}
5359

60+
#ifdef TH_BLAS_MKL
61+
#define BERNOULLI_OMP 800
62+
#define TH_OMP_OVERHEAD_THRESHOLD_COPY 20000
63+
64+
void iBernoulli_generate_copy(THTensor *self, THGenerator *_generator, const double p)
65+
{
66+
int64_t seed = THRandom_random(_generator);
67+
int64_t n = THTensor_(nElement)(self);
68+
int contig = THTensor_(isContiguous)(self);
69+
int *tmp = NULL;
70+
THIntTensor* intTensor = NULL;
71+
72+
if (contig) {
73+
#ifdef TH_REAL_IS_INT
74+
tmp = THIntTensor_data(self);
75+
#else
76+
tmp = (int*)THAlloc(n*sizeof(int));
77+
#endif
78+
} else {
79+
intTensor = THIntTensor_new();
80+
THIntTensor_resizeNd(intTensor, self->nDimension, self->size, NULL);
81+
tmp = THIntTensor_data(intTensor);
82+
}
83+
84+
#ifdef _OPENMP
85+
size_t nthr = !omp_in_parallel() && n >= BERNOULLI_OMP ? omp_get_num_threads() : 1;
86+
#pragma omp parallel num_threads(nthr) firstprivate(nthr)
87+
{
88+
size_t tid = omp_get_thread_num();
89+
int64_t seg_len_tmp = n / nthr;
90+
int64_t line_index_offset = tid * seg_len_tmp;
91+
int64_t line_seg_len = (tid == nthr - 1)? (n-line_index_offset) : seg_len_tmp;
92+
#else
93+
{
94+
int64_t line_index_offset = 0;
95+
int64_t line_seg_len = n;
96+
#endif
97+
98+
if (line_seg_len > 0) {
99+
VSLStreamStatePtr stream;
100+
vslNewStream(&stream, VSL_BRNG_MCG31, seed);
101+
vslSkipAheadStream(stream, line_index_offset);
102+
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, line_seg_len,
103+
tmp + line_index_offset, p);
104+
vslDeleteStream(&stream);
105+
106+
#ifndef TH_REAL_IS_INT
107+
if (contig) {
108+
real* self_seg = THTensor_(data)(self) + line_index_offset;
109+
int* tmp_seg = tmp + line_index_offset;
110+
THVector_(cvtFromInt)(self_seg, tmp_seg, line_seg_len);
111+
}
112+
#endif
113+
}
114+
}
115+
116+
if(contig) {
117+
#ifndef TH_REAL_IS_INT
118+
THFree(tmp);
119+
#endif
120+
} else {
121+
#ifdef _OPENMP
122+
TH_TENSOR_APPLY2_OMP(n, 1, 0, int, intTensor, real, self, *self_data = *intTensor_data;, TH_OMP_OVERHEAD_THRESHOLD_COPY)
123+
#else
124+
TH_TENSOR_APPLY2(int, intTensor, real, self, *self_data = *intTensor_data;)
125+
#endif
126+
THIntTensor_free(intTensor);
127+
}
128+
129+
}
130+
131+
#endif
132+
54133
void THTensor_(bernoulli)(THTensor *self, THGenerator *_generator, double p)
55134
{
135+
#ifdef TH_BLAS_MKL
136+
if(cpuinfo_initialize() && cpuinfo_vendor_intel == cpuinfo_get_processor(0)->core->vendor) {
137+
std::lock_guard<std::mutex> lock(_generator->mutex);
138+
iBernoulli_generate_copy(self, _generator, p);
139+
} else {
140+
std::lock_guard<std::mutex> lock(_generator->mutex);
141+
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_bernoulli(_generator, p););
142+
}
143+
#else
56144
std::lock_guard<std::mutex> lock(_generator->mutex);
57145
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_bernoulli(_generator, p););
146+
#endif
58147
}
59148

60149
void THTensor_(bernoulli_FloatTensor)(THTensor *self, THGenerator *_generator, THFloatTensor *p)

aten/src/TH/generic/THVector.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ TH_API void THVector_(normal_fill)(real *data,
1919
struct THGenerator *generator,
2020
const real mean,
2121
const real stddev);
22+
#ifndef TH_REAL_IS_INT
23+
TH_API void THVector_(cvtFromInt)(real *y, const int *x, const ptrdiff_t n);
24+
#endif
2225

2326
#if defined(TH_REAL_IS_SHORT) || defined(TH_REAL_IS_INT) || defined(TH_REAL_IS_LONG)
2427
TH_API void THVector_(abs)(real *y, const real *x, const ptrdiff_t n);

aten/src/TH/generic/THVectorDefault.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,24 @@ void THVector_(divs_DEFAULT)(real *y, const real *x, const real c, const ptrdiff
130130
y[i] = x[i] / c;
131131
}
132132

133+
#ifndef TH_REAL_IS_INT
134+
void THVector_(cvtFromInt_DEFAULT)(real *y, const int *x, const ptrdiff_t n)
135+
{
136+
ptrdiff_t i = 0;
137+
138+
for(; i<n-4; i+=4)
139+
{
140+
y[i] = (real)x[i];
141+
y[i+1] = (real)x[i+1];
142+
y[i+2] = (real)x[i+2];
143+
y[i+3] = (real)x[i+3];
144+
}
145+
146+
for(; i < n; i++)
147+
y[i] = (real)x[i];
148+
}
149+
#endif
150+
133151
// Fills 16 normally distributed samples into data, interleaved with a
134152
// stride of 8, i.e. in order of ([0], [8]), ([1], [9]), ...
135153
static void THVector_(interleaved_normal_fill_16)(real *data,

aten/src/TH/generic/THVectorDispatch.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,29 @@ void THVector_(copy)(real *y, const real *x, const ptrdiff_t n) {
239239
THVector_(copy_DISPATCHPTR)(y, x, n);
240240
}
241241

242+
#ifndef TH_REAL_IS_INT
243+
static void (*THVector_(cvtFromInt_DISPATCHPTR))(real *, const int *, const ptrdiff_t) = &THVector_(cvtFromInt_DEFAULT);
244+
static FunctionDescription THVector_(cvtFromInt_DISPATCHTABLE)[] = {
245+
#if defined(USE_AVX)
246+
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
247+
FUNCTION_IMPL(THVector_(cvtFromInt_AVX), SIMDExtension_AVX),
248+
#endif
249+
#endif
250+
#if defined(USE_SSE2) || defined(USE_SSE3) || defined(USE_SSSE3) \
251+
|| defined(USE_SSE4_1) || defined(USE_SSE4_2)
252+
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
253+
FUNCTION_IMPL(THVector_(cvtFromInt_SSE), SIMDExtension_SSE),
254+
#endif
255+
#endif
256+
257+
258+
FUNCTION_IMPL(THVector_(cvtFromInt_DEFAULT), SIMDExtension_DEFAULT)
259+
};
260+
void THVector_(cvtFromInt)(real *y, const int *x, const ptrdiff_t n) {
261+
THVector_(cvtFromInt_DISPATCHPTR)(y, x, n);
262+
}
263+
#endif
264+
242265
static void (*THVector_(normal_fill_DISPATCHPTR))(real *, const int64_t, THGenerator *, const real, const real) = &THVector_(normal_fill_DEFAULT);
243266
static FunctionDescription THVector_(normal_fill_DISPATCHTABLE)[] = {
244267
#if defined(TH_REAL_IS_FLOAT) && defined(USE_AVX2)
@@ -290,6 +313,10 @@ struct THVector_(startup) {
290313
INIT_DISPATCH_PTR(copy);
291314
INIT_DISPATCH_PTR(normal_fill);
292315

316+
#ifndef TH_REAL_IS_INT
317+
INIT_DISPATCH_PTR(cvtFromInt);
318+
#endif
319+
293320
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
294321
INIT_DISPATCH_PTR(sigmoid);
295322
#endif

aten/src/TH/vector/AVX.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,4 +271,38 @@ void THFloatVector_adds_AVX(float *y, const float *x, const float c, const ptrdi
271271
}
272272
}
273273

274+
void THFloatVector_cvtFromInt_AVX(float *y, const int *x, const ptrdiff_t n) {
275+
ptrdiff_t i;
276+
__m256i YMM0, YMM1;
277+
__m256 YMM2, YMM3;
278+
for (i=0; i<=((n)-16); i+=16) {
279+
YMM0 = _mm256_loadu_si256((__m256i const*)(x+i));
280+
YMM1 = _mm256_loadu_si256((__m256i const*)(x+i+8));
281+
YMM2 = _mm256_cvtepi32_ps(YMM0);
282+
YMM3 = _mm256_cvtepi32_ps(YMM1);
283+
_mm256_storeu_ps(y+i, YMM2);
284+
_mm256_storeu_ps(y+i+8, YMM3);
285+
}
286+
for (; i<(n); i++) {
287+
y[i] = (float)x[i];
288+
}
289+
}
290+
291+
void THDoubleVector_cvtFromInt_AVX(double *y, const int *x, const ptrdiff_t n) {
292+
ptrdiff_t i;
293+
__m128i YMM0, YMM1;
294+
__m256d YMM2, YMM3;
295+
for (i=0; i<=((n)- 8); i+=8) {
296+
YMM0 = _mm_loadu_si128((__m128i const*)(x+i));
297+
YMM1 = _mm_loadu_si128((__m128i const*)(x+i+4));
298+
YMM2 = _mm256_cvtepi32_pd(YMM0);
299+
YMM3 = _mm256_cvtepi32_pd(YMM1);
300+
_mm256_storeu_pd(y+i, YMM2);
301+
_mm256_storeu_pd(y+i+4, YMM3);
302+
}
303+
for (; i<(n); i++) {
304+
y[i] = (double)x[i];
305+
}
306+
}
307+
274308
#endif // defined(__AVX__)

aten/src/TH/vector/AVX.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ TH_API void THDoubleVector_cmul_AVX(double *z, const double *x, const double *y,
1212
TH_API void THDoubleVector_muls_AVX(double *y, const double *x, const double c, const ptrdiff_t n);
1313
TH_API void THDoubleVector_cadd_AVX(double *z, const double *x, const double *y, const double c, const ptrdiff_t n);
1414
TH_API void THDoubleVector_adds_AVX(double *y, const double *x, const double c, const ptrdiff_t n);
15+
TH_API void THDoubleVector_cvtFromInt_AVX(double *y, const int *x, const ptrdiff_t n);
1516
TH_API void THFloatVector_copy_AVX(float *y, const float *x, const ptrdiff_t n);
1617
TH_API void THFloatVector_fill_AVX(float *x, const float c, const ptrdiff_t n);
1718
TH_API void THFloatVector_cdiv_AVX(float *z, const float *x, const float *y, const ptrdiff_t n);
@@ -20,4 +21,5 @@ TH_API void THFloatVector_cmul_AVX(float *z, const float *x, const float *y, con
2021
TH_API void THFloatVector_muls_AVX(float *y, const float *x, const float c, const ptrdiff_t n);
2122
TH_API void THFloatVector_cadd_AVX(float *z, const float *x, const float *y, const float c, const ptrdiff_t n);
2223
TH_API void THFloatVector_adds_AVX(float *y, const float *x, const float c, const ptrdiff_t n);
24+
TH_API void THFloatVector_cvtFromInt_AVX(float *y, const int *x, const ptrdiff_t n);
2325
#endif

0 commit comments

Comments
 (0)