55#include < curand_kernel.h>
66#include < curand_philox4x32_x.h>
77#include < utility>
8+ #include < functional>
9+ #include < nvfunctional>
10+
11+ #include " ATen/native/Distributions.cuh"
812
913#include < TH/THAtomic.h>
1014
@@ -26,6 +30,26 @@ namespace dist {
2630 return std::make_pair (gen_->initial_seed , offset);
2731 }
2832
33+ template <typename scalar>
34+ struct GammaOpCUDA {
35+ static void apply (Tensor& ret, const Tensor& alpha, std::pair<uint64_t , uint64_t > seeds) {
36+ at::cuda::CUDA_tensor_apply2<scalar, float >(ret, alpha,
37+ [seeds] __device__ (scalar& ret_val, const float & alpha, bool early_exit) {
38+ curandStatePhilox4_32_10_t state;
39+ curand_init (seeds.first , blockIdx .x * blockDim .x + threadIdx .x , seeds.second , &state);
40+ baseSampler<float > standard_uniform ([&state] __device__ () {
41+ return curand_uniform (&state);
42+ });
43+ baseSampler<float > standard_normal ([&state] __device__ () {
44+ return curand_normal (&state);
45+ });
46+ auto sample = scalar_cast<scalar>(sample_gamma<float >(alpha, standard_uniform, standard_normal));
47+ ret_val = ::max (THCNumerics<scalar>::min (), (scalar) sample);
48+ }
49+ );
50+ }
51+ };
52+
2953 template <typename scalar>
3054 struct PoissonOpCUDA {
3155 static void apply (Tensor& ret, const Tensor& lambda, std::pair<uint64_t , uint64_t > seeds) {
@@ -48,5 +72,12 @@ Tensor _s_poisson_cuda(const Tensor& lambda, Generator* gen) {
4872 return ret;
4973}
5074
75+ Tensor _s_gamma_cuda (const Tensor& alpha, Generator* gen) {
76+ Tensor ret = alpha.type ().tensor (alpha.sizes ());
77+ auto alpha_ = alpha.toType (ScalarType::Float);
78+ dispatch_floating_types<void , dist::GammaOpCUDA>(ret.type (), " gamma" , ret, alpha_, dist::next_philox_seed (gen));
79+ return ret;
80+ }
81+
5182} // at::native
5283} // at
0 commit comments