Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions aten/src/ATen/test/cuda_distributions_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,21 @@ TEST(DistributionsTest, TestPhiloxIncrementBigUniformTensor) {
// expected uniforms will start from counter offset of 8
assert_with_expected_uniforms(8);
}

TEST(DistributionsTest, TestPhiloxIncrementSmallMultinomialTensor) {
// Test Description:
// Same concept as TestPhiloxIncrementSmallUniformTensor.
// Multinomial increments offset by 4. Tests if uniform starts from the correct offset.

// if cuda not available, return
if (!at::cuda::is_available()) return;

// manual seed to 123
at::manual_seed(123);

// get some multinomial samples
at::empty({10}, at::TensorOptions(at::kCUDA)).multinomial(4);

// expected uniforms will start from counter offset of 4
assert_with_expected_uniforms(4);
}
9 changes: 0 additions & 9 deletions aten/src/THC/THCTensorRandom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
#include <random>
#include <curand.h>


void initializeGenerator(THCState *state, THCGenerator* gen);
void createGeneratorState(THCGenerator* gen, uint64_t seed);


Expand Down Expand Up @@ -80,19 +78,12 @@ THCGenerator* THCRandom_getGenerator(THCState* state)
std::lock_guard<std::mutex> lock(gen->mutex);
if (gen->state.initf == 0)
{
initializeGenerator(state, gen);
createGeneratorState(gen, gen->state.initial_seed);
gen->state.initf = 1;
}
return gen;
}

curandStateMtgp32* THCRandom_generatorStates(THCState* state)
{
THCGenerator* gen = THCRandom_getGenerator(state);
return gen->state.gen_states;
}

/* Random seed */
uint64_t THCRandom_seed(THCState* state)
{
Expand Down
40 changes: 5 additions & 35 deletions aten/src/THC/THCTensorRandom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,9 @@

THCGenerator* THCRandom_getGenerator(THCState* state);

/* Sets up generator. Allocates but does not create the generator states. Not thread-safe. */
__host__ void initializeGenerator(THCState *state, THCGenerator* gen)
{
gen->state.gen_states = static_cast<curandStateMtgp32*>(THCudaMalloc(state, MAX_NUM_BLOCKS * sizeof(curandStateMtgp32)));
gen->state.kernel_params = static_cast<mtgp32_kernel_params*>(THCudaMalloc(state, sizeof(mtgp32_kernel_params)));
}

/* Creates a new generator state given the seed. Not thread-safe. */
__host__ void createGeneratorState(THCGenerator* gen, uint64_t seed)
{
if (curandMakeMTGP32Constants(mtgp32dc_params_fast_11213, gen->state.kernel_params) != CURAND_STATUS_SUCCESS)
{
THError("Creating MTGP constants failed.");
}
if (curandMakeMTGP32KernelState(gen->state.gen_states, mtgp32dc_params_fast_11213,
gen->state.kernel_params, MAX_NUM_BLOCKS, seed) != CURAND_STATUS_SUCCESS)
{
THError("Creating MTGP kernel state failed.");
}
// seed and offset for philox
gen->state.initial_seed = seed;
gen->state.philox_seed_offset = 0;
Expand All @@ -47,35 +31,26 @@ THC_API __host__ void THCRandom_getRNGState(THCState* state, THByteTensor *rng_s
THCGenerator* gen = THCRandom_getGenerator(state);
std::lock_guard<std::mutex> lock(gen->mutex);

// The RNG state comprises the MTPG32 states, the seed, and an offset used for Philox
static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32);
// The RNG state comprises the seed, and an offset used for Philox
static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32); // this line is just here for BC reason
static const size_t seed_size = sizeof(gen->state.initial_seed);
static const size_t offset_size = sizeof(gen->state.philox_seed_offset);
static const size_t total_size = states_size + seed_size + offset_size;
THByteTensor_resize1d(rng_state, total_size);
THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size");
THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous");
THCudaCheck(cudaMemcpy(THByteTensor_data(rng_state), gen->state.gen_states,
states_size, cudaMemcpyDeviceToHost));
// since curandStateMTGP is not used anymore, fill gen_states of THCGenerator with deterministic garbage value of -1
memset(THByteTensor_data(rng_state), -1, states_size);
memcpy(THByteTensor_data(rng_state) + states_size, &gen->state.initial_seed, seed_size);
memcpy(THByteTensor_data(rng_state) + states_size + seed_size, &gen->state.philox_seed_offset, offset_size);
}

