Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -3568,6 +3568,19 @@
kwarg_only: True
- double p
]]
[[
name: _cpu_bernoulli_
backends:
- CPU
cname: bernoulli
return: self
arguments:
- THTensor* self
- arg: THGenerator* generator
default: nullptr
kwarg_only: True
- double p
]]
[[
name: _th_bernoulli
types:
Expand Down
8 changes: 1 addition & 7 deletions aten/src/ATen/native/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,7 @@ Tensor& bernoulli_(Tensor& self, const Tensor& p_, Generator* gen) {

Tensor& bernoulli_(Tensor& self, double p, Generator* gen) {
if (!self.is_cuda()) {
AT_DISPATCH_ALL_TYPES(self.type(), "bernoulli_", [&] {
THGenerator* generator = get_generator(gen);
std::lock_guard<std::mutex> lock(generator->mutex);
CPU_tensor_apply1<scalar_t>(self, [generator, p](scalar_t& ret_val) {
ret_val = (scalar_t)THRandom_bernoulli(generator, p);
});
});
self._cpu_bernoulli_(p, gen);
return self;
}
Tensor probs = self.type().toScalarType(kDouble).tensor({}).fill_(p);
Expand Down
4 changes: 4 additions & 0 deletions aten/src/TH/THGeneral.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
#include <stddef.h>
#include <inttypes.h>

#ifdef TH_BLAS_MKL
#include <mkl_vsl.h>
#endif

#cmakedefine USE_BLAS
#cmakedefine USE_LAPACK
#cmakedefine BLAS_F2C
Expand Down
4 changes: 2 additions & 2 deletions aten/src/TH/THTensorApply.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@
TYPE1 *rp = TENSOR1->storage->data+TENSOR1->storageOffset; \
TYPE2 *tp = TENSOR2->storage->data+TENSOR2->storageOffset; \
ptrdiff_t iter = 0; \
if(tp != rp) { \
if(tp != (TYPE2*)rp) { \
PRAGMA(ivdep) \
PRAGMA( omp parallel for if (SIZE > OMP_THRESHOLD * 10) firstprivate(rp, tp)) \
for (iter = 0; iter < SIZE; iter++) { \
Expand Down Expand Up @@ -449,7 +449,7 @@
TYPE2 *tp = TENSOR2->storage->data+TENSOR2->storageOffset; \
TYPE3 *srcp = TENSOR3->storage->data+TENSOR3->storageOffset; \
ptrdiff_t iter = 0;\
if (rp != tp) { \
if(tp != (TYPE2*)rp) { \
PRAGMA(ivdep) \
PRAGMA( omp parallel for if (SIZE > OMP_THRESHOLD * 10) ) \
for (iter = 0; iter < SIZE; iter++) {\
Expand Down
89 changes: 89 additions & 0 deletions aten/src/TH/generic/THTensorRandom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
#define TH_GENERIC_FILE "generic/THTensorRandom.cpp"
#else

#ifdef _OPENMP
#include <omp.h>
#endif

#include <cpuinfo.h>

#include "THGenerator.hpp"

void THTensor_(random)(THTensor *self, THGenerator *_generator)
Expand Down Expand Up @@ -51,10 +57,93 @@ void THTensor_(geometric)(THTensor *self, THGenerator *_generator, double p)
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_geometric(_generator, p););
}

#ifdef TH_BLAS_MKL
#define BERNOULLI_OMP 800
#define TH_OMP_OVERHEAD_THRESHOLD_COPY 20000

void iBernoulli_generate_copy(THTensor *self, THGenerator *_generator, const double p)
{
int64_t seed = THRandom_random(_generator);
int64_t n = THTensor_(nElement)(self);
int contig = THTensor_(isContiguous)(self);
int *tmp = NULL;
THIntTensor* intTensor = NULL;

if (contig) {
#ifdef TH_REAL_IS_INT
tmp = THIntTensor_data(self);
#else
tmp = (int*)THAlloc(n*sizeof(int));
#endif
} else {
intTensor = THIntTensor_new();
THIntTensor_resizeNd(intTensor, self->nDimension, self->size, NULL);
tmp = THIntTensor_data(intTensor);
}

#ifdef _OPENMP
size_t nthr = !omp_in_parallel() && n >= BERNOULLI_OMP ? omp_get_num_threads() : 1;
#pragma omp parallel num_threads(nthr) firstprivate(nthr)
{
size_t tid = omp_get_thread_num();
int64_t seg_len_tmp = n / nthr;
int64_t line_index_offset = tid * seg_len_tmp;
int64_t line_seg_len = (tid == nthr - 1)? (n-line_index_offset) : seg_len_tmp;
#else
{
int64_t line_index_offset = 0;
int64_t line_seg_len = n;
#endif

if (line_seg_len > 0) {
VSLStreamStatePtr stream;
vslNewStream(&stream, VSL_BRNG_MCG31, seed);
vslSkipAheadStream(stream, line_index_offset);
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, line_seg_len,
tmp + line_index_offset, p);
vslDeleteStream(&stream);

#ifndef TH_REAL_IS_INT
if (contig) {
real* self_seg = THTensor_(data)(self) + line_index_offset;
int* tmp_seg = tmp + line_index_offset;
THVector_(cvtFromInt)(self_seg, tmp_seg, line_seg_len);
}
#endif
}
}

if(contig) {
#ifndef TH_REAL_IS_INT
THFree(tmp);
#endif
} else {
#ifdef _OPENMP
TH_TENSOR_APPLY2_OMP(n, 1, 0, int, intTensor, real, self, *self_data = *intTensor_data;, TH_OMP_OVERHEAD_THRESHOLD_COPY)
#else
TH_TENSOR_APPLY2(int, intTensor, real, self, *self_data = *intTensor_data;)
#endif
THIntTensor_free(intTensor);
}

}

#endif

void THTensor_(bernoulli)(THTensor *self, THGenerator *_generator, double p)
{
#ifdef TH_BLAS_MKL
if(cpuinfo_initialize() && cpuinfo_vendor_intel == cpuinfo_get_processor(0)->core->vendor) {
std::lock_guard<std::mutex> lock(_generator->mutex);
iBernoulli_generate_copy(self, _generator, p);
} else {
std::lock_guard<std::mutex> lock(_generator->mutex);
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_bernoulli(_generator, p););
}
#else
std::lock_guard<std::mutex> lock(_generator->mutex);
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_bernoulli(_generator, p););
#endif
}

void THTensor_(bernoulli_FloatTensor)(THTensor *self, THGenerator *_generator, THFloatTensor *p)
Expand Down
3 changes: 3 additions & 0 deletions aten/src/TH/generic/THVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ TH_API void THVector_(normal_fill)(real *data,
struct THGenerator *generator,
const real mean,
const real stddev);
#ifndef TH_REAL_IS_INT
TH_API void THVector_(cvtFromInt)(real *y, const int *x, const ptrdiff_t n);
#endif

#if defined(TH_REAL_IS_SHORT) || defined(TH_REAL_IS_INT) || defined(TH_REAL_IS_LONG)
TH_API void THVector_(abs)(real *y, const real *x, const ptrdiff_t n);
Expand Down
18 changes: 18 additions & 0 deletions aten/src/TH/generic/THVectorDefault.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,24 @@ void THVector_(divs_DEFAULT)(real *y, const real *x, const real c, const ptrdiff
y[i] = x[i] / c;
}

#ifndef TH_REAL_IS_INT
void THVector_(cvtFromInt_DEFAULT)(real *y, const int *x, const ptrdiff_t n)
{
ptrdiff_t i = 0;

for(; i<n-4; i+=4)
{
y[i] = (real)x[i];
y[i+1] = (real)x[i+1];
y[i+2] = (real)x[i+2];
y[i+3] = (real)x[i+3];
}

for(; i < n; i++)
y[i] = (real)x[i];
}
#endif

// Fills 16 normally distributed samples into data, interleaved with a
// stride of 8, i.e. in order of ([0], [8]), ([1], [9]), ...
static void THVector_(interleaved_normal_fill_16)(real *data,
Expand Down
27 changes: 27 additions & 0 deletions aten/src/TH/generic/THVectorDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,29 @@ void THVector_(copy)(real *y, const real *x, const ptrdiff_t n) {
THVector_(copy_DISPATCHPTR)(y, x, n);
}

#ifndef TH_REAL_IS_INT
static void (*THVector_(cvtFromInt_DISPATCHPTR))(real *, const int *, const ptrdiff_t) = &THVector_(cvtFromInt_DEFAULT);
static FunctionDescription THVector_(cvtFromInt_DISPATCHTABLE)[] = {
#if defined(USE_AVX)
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
FUNCTION_IMPL(THVector_(cvtFromInt_AVX), SIMDExtension_AVX),
#endif
#endif
#if defined(USE_SSE2) || defined(USE_SSE3) || defined(USE_SSSE3) \
|| defined(USE_SSE4_1) || defined(USE_SSE4_2)
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
FUNCTION_IMPL(THVector_(cvtFromInt_SSE), SIMDExtension_SSE),
#endif
#endif


FUNCTION_IMPL(THVector_(cvtFromInt_DEFAULT), SIMDExtension_DEFAULT)
};
void THVector_(cvtFromInt)(real *y, const int *x, const ptrdiff_t n) {
THVector_(cvtFromInt_DISPATCHPTR)(y, x, n);
}
#endif

static void (*THVector_(normal_fill_DISPATCHPTR))(real *, const int64_t, THGenerator *, const real, const real) = &THVector_(normal_fill_DEFAULT);
static FunctionDescription THVector_(normal_fill_DISPATCHTABLE)[] = {
#if defined(TH_REAL_IS_FLOAT) && defined(USE_AVX2)
Expand Down Expand Up @@ -290,6 +313,10 @@ struct THVector_(startup) {
INIT_DISPATCH_PTR(copy);
INIT_DISPATCH_PTR(normal_fill);

#ifndef TH_REAL_IS_INT
INIT_DISPATCH_PTR(cvtFromInt);
#endif

#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
INIT_DISPATCH_PTR(sigmoid);
#endif
Expand Down
34 changes: 34 additions & 0 deletions aten/src/TH/vector/AVX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,38 @@ void THFloatVector_adds_AVX(float *y, const float *x, const float c, const ptrdi
}
}

void THFloatVector_cvtFromInt_AVX(float *y, const int *x, const ptrdiff_t n) {
ptrdiff_t i;
__m256i YMM0, YMM1;
__m256 YMM2, YMM3;
for (i=0; i<=((n)-16); i+=16) {
YMM0 = _mm256_loadu_si256((__m256i const*)(x+i));
YMM1 = _mm256_loadu_si256((__m256i const*)(x+i+8));
YMM2 = _mm256_cvtepi32_ps(YMM0);
YMM3 = _mm256_cvtepi32_ps(YMM1);
_mm256_storeu_ps(y+i, YMM2);
_mm256_storeu_ps(y+i+8, YMM3);
}
for (; i<(n); i++) {
y[i] = (float)x[i];
}
}

void THDoubleVector_cvtFromInt_AVX(double *y, const int *x, const ptrdiff_t n) {
ptrdiff_t i;
__m128i YMM0, YMM1;
__m256d YMM2, YMM3;
for (i=0; i<=((n)- 8); i+=8) {
YMM0 = _mm_loadu_si128((__m128i const*)(x+i));
YMM1 = _mm_loadu_si128((__m128i const*)(x+i+4));
YMM2 = _mm256_cvtepi32_pd(YMM0);
YMM3 = _mm256_cvtepi32_pd(YMM1);
_mm256_storeu_pd(y+i, YMM2);
_mm256_storeu_pd(y+i+4, YMM3);
}
for (; i<(n); i++) {
y[i] = (double)x[i];
}
}

#endif // defined(__AVX__)
2 changes: 2 additions & 0 deletions aten/src/TH/vector/AVX.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ TH_API void THDoubleVector_cmul_AVX(double *z, const double *x, const double *y,
TH_API void THDoubleVector_muls_AVX(double *y, const double *x, const double c, const ptrdiff_t n);
TH_API void THDoubleVector_cadd_AVX(double *z, const double *x, const double *y, const double c, const ptrdiff_t n);
TH_API void THDoubleVector_adds_AVX(double *y, const double *x, const double c, const ptrdiff_t n);
TH_API void THDoubleVector_cvtFromInt_AVX(double *y, const int *x, const ptrdiff_t n);
TH_API void THFloatVector_copy_AVX(float *y, const float *x, const ptrdiff_t n);
TH_API void THFloatVector_fill_AVX(float *x, const float c, const ptrdiff_t n);
TH_API void THFloatVector_cdiv_AVX(float *z, const float *x, const float *y, const ptrdiff_t n);
Expand All @@ -20,4 +21,5 @@ TH_API void THFloatVector_cmul_AVX(float *z, const float *x, const float *y, con
TH_API void THFloatVector_muls_AVX(float *y, const float *x, const float c, const ptrdiff_t n);
TH_API void THFloatVector_cadd_AVX(float *z, const float *x, const float *y, const float c, const ptrdiff_t n);
TH_API void THFloatVector_adds_AVX(float *y, const float *x, const float c, const ptrdiff_t n);
TH_API void THFloatVector_cvtFromInt_AVX(float *y, const int *x, const ptrdiff_t n);
#endif
35 changes: 35 additions & 0 deletions aten/src/TH/vector/SSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,38 @@ static void THFloatVector_divs_SSE(float *y, const float *x, const float c, cons
y[i] = x[i] / c;
}
}

static void THFloatVector_cvtFromInt_SSE(float *y, const int *x, const ptrdiff_t n) {
ptrdiff_t i;
__m128i YMM0, YMM1;
__m128 YMM2, YMM3;
for (i=0; i<=((n)-8); i+=8) {
YMM0 = _mm_loadu_si128((__m128i const*)(x+i));
YMM1 = _mm_loadu_si128((__m128i const*)(x+i+4));
YMM2 = _mm_cvtepi32_ps(YMM0);
YMM3 = _mm_cvtepi32_ps(YMM1);
_mm_storeu_ps(y+i, YMM2);
_mm_storeu_ps(y+i+4, YMM3);
}
for (; i<(n); i++) {
y[i] = (float)x[i];
}
}

static void THDoubleVector_cvtFromInt_SSE(double *y, const int *x, const ptrdiff_t n) {
ptrdiff_t i;
__m128i YMM0, YMM1;
__m128d YMM2, YMM3;
for (i=0; i<=((n)- 4); i+=4) {
YMM0 = _mm_loadu_si128((__m128i const*)(x+i));
YMM2 = _mm_cvtepi32_pd(YMM0);
YMM1 = _mm_srli_si128(YMM0, 8);
YMM3 = _mm_cvtepi32_pd(YMM1);
_mm_storeu_pd(y+i, YMM2);
_mm_storeu_pd(y+i+2, YMM3);
}
for (; i<(n); i++) {
y[i] = (double)x[i];
}
}