Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/backend/cpu/kernel/random_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,18 @@ double getDouble01(uint *val, uint index) {

template<>
char transform<char>(uint *val, uint index) {
char v = val[index >> 2] >> (8 << (index & 3));
v = (v & 0x1) ? 1 : 0;
char v = 0;
memcpy(&v, static_cast<char *>(static_cast<void *>(val)) + index,
sizeof(char));
v &= 0x1;
return v;
}

template<>
uchar transform<uchar>(uint *val, uint index) {
uchar v = val[index >> 2] >> (index << 3);
uchar v = 0;
memcpy(&v, static_cast<uchar *>(static_cast<void *>(val)) + index,
sizeof(uchar));
return v;
}

Expand Down Expand Up @@ -210,7 +214,7 @@ void philoxUniform(T *out, size_t elements, const uintl seed, uintl counter) {

// Use the same ctr array for each of the 4 locations,
// but each of the location gets a different ctr value
for (size_t buf_idx = 0; buf_idx < NUM_WRITES; ++buf_idx) {
for (uint buf_idx = 0; buf_idx < NUM_WRITES; ++buf_idx) {
size_t out_idx = iter + buf_idx * WRITE_STRIDE + i + j;
if (out_idx < elements) {
out[out_idx] = transform<T>(ctr, buf_idx);
Expand Down
48 changes: 24 additions & 24 deletions src/backend/cuda/kernel/random_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,21 +315,21 @@ __device__ static void writeOut128Bytes(char *out, const uint &index,
const uint &r1, const uint &r2,
const uint &r3, const uint &r4) {
out[index] = (r1)&0x1;
out[index + blockDim.x] = (r1 >> 1) & 0x1;
out[index + 2 * blockDim.x] = (r1 >> 2) & 0x1;
out[index + 3 * blockDim.x] = (r1 >> 3) & 0x1;
out[index + blockDim.x] = (r1 >> 8) & 0x1;
out[index + 2 * blockDim.x] = (r1 >> 16) & 0x1;
out[index + 3 * blockDim.x] = (r1 >> 24) & 0x1;
out[index + 4 * blockDim.x] = (r2)&0x1;
out[index + 5 * blockDim.x] = (r2 >> 1) & 0x1;
out[index + 6 * blockDim.x] = (r2 >> 2) & 0x1;
out[index + 7 * blockDim.x] = (r2 >> 3) & 0x1;
out[index + 5 * blockDim.x] = (r2 >> 8) & 0x1;
out[index + 6 * blockDim.x] = (r2 >> 16) & 0x1;
out[index + 7 * blockDim.x] = (r2 >> 24) & 0x1;
out[index + 8 * blockDim.x] = (r3)&0x1;
out[index + 9 * blockDim.x] = (r3 >> 1) & 0x1;
out[index + 10 * blockDim.x] = (r3 >> 2) & 0x1;
out[index + 11 * blockDim.x] = (r3 >> 3) & 0x1;
out[index + 9 * blockDim.x] = (r3 >> 8) & 0x1;
out[index + 10 * blockDim.x] = (r3 >> 16) & 0x1;
out[index + 11 * blockDim.x] = (r3 >> 24) & 0x1;
out[index + 12 * blockDim.x] = (r4)&0x1;
out[index + 13 * blockDim.x] = (r4 >> 1) & 0x1;
out[index + 14 * blockDim.x] = (r4 >> 2) & 0x1;
out[index + 15 * blockDim.x] = (r4 >> 3) & 0x1;
out[index + 13 * blockDim.x] = (r4 >> 8) & 0x1;
out[index + 14 * blockDim.x] = (r4 >> 16) & 0x1;
out[index + 15 * blockDim.x] = (r4 >> 24) & 0x1;
}

__device__ static void writeOut128Bytes(short *out, const uint &index,
Expand Down Expand Up @@ -540,49 +540,49 @@ __device__ static void partialWriteOut128Bytes(char *out, const uint &index,
const uint &elements) {
if (index < elements) { out[index] = (r1)&0x1; }
if (index + blockDim.x < elements) {
out[index + blockDim.x] = (r1 >> 1) & 0x1;
out[index + blockDim.x] = (r1 >> 8) & 0x1;
}
if (index + 2 * blockDim.x < elements) {
out[index + 2 * blockDim.x] = (r1 >> 2) & 0x1;
out[index + 2 * blockDim.x] = (r1 >> 16) & 0x1;
}
if (index + 3 * blockDim.x < elements) {
out[index + 3 * blockDim.x] = (r1 >> 3) & 0x1;
out[index + 3 * blockDim.x] = (r1 >> 24) & 0x1;
}
if (index + 4 * blockDim.x < elements) {
out[index + 4 * blockDim.x] = (r2)&0x1;
}
if (index + 5 * blockDim.x < elements) {
out[index + 5 * blockDim.x] = (r2 >> 1) & 0x1;
out[index + 5 * blockDim.x] = (r2 >> 8) & 0x1;
}
if (index + 6 * blockDim.x < elements) {
out[index + 6 * blockDim.x] = (r2 >> 2) & 0x1;
out[index + 6 * blockDim.x] = (r2 >> 16) & 0x1;
}
if (index + 7 * blockDim.x < elements) {
out[index + 7 * blockDim.x] = (r2 >> 3) & 0x1;
out[index + 7 * blockDim.x] = (r2 >> 24) & 0x1;
}
if (index + 8 * blockDim.x < elements) {
out[index + 8 * blockDim.x] = (r3)&0x1;
}
if (index + 9 * blockDim.x < elements) {
out[index + 9 * blockDim.x] = (r3 >> 1) & 0x1;
out[index + 9 * blockDim.x] = (r3 >> 8) & 0x1;
}
if (index + 10 * blockDim.x < elements) {
out[index + 10 * blockDim.x] = (r3 >> 2) & 0x1;
out[index + 10 * blockDim.x] = (r3 >> 16) & 0x1;
}
if (index + 11 * blockDim.x < elements) {
out[index + 11 * blockDim.x] = (r3 >> 3) & 0x1;
out[index + 11 * blockDim.x] = (r3 >> 24) & 0x1;
}
if (index + 12 * blockDim.x < elements) {
out[index + 12 * blockDim.x] = (r4)&0x1;
}
if (index + 13 * blockDim.x < elements) {
out[index + 13 * blockDim.x] = (r4 >> 1) & 0x1;
out[index + 13 * blockDim.x] = (r4 >> 8) & 0x1;
}
if (index + 14 * blockDim.x < elements) {
out[index + 14 * blockDim.x] = (r4 >> 2) & 0x1;
out[index + 14 * blockDim.x] = (r4 >> 16) & 0x1;
}
if (index + 15 * blockDim.x < elements) {
out[index + 15 * blockDim.x] = (r4 >> 3) & 0x1;
out[index + 15 * blockDim.x] = (r4 >> 24) & 0x1;
}
}

Expand Down
48 changes: 24 additions & 24 deletions src/backend/oneapi/kernel/random_engine_write.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,21 +310,21 @@ static void writeOut128Bytes(char *out, const uint &index, const uint groupSz,
const uint &r1, const uint &r2, const uint &r3,
const uint &r4) {
out[index] = (r1)&0x1;
out[index + groupSz] = (r1 >> 1) & 0x1;
out[index + 2 * groupSz] = (r1 >> 2) & 0x1;
out[index + 3 * groupSz] = (r1 >> 3) & 0x1;
out[index + groupSz] = (r1 >> 8) & 0x1;
out[index + 2 * groupSz] = (r1 >> 16) & 0x1;
out[index + 3 * groupSz] = (r1 >> 24) & 0x1;
out[index + 4 * groupSz] = (r2)&0x1;
out[index + 5 * groupSz] = (r2 >> 1) & 0x1;
out[index + 6 * groupSz] = (r2 >> 2) & 0x1;
out[index + 7 * groupSz] = (r2 >> 3) & 0x1;
out[index + 5 * groupSz] = (r2 >> 8) & 0x1;
out[index + 6 * groupSz] = (r2 >> 16) & 0x1;
out[index + 7 * groupSz] = (r2 >> 24) & 0x1;
out[index + 8 * groupSz] = (r3)&0x1;
out[index + 9 * groupSz] = (r3 >> 1) & 0x1;
out[index + 10 * groupSz] = (r3 >> 2) & 0x1;
out[index + 11 * groupSz] = (r3 >> 3) & 0x1;
out[index + 9 * groupSz] = (r3 >> 8) & 0x1;
out[index + 10 * groupSz] = (r3 >> 16) & 0x1;
out[index + 11 * groupSz] = (r3 >> 24) & 0x1;
out[index + 12 * groupSz] = (r4)&0x1;
out[index + 13 * groupSz] = (r4 >> 1) & 0x1;
out[index + 14 * groupSz] = (r4 >> 2) & 0x1;
out[index + 15 * groupSz] = (r4 >> 3) & 0x1;
out[index + 13 * groupSz] = (r4 >> 8) & 0x1;
out[index + 14 * groupSz] = (r4 >> 16) & 0x1;
out[index + 15 * groupSz] = (r4 >> 24) & 0x1;
}

static void writeOut128Bytes(short *out, const uint &index, const uint groupSz,
Expand Down Expand Up @@ -513,44 +513,44 @@ static void partialWriteOut128Bytes(char *out, const uint &index,
const uint &r2, const uint &r3,
const uint &r4, const uint &elements) {
if (index < elements) { out[index] = (r1)&0x1; }
if (index + groupSz < elements) { out[index + groupSz] = (r1 >> 1) & 0x1; }
if (index + groupSz < elements) { out[index + groupSz] = (r1 >> 8) & 0x1; }
if (index + 2 * groupSz < elements) {
out[index + 2 * groupSz] = (r1 >> 2) & 0x1;
out[index + 2 * groupSz] = (r1 >> 16) & 0x1;
}
if (index + 3 * groupSz < elements) {
out[index + 3 * groupSz] = (r1 >> 3) & 0x1;
out[index + 3 * groupSz] = (r1 >> 24) & 0x1;
}
if (index + 4 * groupSz < elements) { out[index + 4 * groupSz] = (r2)&0x1; }
if (index + 5 * groupSz < elements) {
out[index + 5 * groupSz] = (r2 >> 1) & 0x1;
out[index + 5 * groupSz] = (r2 >> 8) & 0x1;
}
if (index + 6 * groupSz < elements) {
out[index + 6 * groupSz] = (r2 >> 2) & 0x1;
out[index + 6 * groupSz] = (r2 >> 16) & 0x1;
}
if (index + 7 * groupSz < elements) {
out[index + 7 * groupSz] = (r2 >> 3) & 0x1;
out[index + 7 * groupSz] = (r2 >> 24) & 0x1;
}
if (index + 8 * groupSz < elements) { out[index + 8 * groupSz] = (r3)&0x1; }
if (index + 9 * groupSz < elements) {
out[index + 9 * groupSz] = (r3 >> 1) & 0x1;
out[index + 9 * groupSz] = (r3 >> 8) & 0x1;
}
if (index + 10 * groupSz < elements) {
out[index + 10 * groupSz] = (r3 >> 2) & 0x1;
out[index + 10 * groupSz] = (r3 >> 16) & 0x1;
}
if (index + 11 * groupSz < elements) {
out[index + 11 * groupSz] = (r3 >> 3) & 0x1;
out[index + 11 * groupSz] = (r3 >> 24) & 0x1;
}
if (index + 12 * groupSz < elements) {
out[index + 12 * groupSz] = (r4)&0x1;
}
if (index + 13 * groupSz < elements) {
out[index + 13 * groupSz] = (r4 >> 1) & 0x1;
out[index + 13 * groupSz] = (r4 >> 8) & 0x1;
}
if (index + 14 * groupSz < elements) {
out[index + 14 * groupSz] = (r4 >> 2) & 0x1;
out[index + 14 * groupSz] = (r4 >> 16) & 0x1;
}
if (index + 15 * groupSz < elements) {
out[index + 15 * groupSz] = (r4 >> 3) & 0x1;
out[index + 15 * groupSz] = (r4 >> 24) & 0x1;
}
}

Expand Down
48 changes: 24 additions & 24 deletions src/backend/opencl/kernel/random_engine_write.cl
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,21 @@ void writeOut128Bytes_uchar(global uchar *out, uint index, uint r1, uint r2,
void writeOut128Bytes_char(global char *out, uint index, uint r1, uint r2,
uint r3, uint r4) {
out[index] = (r1)&0x1;
out[index + THREADS] = (r1 >> 1) & 0x1;
out[index + 2 * THREADS] = (r1 >> 2) & 0x1;
out[index + 3 * THREADS] = (r1 >> 3) & 0x1;
out[index + THREADS] = (r1 >> 8) & 0x1;
out[index + 2 * THREADS] = (r1 >> 16) & 0x1;
out[index + 3 * THREADS] = (r1 >> 24) & 0x1;
out[index + 4 * THREADS] = (r2)&0x1;
out[index + 5 * THREADS] = (r2 >> 1) & 0x1;
out[index + 6 * THREADS] = (r2 >> 2) & 0x1;
out[index + 7 * THREADS] = (r2 >> 3) & 0x1;
out[index + 5 * THREADS] = (r2 >> 8) & 0x1;
out[index + 6 * THREADS] = (r2 >> 16) & 0x1;
out[index + 7 * THREADS] = (r2 >> 24) & 0x1;
out[index + 8 * THREADS] = (r3)&0x1;
out[index + 9 * THREADS] = (r3 >> 1) & 0x1;
out[index + 10 * THREADS] = (r3 >> 2) & 0x1;
out[index + 11 * THREADS] = (r3 >> 3) & 0x1;
out[index + 9 * THREADS] = (r3 >> 8) & 0x1;
out[index + 10 * THREADS] = (r3 >> 16) & 0x1;
out[index + 11 * THREADS] = (r3 >> 24) & 0x1;
out[index + 12 * THREADS] = (r4)&0x1;
out[index + 13 * THREADS] = (r4 >> 1) & 0x1;
out[index + 14 * THREADS] = (r4 >> 2) & 0x1;
out[index + 15 * THREADS] = (r4 >> 3) & 0x1;
out[index + 13 * THREADS] = (r4 >> 8) & 0x1;
out[index + 14 * THREADS] = (r4 >> 16) & 0x1;
out[index + 15 * THREADS] = (r4 >> 24) & 0x1;
}

void writeOut128Bytes_short(global short *out, uint index, uint r1, uint r2,
Expand Down Expand Up @@ -187,44 +187,44 @@ void partialWriteOut128Bytes_uchar(global uchar *out, uint index, uint r1,
void partialWriteOut128Bytes_char(global char *out, uint index, uint r1,
uint r2, uint r3, uint r4, uint elements) {
if (index < elements) { out[index] = (r1)&0x1; }
if (index + THREADS < elements) { out[index + THREADS] = (r1 >> 1) & 0x1; }
if (index + THREADS < elements) { out[index + THREADS] = (r1 >> 8) & 0x1; }
if (index + 2 * THREADS < elements) {
out[index + 2 * THREADS] = (r1 >> 2) & 0x1;
out[index + 2 * THREADS] = (r1 >> 16) & 0x1;
}
if (index + 3 * THREADS < elements) {
out[index + 3 * THREADS] = (r1 >> 3) & 0x1;
out[index + 3 * THREADS] = (r1 >> 24) & 0x1;
}
if (index + 4 * THREADS < elements) { out[index + 4 * THREADS] = (r2)&0x1; }
if (index + 5 * THREADS < elements) {
out[index + 5 * THREADS] = (r2 >> 1) & 0x1;
out[index + 5 * THREADS] = (r2 >> 8) & 0x1;
}
if (index + 6 * THREADS < elements) {
out[index + 6 * THREADS] = (r2 >> 2) & 0x1;
out[index + 6 * THREADS] = (r2 >> 16) & 0x1;
}
if (index + 7 * THREADS < elements) {
out[index + 7 * THREADS] = (r2 >> 3) & 0x1;
out[index + 7 * THREADS] = (r2 >> 24) & 0x1;
}
if (index + 8 * THREADS < elements) { out[index + 8 * THREADS] = (r3)&0x1; }
if (index + 9 * THREADS < elements) {
out[index + 9 * THREADS] = (r3 >> 1) & 0x1;
out[index + 9 * THREADS] = (r3 >> 8) & 0x1;
}
if (index + 10 * THREADS < elements) {
out[index + 10 * THREADS] = (r3 >> 2) & 0x1;
out[index + 10 * THREADS] = (r3 >> 16) & 0x1;
}
if (index + 11 * THREADS < elements) {
out[index + 11 * THREADS] = (r3 >> 3) & 0x1;
out[index + 11 * THREADS] = (r3 >> 24) & 0x1;
}
if (index + 12 * THREADS < elements) {
out[index + 12 * THREADS] = (r4)&0x1;
}
if (index + 13 * THREADS < elements) {
out[index + 13 * THREADS] = (r4 >> 1) & 0x1;
out[index + 13 * THREADS] = (r4 >> 8) & 0x1;
}
if (index + 14 * THREADS < elements) {
out[index + 14 * THREADS] = (r4 >> 2) & 0x1;
out[index + 14 * THREADS] = (r4 >> 16) & 0x1;
}
if (index + 15 * THREADS < elements) {
out[index + 15 * THREADS] = (r4 >> 3) & 0x1;
out[index + 15 * THREADS] = (r4 >> 24) & 0x1;
}
}

Expand Down
Loading