|
2 | 2 | #define TH_GENERIC_FILE "generic/THTensorRandom.cpp" |
3 | 3 | #else |
4 | 4 |
|
| 5 | +#ifdef _OPENMP |
| 6 | +#include <omp.h> |
| 7 | +#endif |
| 8 | + |
| 9 | +#include <cpuinfo.h> |
| 10 | + |
5 | 11 | #include "THGenerator.hpp" |
6 | 12 |
|
7 | 13 | void THTensor_(random)(THTensor *self, THGenerator *_generator) |
@@ -51,10 +57,93 @@ void THTensor_(geometric)(THTensor *self, THGenerator *_generator, double p) |
51 | 57 | TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_geometric(_generator, p);); |
52 | 58 | } |
53 | 59 |
|
| 60 | +#ifdef TH_BLAS_MKL |
| 61 | +#define BERNOULLI_OMP 800 |
| 62 | +#define TH_OMP_OVERHEAD_THRESHOLD_COPY 20000 |
| 63 | + |
| 64 | +void iBernoulli_generate_copy(THTensor *self, THGenerator *_generator, const double p) |
| 65 | +{ |
| 66 | + int64_t seed = THRandom_random(_generator); |
| 67 | + int64_t n = THTensor_(nElement)(self); |
| 68 | + int contig = THTensor_(isContiguous)(self); |
| 69 | + int *tmp = NULL; |
| 70 | + THIntTensor* intTensor = NULL; |
| 71 | + |
| 72 | + if (contig) { |
| 73 | +#ifdef TH_REAL_IS_INT |
| 74 | + tmp = THIntTensor_data(self); |
| 75 | +#else |
| 76 | + tmp = (int*)THAlloc(n*sizeof(int)); |
| 77 | +#endif |
| 78 | + } else { |
| 79 | + intTensor = THIntTensor_new(); |
| 80 | + THIntTensor_resizeNd(intTensor, self->nDimension, self->size, NULL); |
| 81 | + tmp = THIntTensor_data(intTensor); |
| 82 | + } |
| 83 | + |
| 84 | +#ifdef _OPENMP |
| 85 | + size_t nthr = !omp_in_parallel() && n >= BERNOULLI_OMP ? omp_get_num_threads() : 1; |
| 86 | +#pragma omp parallel num_threads(nthr) firstprivate(nthr) |
| 87 | + { |
| 88 | + size_t tid = omp_get_thread_num(); |
| 89 | + int64_t seg_len_tmp = n / nthr; |
| 90 | + int64_t line_index_offset = tid * seg_len_tmp; |
| 91 | + int64_t line_seg_len = (tid == nthr - 1)? (n-line_index_offset) : seg_len_tmp; |
| 92 | +#else |
| 93 | + { |
| 94 | + int64_t line_index_offset = 0; |
| 95 | + int64_t line_seg_len = n; |
| 96 | +#endif |
| 97 | + |
| 98 | + if (line_seg_len > 0) { |
| 99 | + VSLStreamStatePtr stream; |
| 100 | + vslNewStream(&stream, VSL_BRNG_MCG31, seed); |
| 101 | + vslSkipAheadStream(stream, line_index_offset); |
| 102 | + viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, line_seg_len, |
| 103 | + tmp + line_index_offset, p); |
| 104 | + vslDeleteStream(&stream); |
| 105 | + |
| 106 | +#ifndef TH_REAL_IS_INT |
| 107 | + if (contig) { |
| 108 | + real* self_seg = THTensor_(data)(self) + line_index_offset; |
| 109 | + int* tmp_seg = tmp + line_index_offset; |
| 110 | + THVector_(cvtFromInt)(self_seg, tmp_seg, line_seg_len); |
| 111 | + } |
| 112 | +#endif |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + if(contig) { |
| 117 | +#ifndef TH_REAL_IS_INT |
| 118 | + THFree(tmp); |
| 119 | +#endif |
| 120 | + } else { |
| 121 | +#ifdef _OPENMP |
| 122 | + TH_TENSOR_APPLY2_OMP(n, 1, 0, int, intTensor, real, self, *self_data = *intTensor_data;, TH_OMP_OVERHEAD_THRESHOLD_COPY) |
| 123 | +#else |
| 124 | + TH_TENSOR_APPLY2(int, intTensor, real, self, *self_data = *intTensor_data;) |
| 125 | +#endif |
| 126 | + THIntTensor_free(intTensor); |
| 127 | + } |
| 128 | + |
| 129 | +} |
| 130 | + |
| 131 | +#endif |
| 132 | + |
54 | 133 | void THTensor_(bernoulli)(THTensor *self, THGenerator *_generator, double p) |
55 | 134 | { |
| 135 | +#ifdef TH_BLAS_MKL |
| 136 | + if(cpuinfo_initialize() && cpuinfo_vendor_intel == cpuinfo_get_processor(0)->core->vendor) { |
| 137 | + std::lock_guard<std::mutex> lock(_generator->mutex); |
| 138 | + iBernoulli_generate_copy(self, _generator, p); |
| 139 | + } else { |
| 140 | + std::lock_guard<std::mutex> lock(_generator->mutex); |
| 141 | + TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_bernoulli(_generator, p);); |
| 142 | + } |
| 143 | +#else |
56 | 144 | std::lock_guard<std::mutex> lock(_generator->mutex); |
57 | 145 | TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_bernoulli(_generator, p);); |
| 146 | +#endif |
58 | 147 | } |
59 | 148 |
|
60 | 149 | void THTensor_(bernoulli_FloatTensor)(THTensor *self, THGenerator *_generator, THFloatTensor *p) |
|
0 commit comments