Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -3968,26 +3968,6 @@
kwarg_only: True
- THTensor* self
]]
[[
name: _standard_gamma
types:
- floating_point
backends:
- CPU
return: argument 0
variants:
- method
- function
options:
- cname: standard_gamma
arguments:
- arg: THTensor* output
output: True
- arg: THGenerator* generator
default: nullptr
kwarg_only: True
- THTensor* self
]]
[[
name: tensor
return: THTensor*
Expand Down
27 changes: 27 additions & 0 deletions aten/src/ATen/native/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "ATen/CheckGenerator.h"
#include "ATen/Generator.h"

#include <ATen/native/Distributions.cuh>

#include "TH/THRandom.h"

namespace at {
Expand Down Expand Up @@ -155,6 +157,24 @@ namespace dist {
return gen_->generator;
}

template <typename scalar>
struct GammaOp {
static void apply(Tensor& ret, const Tensor& alpha, THGenerator *generator) {
CPU_tensor_apply2<scalar, double>(ret, alpha,
[generator](scalar& ret_val, const double& alpha){
dist::baseSampler<float> standard_uniform([generator] () {
return THRandom_standard_uniform(generator);
});
dist::baseSampler<float> standard_normal([generator] () {
return THRandom_normal(generator, 0.0, 1.0);
});
auto sample = dist::sample_gamma<float>(alpha, standard_uniform, standard_normal);
ret_val = std::max(std::numeric_limits<scalar>::min(), (scalar) sample);
}
);
}
};

template <typename scalar>
struct PoissonOp {
static int64_t sample_poisson(double lambda, THGenerator *generator) {
Expand Down Expand Up @@ -227,5 +247,12 @@ Tensor _s_poisson_cpu(const Tensor& lambda, Generator *gen) {
return ret;
}

Tensor _s_gamma_cpu(const Tensor& alpha, Generator *gen) {
Tensor ret = alpha.type().zeros(alpha.sizes());
auto alpha_ = alpha.toType(ScalarType::Double);
dispatch_floating_types<void, dist::GammaOp>(ret.type(), "gamma", ret, alpha_, dist::get_generator(gen));
return ret;
}

} // at::native
} // at
63 changes: 63 additions & 0 deletions aten/src/ATen/native/Distributions.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include "ATen/Config.h"
#include <functional>
#if AT_CUDA_ENABLED()
#include <nvfunctional>
#endif

namespace at {
namespace native {
namespace dist {

// this wraps sampling primitives to expose a common interface
template<typename precision_t>
struct baseSampler {
#if AT_CUDA_ENABLED()
nvstd::function<precision_t(void)> sampler;
__device__ baseSampler(nvstd::function<precision_t(void)> sampler): sampler(sampler) {}
__device__ precision_t sample() {
return sampler();
}
#else
std::function<precision_t(void)> sampler;
baseSampler(std::function<precision_t(void)> sampler): sampler(sampler) {}
precision_t sample() {
return sampler();
}
#endif
};

template<typename precision_t>
#if AT_CUDA_ENABLED()
__host__ __device__
#endif
precision_t sample_gamma(precision_t alpha, baseSampler<precision_t>& standard_uniform, baseSampler<precision_t>& standard_normal) {
precision_t scale = 1.0;

// Boost alpha for higher acceptance probability.
if (alpha < 1.0) {
scale *= ::pow(1 - standard_uniform.sample(), 1.0 / alpha);
alpha += 1.0;
}

// This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
// doi:10.1145/358407.358414
const precision_t d = alpha - 1.0 / 3.0;
const precision_t c = 1.0 / ::sqrt(9.0 * d);
for (;;) {
precision_t x, y;
do {
x = standard_normal.sample();
y = 1.0 + c * x;
} while (y <= 0);
const precision_t v = y * y * y;
const precision_t u = 1 - standard_uniform.sample();
const precision_t xx = x * x;
if (u < 1.0 - 0.0331 * xx * xx)
return scale * d * v;
if (::log(u) < 0.5 * xx + d * (1.0 - v + ::log(v)))
return scale * d * v;
}
}
} // dist
} // native
} // at
31 changes: 31 additions & 0 deletions aten/src/ATen/native/cuda/Distributions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
#include <utility>
#include <functional>
#include <nvfunctional>

#include "ATen/native/Distributions.cuh"

#include <TH/THAtomic.h>

Expand All @@ -26,6 +30,26 @@ namespace dist {
return std::make_pair(gen_->initial_seed, offset);
}

template <typename scalar>
struct GammaOpCUDA {
static void apply(Tensor& ret, const Tensor& alpha, std::pair<uint64_t, uint64_t> seeds) {
at::cuda::CUDA_tensor_apply2<scalar, float>(ret, alpha,
[seeds] __device__ (scalar& ret_val, const float& alpha, bool early_exit) {
curandStatePhilox4_32_10_t state;
curand_init(seeds.first, blockIdx.x * blockDim.x + threadIdx.x, seeds.second, &state);
baseSampler<float> standard_uniform([&state] __device__ () {
return curand_uniform(&state);
});
baseSampler<float> standard_normal([&state] __device__ () {
return curand_normal(&state);
});
auto sample = scalar_cast<scalar>(sample_gamma<float>(alpha, standard_uniform, standard_normal));
ret_val = ::max(THCNumerics<scalar>::min(), (scalar) sample);
}
);
}
};

template <typename scalar>
struct PoissonOpCUDA {
static void apply(Tensor& ret, const Tensor& lambda, std::pair<uint64_t, uint64_t> seeds) {
Expand All @@ -48,5 +72,12 @@ Tensor _s_poisson_cuda(const Tensor& lambda, Generator* gen) {
return ret;
}

Tensor _s_gamma_cuda(const Tensor& alpha, Generator* gen) {
Tensor ret = alpha.type().tensor(alpha.sizes());
auto alpha_ = alpha.toType(ScalarType::Float);
dispatch_floating_types<void, dist::GammaOpCUDA>(ret.type(), "gamma", ret, alpha_, dist::next_philox_seed(gen));
return ret;
}

} // at::native
} // at
6 changes: 6 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,9 @@
dispatch:
CPU: _s_poisson_cpu
CUDA: _s_poisson_cuda

- func: standard_gamma(Tensor self, Generator* generator=nullptr) -> Tensor
variants: function
dispatch:
CPU: _s_gamma_cpu
CUDA: _s_gamma_cuda
29 changes: 0 additions & 29 deletions aten/src/TH/THRandom.c
Original file line number Diff line number Diff line change
Expand Up @@ -281,35 +281,6 @@ double THRandom_exponential(THGenerator *_generator, double lambda)
return(-1. / lambda * log(1-uniform_double(_generator)));
}

double THRandom_standard_gamma(THGenerator *_generator, double alpha) {
double scale = 1.0;

// Boost alpha for higher acceptance probability.
if(alpha < 1.0) {
scale *= pow(1 - uniform_double(_generator), 1.0 / alpha);
alpha += 1.0;
}

// This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
// doi:10.1145/358407.358414
const double d = alpha - 1.0 / 3.0;
const double c = 1.0 / sqrt(9.0 * d);
for(;;) {
double x, y;
do {
x = THRandom_normal(_generator, 0.0, 1.0);
y = 1.0 + c * x;
} while(y <= 0);
const double v = y * y * y;
const double u = 1 - uniform_double(_generator);
const double xx = x * x;
if(u < 1.0 - 0.0331 * xx * xx)
return scale * d * v;
if(log(u) < 0.5 * xx + d * (1.0 - v + log(v)))
return scale * d * v;
}
}

double THRandom_cauchy(THGenerator *_generator, double median, double sigma)
{
return(median + sigma * tan(M_PI*(uniform_double(_generator)-0.5)));
Expand Down
6 changes: 0 additions & 6 deletions aten/src/TH/THRandom.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@ TH_API double THRandom_normal(THGenerator *_generator, double mean, double stdv)
*/
TH_API double THRandom_exponential(THGenerator *_generator, double lambda);

/** Generates a random number from a standard Gamma distribution.
The Gamma density is proportional to $x^{alpha-1} exp(-x)$
The shape parameter alpha (a.k.a. k) is a positive real number.
*/
TH_API double THRandom_standard_gamma(THGenerator *_generator, double alpha);

/** Returns a random number from a Cauchy distribution.
The Cauchy density is $p(x) = sigma/(pi*(sigma^2 + (x-median)^2))$
*/
Expand Down
9 changes: 0 additions & 9 deletions aten/src/TH/generic/THTensorRandom.c
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,6 @@ void THTensor_(exponential)(THTensor *self, THGenerator *_generator, double lamb
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_exponential(_generator, lambda););
}

void THTensor_(standard_gamma)(THTensor *self, THGenerator *gen, THTensor *alpha)
{
THTensor_(resizeAs)(self, alpha);
TH_TENSOR_APPLY2(real, self, real, alpha, {
const real sample = THRandom_standard_gamma(gen, *alpha_data);
*self_data = sample > 0 ? sample : TH_REAL_MIN;
});
}

#undef TH_REAL_MIN

void THTensor_(cauchy)(THTensor *self, THGenerator *_generator, double median, double sigma)
Expand Down
1 change: 0 additions & 1 deletion aten/src/TH/generic/THTensorRandom.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ TH_API void THTensor_(normal_means)(THTensor *self, THGenerator *gen, THTensor *
TH_API void THTensor_(normal_stddevs)(THTensor *self, THGenerator *gen, double mean, THTensor *stddevs);
TH_API void THTensor_(normal_means_stddevs)(THTensor *self, THGenerator *gen, THTensor *means, THTensor *stddevs);
TH_API void THTensor_(exponential)(THTensor *self, THGenerator *_generator, double lambda);
TH_API void THTensor_(standard_gamma)(THTensor *self, THGenerator *_generator, THTensor *alpha);
TH_API void THTensor_(cauchy)(THTensor *self, THGenerator *_generator, double median, double sigma);
TH_API void THTensor_(logNormal)(THTensor *self, THGenerator *_generator, double mean, double stdv);
TH_API void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTensor *prob_dist, int n_sample, int with_replacement);
Expand Down
13 changes: 12 additions & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def test_poisson_sample(self):
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_poisson_gpu_sample(self):
set_rng_seed(0)
set_rng_seed(1)

This comment was marked as off-topic.

for rate in [0.12, 0.9, 4.0]:
self._check_sampler_discrete(Poisson(torch.Tensor([rate]).cuda()),
scipy.stats.poisson(rate),
Expand Down Expand Up @@ -832,6 +832,17 @@ def test_gamma_sample(self):
scipy.stats.gamma(alpha, scale=1.0 / beta),
'Gamma(concentration={}, rate={})'.format(alpha, beta))

@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_gamma_gpu_sample(self):
set_rng_seed(0)
for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
a, b = torch.Tensor([alpha]).cuda(), torch.Tensor([beta]).cuda()
self._check_sampler_sampler(Gamma(a, b),
scipy.stats.gamma(alpha, scale=1.0 / beta),
'Gamma(alpha={}, beta={})'.format(alpha, beta),
failure_rate=1e-4)

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_pareto(self):
scale = Variable(torch.randn(2, 3).abs(), requires_grad=True)
Expand Down
4 changes: 2 additions & 2 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,8 @@
self: not_implemented("_sparse_mask")
mask: not_implemented("_sparse_mask")

- name: _standard_gamma(Tensor self, Generator generator)
self: grad * self._standard_gamma_grad(output)
- name: standard_gamma(Tensor self, Generator generator)
self: grad * self._standard_gamma_grad(result)

- name: _standard_gamma_grad(Tensor self, Tensor output)
self: not_implemented("_standard_gamma_grad")
Expand Down
2 changes: 0 additions & 2 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ IMPLEMENT_STATELESS(bmm)
// TODO: this doesn't implement options that return numbers!
IMPLEMENT_STATELESS(multinomial)
IMPLEMENT_STATELESS(normal)
IMPLEMENT_STATELESS(_standard_gamma)
IMPLEMENT_STATELESS(_dirichlet_grad)
IMPLEMENT_STATELESS(bernoulli)
IMPLEMENT_STATELESS(range)
Expand Down Expand Up @@ -719,7 +718,6 @@ static PyMethodDef TorchMethods[] = {
{"bmm", (PyCFunction)THPModule_bmm, METH_VARARGS | METH_KEYWORDS, NULL},
{"multinomial", (PyCFunction)THPModule_multinomial, METH_VARARGS | METH_KEYWORDS, NULL},
{"normal", (PyCFunction)THPModule_normal, METH_VARARGS | METH_KEYWORDS, NULL},
{"_standard_gamma", (PyCFunction)THPModule__standard_gamma, METH_VARARGS | METH_KEYWORDS, NULL},
{"_dirichlet_grad", (PyCFunction)THPModule__dirichlet_grad, METH_VARARGS | METH_KEYWORDS, NULL},
{"bernoulli", (PyCFunction)THPModule_bernoulli, METH_VARARGS | METH_KEYWORDS, NULL},
{"rand", (PyCFunction)THPModule_rand, METH_VARARGS | METH_KEYWORDS, NULL},
Expand Down
20 changes: 0 additions & 20 deletions torch/csrc/generic/methods/TensorRandom.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -210,26 +210,6 @@
default: 1
]]

[[
name: _standard_gamma
types:
- floating_point
backends:
- CPU
return: argument 0
variants:
- function
options:
- cname: standard_gamma
arguments:
- arg: THTensor* output
output: True
- arg: THGenerator* generator
default: THPGenerator_TH_CData(THPDefaultGenerator)
kwarg_only: True
- THTensor* alpha
]]

[[
name: _dirichlet_grad
types:
Expand Down
3 changes: 2 additions & 1 deletion torch/distributions/dirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from torch.autograd.function import once_differentiable
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.gamma import _standard_gamma
from torch.distributions.utils import _finfo, broadcast_all


def _dirichlet_sample_nograd(concentration):
probs = torch._C._standard_gamma(concentration)
probs = _standard_gamma(concentration)
probs /= probs.sum(-1, True)
eps = _finfo(probs).eps
return probs.clamp_(min=eps, max=1 - eps)
Expand Down
4 changes: 2 additions & 2 deletions torch/distributions/gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

def _standard_gamma(concentration):
if not isinstance(concentration, Variable):
return torch._C._standard_gamma(concentration)
return concentration._standard_gamma()
return torch._C._VariableFunctions.standard_gamma(Variable(concentration)).data
return torch._C._VariableFunctions.standard_gamma(concentration)


class Gamma(Distribution):
Expand Down