Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -2720,7 +2720,6 @@
- floating_point
backends:
- CPU
- CUDA
cname: exponential
variants: function
return: self
Expand Down
50 changes: 50 additions & 0 deletions aten/src/ATen/native/cuda/Distributions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,50 @@ void cauchy_kernel_cuda(TensorIterator& iter, double median_, double sigma_, Gen
});
}

void exponential_kernel_cuda(TensorIterator& iter, double lambda_, Generator* gen_) {
auto gen = check_generator<CUDAGenerator>(gen_, &globalContext().defaultGenerator(kCUDA));
// Note that HIP doesn't support std::nextafter in device code.
auto nextafter_1_0_float = std::nextafter(1.0f, 0.0f);
auto nextafter_1_0_double = std::nextafter(1.0, 0.0);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "exponential_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto lambda = static_cast<accscalar_t>(lambda_);
if (std::is_same<scalar_t, double>::value) {
// define lambda for exponential transformation
auto exponential_func = [lambda, nextafter_1_0_double] __device__ (accscalar_t rand) {
accscalar_t sample;
// curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
// Hence, squash the 1 to just below 1.
if(rand == static_cast<accscalar_t>(1.0)) {
sample = ::log(nextafter_1_0_double);
} else {
sample = ::log(rand);
}
return static_cast<scalar_t>(static_cast<accscalar_t>(-1.0) / lambda * sample);
};
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
gen,
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
exponential_func);
} else {
// use __logf fast approximation for peak bandwidth
auto exponential_func = [lambda, nextafter_1_0_float] __device__ (accscalar_t rand) {
accscalar_t sample;
if(rand == static_cast<accscalar_t>(1.0)) {
sample = __logf(nextafter_1_0_float);
} else {
sample = __logf(rand);
}
return static_cast<scalar_t>(static_cast<accscalar_t>(-1.0) / lambda * sample);
};
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter,
gen,
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
exponential_func);
}
});
}

Tensor& uniform_cuda_(Tensor& self, double from, double to, Generator* gen) {
auto iter = TensorIterator::nullary_op(self);
uniform_kernel_cuda(*iter, from, to, gen);
Expand Down Expand Up @@ -631,4 +675,10 @@ Tensor& cauchy_cuda_(Tensor& self, double median, double sigma, Generator* gen)
return self;
}

Tensor& exponential_cuda_(Tensor& self, double lambda, Generator* gen) {
auto iter = TensorIterator::nullary_op(self);
exponential_kernel_cuda(*iter, lambda, gen);
return self;
}

}} // namespace at::native
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3227,7 +3227,7 @@
variants: method
dispatch:
CPU: legacy::cpu::_th_exponential_
CUDA: legacy::cuda::_th_exponential_
CUDA: exponential_cuda_

- func: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)
variants: method
Expand Down
5 changes: 0 additions & 5 deletions aten/src/THC/THCTensorRandom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,6 @@ __global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1, ARG2)
} \
}

GENERATE_KERNEL1(generate_exponential, float, double lambda, float, curand_uniform, (float)(-1. / lambda * log(x)))
GENERATE_KERNEL1(generate_exponential, double, double lambda, double, curand_uniform_double, (double)(-1. / lambda * log(x)))

GENERATE_KERNEL1(generate_exponential, at::Half, double lambda, float, curand_uniform, (ScalarConvert<float, at::Half>::to((float)(-1. / lambda * log(x)))))

#include <THC/generic/THCTensorRandom.cu>
#include <THC/THCGenerateAllTypes.h>

Expand Down
16 changes: 0 additions & 16 deletions aten/src/THC/generic/THCTensorRandom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,6 @@ void THCTensor_(logNormal)(THCState* state, THCTensor *self_, double mean, doubl
THCTensor_(freeCopyTo)(state, self, self_);
};

void THCTensor_(exponential)(THCState* state, THCTensor *self_, double lambda)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self_));
ptrdiff_t size = THCTensor_(nElement)(state, self_);
if (size == 0) return;
THCGenerator* gen = THCRandom_getGenerator(state);

THCTensor *self = THCTensor_(newContiguous)(state, self_);
scalar_t *data = THCTensor_(data)(state, self);

generate_exponential<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->state.gen_states, size, data, lambda);

THCTensor_(freeCopyTo)(state, self, self_);
};

void THCTensor_(renormRows)(struct THCState* state,
THCTensor* t) {
THAssert(THCTensor_(nDimensionLegacyAll)(state, t) == 2);
Expand Down
1 change: 0 additions & 1 deletion aten/src/THC/generic/THCTensorRandom.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)

THC_API void THCTensor_(logNormal)(struct THCState *state, THCTensor *self, double mean, double stdv);
THC_API void THCTensor_(exponential)(struct THCState *state, THCTensor *self, double lambda);
THC_API void THCTensor_(multinomial)(struct THCState *state, THCudaLongTensor *self, THCTensor *prob_dist, int n_sample, int with_replacement);
THC_API void THCTensor_(multinomialAliasSetup)(struct THCState *state, THCTensor *probs, THCudaLongTensor *J, THCTensor *q);
THC_API void THCTensor_(multinomialAliasDraw)(THCState *state, THCudaLongTensor *self, THCTensor *_q, THCudaLongTensor *_J, int n_sample);
Expand Down