__global__ void set_rngstate_kernel(curandStateMtgp32 *state, mtgp32_kernel_params *kernel)
{
#ifndef __HIP_PLATFORM_HCC__
state[threadIdx.x].k = kernel;
#else
state[threadIdx.x].set_params(kernel);
#endif
}

THC_API __host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_state)
{
THCGenerator* gen = THCRandom_getGenerator(state);
std::lock_guard<std::mutex> lock(gen->mutex);

static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32);
static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32); // this line is just here for BC reason
static const size_t seed_size = sizeof(gen->state.initial_seed);
static const size_t offset_size = sizeof(gen->state.philox_seed_offset);
static const size_t total_size = states_size + seed_size + offset_size;
Expand All @@ -87,11 +62,6 @@ THC_API __host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_s
THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size");
}
THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous");

THCudaCheck(cudaMemcpy(gen->state.gen_states, THByteTensor_data(rng_state),
states_size, cudaMemcpyHostToDevice));
set_rngstate_kernel<<<1, MAX_NUM_BLOCKS, 0, THCState_getCurrentStream(state)>>>(
gen->state.gen_states, gen->state.kernel_params);
memcpy(&gen->state.initial_seed, THByteTensor_data(rng_state) + states_size, seed_size);
if (!no_philox_seed) {
memcpy(&gen->state.philox_seed_offset, THByteTensor_data(rng_state) + states_size + seed_size, offset_size);
Expand Down
26 changes: 16 additions & 10 deletions aten/src/THC/THCTensorRandom.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ sampleMultinomialOnce(int64_t* dest,

template <typename T>
__global__ void
sampleMultinomialWithReplacement(curandStateMtgp32* state,
sampleMultinomialWithReplacement(std::pair<uint64_t, uint64_t> seeds,
int totalSamples,
int64_t* dest,
int64_t distributions,
Expand All @@ -282,9 +282,11 @@ sampleMultinomialWithReplacement(curandStateMtgp32* state,
T* normDist) {
// At the moment, each warp computes one sample value in the binary
// search due to divergence. It seems possible to compute multiple
// values and limit divergence though later on. However, no matter
// what, all block threads must participate in the curand_uniform
// call to update the generator state.
// values and limit divergence though later on.

int idx = blockIdx.x * blockDim.x * blockDim.y + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seeds.first, idx, seeds.second, &state);

// The block determines the distribution for which we generate a point
for (int64_t curDist = blockIdx.x;
Expand All @@ -296,7 +298,8 @@ sampleMultinomialWithReplacement(curandStateMtgp32* state,
int sample = sampleBase + threadIdx.y;

// All threads participate in this
T r = ScalarConvert<float, T>::to(curand_uniform(&state[blockIdx.x]));
auto rand = curand_uniform4(&state);
T r = ScalarConvert<float, T>::to(rand.x);

if (threadIdx.x == 0 && sample < totalSamples) {
// Find the bucket that a uniform sample lies in
Expand All @@ -315,7 +318,7 @@ sampleMultinomialWithReplacement(curandStateMtgp32* state,

template <typename T>
__global__ void
sampleMultinomialWithoutReplacement(curandStateMtgp32* state,
sampleMultinomialWithoutReplacement(std::pair<uint64_t, uint64_t> seeds,
int totalSamples,
int sample,
int64_t* dest,
Expand All @@ -325,9 +328,11 @@ sampleMultinomialWithoutReplacement(curandStateMtgp32* state,
T* normDistPrefixSum) {
// At the moment, each warp computes one sample value in the binary
// search due to divergence. It seems possible to compute multiple
// values and limit divergence though later on. However, no matter
// what, all block threads must participate in the curand_uniform
// call to update the generator state.
// values and limit divergence though later on.

int idx = blockIdx.x * blockDim.x * blockDim.y + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seeds.first, idx, seeds.second, &state);

// The block and warp determines the distribution for which we
// generate a point
Expand All @@ -338,7 +343,8 @@ sampleMultinomialWithoutReplacement(curandStateMtgp32* state,
int64_t curDist = curDistBase + threadIdx.y;

// All threads must participate in this
T r = ScalarConvert<float, T>::to(curand_uniform(&state[blockIdx.x]));
auto rand = curand_uniform4(&state);
T r = ScalarConvert<float, T>::to(rand.x);

if (threadIdx.x == 0 && curDist < distributions) {
// Find the bucket that a uniform sample lies in
Expand Down
2 changes: 0 additions & 2 deletions aten/src/THC/THCTensorRandom.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,4 @@ THC_API uint64_t THCRandom_initialSeed(struct THCState *state);
THC_API void THCRandom_getRNGState(struct THCState *state, THByteTensor *rng_state);
THC_API void THCRandom_setRNGState(struct THCState *state, THByteTensor *rng_state);

THC_API curandStateMtgp32* THCRandom_generatorStates(struct THCState* state);

#endif
11 changes: 9 additions & 2 deletions aten/src/THC/generic/THCTensorRandom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#else

#include "ATen/cuda/CUDAContext.h"
#include <utility>

#define NUM_BLOCKS min((int)THCCeilDiv(size, (ptrdiff_t) BLOCK_SIZE), MAX_NUM_BLOCKS)

Expand Down Expand Up @@ -130,6 +131,12 @@ void THCTensor_(multinomial)(struct THCState *state,
// Prefix sum along rows
THCTensor_(cumsum)(state, prefixSum, normDist, 1);

// each thread will utilize one random, however, since we have to use
// curand_uniform4 (See Note [Register spilling in curand call for CUDA < 10]),
// offset is 4.
uint64_t offset = gen->state.philox_seed_offset.fetch_add(4);
std::pair<uint64_t, uint64_t> next_philox_seed = std::make_pair(gen->state.initial_seed, offset);

if (with_replacement) {
// Sample with replacement

Expand All @@ -144,7 +151,7 @@ void THCTensor_(multinomial)(struct THCState *state,

sampleMultinomialWithReplacement
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
gen->state.gen_states,
next_philox_seed,
n_sample,
THCudaLongTensor_data(state, self),
numDist, numCategories,
Expand Down Expand Up @@ -178,7 +185,7 @@ void THCTensor_(multinomial)(struct THCState *state,
// recalculate our distribution
sampleMultinomialWithoutReplacement
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
gen->state.gen_states,
next_philox_seed,
n_sample,
sample,
THCudaLongTensor_data(state, self),
Expand Down
26 changes: 17 additions & 9 deletions aten/src/THCUNN/RReLU.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,47 @@
#include <THCUNN/common.h>
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
#include <utility>

// copied from cutorch/lib/THC/THCTensorRandom.cu
#define MAX_NUM_BLOCKS 64
#define BLOCK_SIZE 256
#define NUM_BLOCKS(n) min((int)THCCeilDiv(n, (ptrdiff_t) BLOCK_SIZE), MAX_NUM_BLOCKS)

template<typename T>
inline T __device__ curand_uniform_type(curandStateMtgp32 *state);
inline T __device__ curand_uniform_type(curandStatePhilox4_32_10_t *state);

template <>
inline THHalf __device__ curand_uniform_type<THHalf>(curandStateMtgp32 *state) {
return ScalarConvert<float, THHalf>::to(curand_uniform(state));
inline THHalf __device__ curand_uniform_type<THHalf>(curandStatePhilox4_32_10_t *state) {
auto rand = curand_uniform4(state);
return ScalarConvert<float, THHalf>::to(rand.x);
}

template <>
inline float __device__ curand_uniform_type<float>(curandStateMtgp32 *state) {
return curand_uniform(state);
inline float __device__ curand_uniform_type<float>(curandStatePhilox4_32_10_t *state) {
auto rand = curand_uniform4(state);
return rand.x;
}

template <>
inline double __device__ curand_uniform_type<double>(curandStateMtgp32 *state) {
return curand_uniform_double(state);
inline double __device__ curand_uniform_type<double>(curandStatePhilox4_32_10_t *state) {
auto rand = curand_uniform2_double(state);
return rand.x;
}

template <typename T>
__global__ void rreluUpdateOutputTrain(int n, curandStateMtgp32 *state,
__global__ void rreluUpdateOutputTrain(int n, std::pair<uint64_t, uint64_t> seeds,
T *input, T* noise, T *output, double a, double b)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seeds.first, idx, seeds.second, &state);
CUDA_KERNEL_LOOP(i, n)
{
if (input[i] <= 0)
{
T r = curand_uniform_type<T>(&state[blockIdx.x]);
T r = curand_uniform_type<T>(&state);
r = ScalarConvert<double, T>::to(r * (b-a) + a);
output[i] = input[i] * r;
noise[i] = r;
Expand Down
24 changes: 19 additions & 5 deletions aten/src/THCUNN/generic/RReLU.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
#else

#include <THCUNN/common.h>
#include <THC/THCGeneral.h>
#include <THC/THCTensorRandom.h>
#include <THC/THCGenerator.hpp>
#include <utility>

THCGenerator* THCRandom_getGenerator(THCState* state);

void THNN_(RReLU_updateOutput)(
THCState *state,
Expand All @@ -16,7 +22,7 @@ void THNN_(RReLU_updateOutput)(
void *generator)
{
THCUNN_assertSameGPU(state, 3, input, output, noise);
curandStateMtgp32* gen_states = THCRandom_generatorStates(state);
THCGenerator* gen = THCRandom_getGenerator(state);

if (train)
{
Expand All @@ -25,18 +31,26 @@ void THNN_(RReLU_updateOutput)(
scalar_t *input_data = THCTensor_(data)(state, input);
scalar_t *noise_data = THCTensor_(data)(state, noise);
ptrdiff_t n = THCTensor_(nElement)(state, input);

// philox offset calculation for grid-stride loop utilizing curand4
const uint32_t curand4_engine_calls = 4;
dim3 grid = NUM_BLOCKS(n);
uint64_t counter_offset = ((n - 1) / (BLOCK_SIZE * grid.x) + 1) * curand4_engine_calls;
uint64_t offset = gen->state.philox_seed_offset.fetch_add(counter_offset);
std::pair<uint64_t, uint64_t> next_philox_seed = std::make_pair(gen->state.initial_seed, offset);

if (inplace)
{
rreluUpdateOutputTrain<<<NUM_BLOCKS(n), BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
n, gen_states, input_data, noise_data, input_data, lower, upper);
rreluUpdateOutputTrain<<<grid, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
n, next_philox_seed, input_data, noise_data, input_data, lower, upper);
THCTensor_(set)(state, output, input);
}
else
{
THCTensor_(resizeAs)(state, output, input);
scalar_t *output_data = THCTensor_(data)(state, output);
rreluUpdateOutputTrain<<<NUM_BLOCKS(n), BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
n, gen_states, input_data, noise_data, output_data, lower, upper);
rreluUpdateOutputTrain<<<grid, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
n, next_philox_seed, input_data, noise_data, output_data, lower, upper);
}
THCudaCheck(cudaGetLastError());
THCTensor_(free)(state, input);
Expand Down
3 changes: 1 addition & 2 deletions torch/cuda/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def set_rng_state(new_state, device=device('cuda')):
device (torch.device or int, optional): The device to set the RNG state.
Default: ``torch.device('cuda')`` (i.e., the current CUDA device).
"""
new_state_copy = new_state.clone()

# NB: What if device=-1? You might be afraid that the "current"
# device would change by the time we actually get around to invoking
Expand All @@ -51,7 +50,7 @@ def set_rng_state(new_state, device=device('cuda')):
# immediately.
def cb():
with device_ctx_manager(device):
_C._cuda_setRNGState(new_state_copy)
_C._cuda_setRNGState(new_state)

_lazy_call(cb)

Expand Down