Skip to content

Commit d341bcb

Browse files
syed-ahmedfacebook-github-bot
authored andcommitted
Move THCTensor_(exponential) to ATen (#21297)
Summary: Pull Request resolved: #21297 ghimport-source-id: 5f45154 Reviewed By: jerryzh168 Differential Revision: D15632931 Pulled By: ezyang fbshipit-source-id: 0367eec0a9ef6812b1b3ab7597817ee40a011bb8
1 parent 92b76df commit d341bcb

File tree

6 files changed

+51
-24
lines changed

6 files changed

+51
-24
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2720,7 +2720,6 @@
27202720
- floating_point
27212721
backends:
27222722
- CPU
2723-
- CUDA
27242723
cname: exponential
27252724
variants: function
27262725
return: self

aten/src/ATen/native/cuda/Distributions.cu

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,50 @@ void cauchy_kernel_cuda(TensorIterator& iter, double median_, double sigma_, Gen
546546
});
547547
}
548548

549+
void exponential_kernel_cuda(TensorIterator& iter, double lambda_, Generator* gen_) {
550+
auto gen = check_generator<CUDAGenerator>(gen_, &globalContext().defaultGenerator(kCUDA));
551+
// Note that HIP doesn't support std::nextafter in device code.
552+
auto nextafter_1_0_float = std::nextafter(1.0f, 0.0f);
553+
auto nextafter_1_0_double = std::nextafter(1.0, 0.0);
554+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "exponential_cuda", [&] {
555+
using accscalar_t = at::acc_type<scalar_t, true>;
556+
auto lambda = static_cast<accscalar_t>(lambda_);
557+
if (std::is_same<scalar_t, double>::value) {
558+
// define lambda for exponential transformation
559+
auto exponential_func = [lambda, nextafter_1_0_double] __device__ (accscalar_t rand) {
560+
accscalar_t sample;
561+
// curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
562+
// Hence, squash the 1 to just below 1.
563+
if(rand == static_cast<accscalar_t>(1.0)) {
564+
sample = ::log(nextafter_1_0_double);
565+
} else {
566+
sample = ::log(rand);
567+
}
568+
return static_cast<scalar_t>(static_cast<accscalar_t>(-1.0) / lambda * sample);
569+
};
570+
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
571+
gen,
572+
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
573+
exponential_func);
574+
} else {
575+
// use __logf fast approximation for peak bandwidth
576+
auto exponential_func = [lambda, nextafter_1_0_float] __device__ (accscalar_t rand) {
577+
accscalar_t sample;
578+
if(rand == static_cast<accscalar_t>(1.0)) {
579+
sample = __logf(nextafter_1_0_float);
580+
} else {
581+
sample = __logf(rand);
582+
}
583+
return static_cast<scalar_t>(static_cast<accscalar_t>(-1.0) / lambda * sample);
584+
};
585+
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter,
586+
gen,
587+
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
588+
exponential_func);
589+
}
590+
});
591+
}
592+
549593
Tensor& uniform_cuda_(Tensor& self, double from, double to, Generator* gen) {
550594
auto iter = TensorIterator::nullary_op(self);
551595
uniform_kernel_cuda(*iter, from, to, gen);
@@ -631,4 +675,10 @@ Tensor& cauchy_cuda_(Tensor& self, double median, double sigma, Generator* gen)
631675
return self;
632676
}
633677

678+
Tensor& exponential_cuda_(Tensor& self, double lambda, Generator* gen) {
679+
auto iter = TensorIterator::nullary_op(self);
680+
exponential_kernel_cuda(*iter, lambda, gen);
681+
return self;
682+
}
683+
634684
}} // namespace at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3237,7 +3237,7 @@
32373237
variants: method
32383238
dispatch:
32393239
CPU: legacy::cpu::_th_exponential_
3240-
CUDA: legacy::cuda::_th_exponential_
3240+
CUDA: exponential_cuda_
32413241

32423242
- func: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)
32433243
variants: method

aten/src/THC/THCTensorRandom.cu

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,6 @@ __global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1, ARG2)
129129
} \
130130
}
131131

132-
GENERATE_KERNEL1(generate_exponential, float, double lambda, float, curand_uniform, (float)(-1. / lambda * log(x)))
133-
GENERATE_KERNEL1(generate_exponential, double, double lambda, double, curand_uniform_double, (double)(-1. / lambda * log(x)))
134-
135-
GENERATE_KERNEL1(generate_exponential, at::Half, double lambda, float, curand_uniform, (ScalarConvert<float, at::Half>::to((float)(-1. / lambda * log(x)))))
136-
137132
#include <THC/generic/THCTensorRandom.cu>
138133
#include <THC/THCGenerateAllTypes.h>
139134

aten/src/THC/generic/THCTensorRandom.cu

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,6 @@ void THCTensor_(logNormal)(THCState* state, THCTensor *self_, double mean, doubl
2626
THCTensor_(freeCopyTo)(state, self, self_);
2727
};
2828

29-
void THCTensor_(exponential)(THCState* state, THCTensor *self_, double lambda)
30-
{
31-
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self_));
32-
ptrdiff_t size = THCTensor_(nElement)(state, self_);
33-
if (size == 0) return;
34-
THCGenerator* gen = THCRandom_getGenerator(state);
35-
36-
THCTensor *self = THCTensor_(newContiguous)(state, self_);
37-
scalar_t *data = THCTensor_(data)(state, self);
38-
39-
generate_exponential<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
40-
gen->state.gen_states, size, data, lambda);
41-
42-
THCTensor_(freeCopyTo)(state, self, self_);
43-
};
44-
4529
void THCTensor_(renormRows)(struct THCState* state,
4630
THCTensor* t) {
4731
THAssert(THCTensor_(nDimensionLegacyAll)(state, t) == 2);

aten/src/THC/generic/THCTensorRandom.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
66

77
THC_API void THCTensor_(logNormal)(struct THCState *state, THCTensor *self, double mean, double stdv);
8-
THC_API void THCTensor_(exponential)(struct THCState *state, THCTensor *self, double lambda);
98
THC_API void THCTensor_(multinomial)(struct THCState *state, THCudaLongTensor *self, THCTensor *prob_dist, int n_sample, int with_replacement);
109
THC_API void THCTensor_(multinomialAliasSetup)(struct THCState *state, THCTensor *probs, THCudaLongTensor *J, THCTensor *q);
1110
THC_API void THCTensor_(multinomialAliasDraw)(THCState *state, THCudaLongTensor *self, THCTensor *_q, THCudaLongTensor *_J, int n_sample);

0 commit comments

Comments
 (0)