Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -3813,26 +3813,6 @@
kwarg_only: True
- THTensor* self
]]
[[
name: _standard_gamma
types:
- floating_point
backends:
- CPU
return: argument 0
variants:
- method
- function
options:
- cname: standard_gamma
arguments:
- arg: THTensor* output
output: True
- arg: THGenerator* generator
default: nullptr
kwarg_only: True
- THTensor* self
]]
[[
name: _dirichlet_grad
types:
Expand Down
124 changes: 36 additions & 88 deletions aten/src/ATen/native/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
#include "ATen/CPUGenerator.h"
#include "ATen/CheckGenerator.h"
#include "ATen/Generator.h"
#include "ATen/native/Distributions.h"

#include <functional>

#include "TH/THRandom.h"
#include "TH/THGenerator.h"
#include "TH/THMath.h"

namespace {
Expand Down Expand Up @@ -100,85 +104,6 @@ int64_t sample_poisson(double lambda, THGenerator* generator) {
}
}

template <typename scalar>
static inline scalar digamma_one(scalar x) {
throw std::runtime_error("digamma is only implemented for float, double");
}

template <>
inline double digamma_one<double>(double x) {
return TH_digamma(x);
}

template <>
inline float digamma_one<float>(float x) {
return TH_digammaf(x);
}

// Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
// for random number x drawn from a standard Gamma distribution Gamma(alpha).
template <typename scalar_t>
scalar_t standard_gamma_grad_one(scalar_t alpha, scalar_t x) {
// Use a Taylor series expansion for small x.
if (x < 0.8f) {
scalar_t numer = 1;
scalar_t denom = alpha;
auto series1 = numer / denom;
auto series2 = numer / (denom * denom);
for (int i = 1; i <= 5; ++i) {
numer *= -x / i;
denom += 1;
series1 += numer / denom;
series2 += numer / (denom * denom);
}
const auto pow_x_alpha = std::pow(x, alpha);
const auto gamma_pdf = std::pow(x, alpha - 1) * std::exp(-x);
const auto gamma_cdf = pow_x_alpha * series1;
const auto gamma_cdf_alpha = (std::log(x) - digamma_one(alpha)) * gamma_cdf
- pow_x_alpha * series2;
const auto result = -gamma_cdf_alpha / gamma_pdf;
return std::isnan(result) ? 0 : result;
}

// Use a Rice saddle point expansion for large alpha.
if (alpha > 8.0f) {
if (0.9f * alpha <= x && x <= 1.1f * alpha) {
const auto numer_1 = 1 + 24 * alpha * (1 + 12 * alpha);
const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x)
- 65 * x * x / alpha + alpha * (107 + 3600 * x);
const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha);
return numer_1 * numer_2 / denom;
}
const auto denom = std::sqrt(8 * alpha);
const auto term2 = denom / (alpha - x);
const auto term3 = std::pow(x - alpha - alpha * std::log(x / alpha), -1.5f);
const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3;
const auto term1 = std::log(x / alpha) * term23
- std::sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x));
const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha));
const auto numer = x * term1;
return -stirling * numer / denom;
}

// Use a bivariate rational approximation to the reparameterized gradient.
const auto u = std::log(x / alpha);
const auto v = std::log(alpha);
static const scalar_t coef_uv[3][8] = {
{0.16009398, -0.094634809, 0.025146376, -0.0030648343,
1, 0.32668115, 0.10406089, 0.0014179084},
{0.53487893, 0.1298071, 0.065735949, -0.0015649758,
0.16639465, 0.020070113, -0.0035938915, -0.00058392623},
{0.040121004, -0.0065914022, -0.0026286047, -0.0013441777,
0.017050642, -0.0021309326, 0.00085092367, -1.5247877e-07},
};
scalar_t coef_v[8];
for (int i = 0; i < 8; ++ i) {
coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
}
const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
return std::exp(p / q);
}
} // namespace

namespace at {
Expand All @@ -198,28 +123,51 @@ Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "_standard_gamma_grad", [&] {
CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>(ret, self, output,
[](scalar_t& ret_val, const scalar_t& self_val, const scalar_t &output_val) {
ret_val = standard_gamma_grad_one(self_val, output_val);
ret_val = standard_gamma_grad_one<scalar_t, double>(self_val, output_val);
}
);
});
return ret;
}

Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
AT_ERROR("_standard_gamma_grad is not implemented for CUDA types");
}
/*
* This section is a counterpart to Distributions.cu
*/

Tensor _s_poisson_cpu(const Tensor& lambda, Generator *gen) {
Tensor ret = at::zeros(lambda.type(), lambda.sizes());
auto lambda_ = lambda.toType(ScalarType::Double);
AT_DISPATCH_FLOATING_TYPES(ret.type(), "poisson", [&] {
THGenerator* generator = get_generator(gen);
CPU_tensor_apply2<scalar_t, double>(ret, lambda_,
[generator](scalar_t& ret_val, const double& lambda){
ret_val = sample_poisson(lambda, generator);
std::lock_guard<std::mutex> lock(generator->mutex);
CPU_tensor_apply2<scalar_t, scalar_t>(ret, lambda,
[generator](scalar_t& ret_val, const scalar_t& lambda){
ret_val = static_cast<scalar_t>(sample_poisson(static_cast<double>(lambda), generator));
}
);
});
});
return ret;
}

Tensor _s_gamma_cpu(const Tensor& alpha, Generator *gen) {
Tensor ret = alpha.type().zeros(alpha.sizes());
AT_DISPATCH_FLOATING_TYPES(ret.type(), "gamma", [&] {
THGenerator* generator = get_generator(gen);
std::lock_guard<std::mutex> lock(generator->mutex);
CPU_tensor_apply2<scalar_t, scalar_t>(ret, alpha,
[generator](scalar_t& ret_val, const scalar_t& alpha){
BaseSampler<double> standard_uniform([generator] () {
return THRandom_standard_uniform(generator);
});
BaseSampler<double> standard_normal([generator] () {
return THRandom_normal(generator, 0.0, 1.0);
});
auto sample = sample_gamma<scalar_t, double>(alpha, standard_uniform, standard_normal);
ret_val = std::max(std::numeric_limits<scalar_t>::min(), (scalar_t) sample);
}
);
});

return ret;
}

}} // namespace at::native
Loading