Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -2752,24 +2752,6 @@
kwarg_only: True
- double p
]]
[[
name: _th_dirichlet_grad
types:
- floating_point
backends:
- CPU
return: argument 0
variants:
- function
options:
- cname: dirichlet_grad
arguments:
- arg: THTensor* output
output: True
- THTensor* x
- THTensor* alpha
- THTensor* total
]]

# In theory, this could be a part of the above declaration. But in
# practice this leads to all sorts of problems with ambiguous overloads.
Expand Down
12 changes: 12 additions & 0 deletions aten/src/ATen/native/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,18 @@ Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) {
return ret;
}

Tensor _dirichlet_grad_cpu(const Tensor& x, const Tensor& alpha, const Tensor& total) {
Tensor ret = at::empty(x.sizes(), x.options());
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "_dirichlet_grad_cpu", [&] {
CPU_tensor_apply4<scalar_t, scalar_t, scalar_t, scalar_t>(ret, x, alpha, total,
[](scalar_t& ret_val, const scalar_t& x_val, const scalar_t& alpha_val, const scalar_t& total_val) {
ret_val = dirichlet_grad_one<scalar_t, double>(x_val, alpha_val, total_val);
}
);
});
return ret;
}

/*
* This section is a counterpart to Distributions.cu
*/
Expand Down
142 changes: 142 additions & 0 deletions aten/src/ATen/native/Distributions.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,146 @@ C10_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) {
return static_cast<scalar_t>(compat_exp(p / q));
}

// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
// Assumes x is close to zero and uses a Taylor expansion.
template <typename scalar_t, typename accscalar_t>
C10_DEVICE static inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) {
const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha)
- digamma_one<scalar_t, accscalar_t>(alpha + beta) - compat_log(x);
scalar_t numer = 1;
scalar_t series = numer / alpha * (factor + 1 / alpha);
for (int i = 1; i <= 10; ++i) {
scalar_t casted_i = static_cast<scalar_t>(i);
numer *= (casted_i - beta) * x / casted_i;
const scalar_t denom = alpha + casted_i;
series += numer / denom * (factor + 1 / denom);
}
const scalar_t result = x * compat_pow(1 - x, -beta) * series;
return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
}

// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta.
// Assumes x is close to zero and uses a Taylor expansion.
template <typename scalar_t, typename accscalar_t>
C10_DEVICE static inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) {
const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha + beta) - digamma_one<scalar_t, accscalar_t>(beta);
scalar_t numer = 1, betas = 1, dbetas = 0, series = factor / alpha;
for (int i = 1; i <= 8; ++i) {
scalar_t casted_i = static_cast<scalar_t>(i);
numer *= -x / casted_i;
dbetas = dbetas * (beta - casted_i) + betas;
betas = betas * (beta - casted_i);
series += numer / (alpha + casted_i) * (dbetas + factor * betas);
}
const scalar_t result = -compat_pow(1 - x, 1 - beta) * series;
return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
}

// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
// Assumes alpha and beta are both large and uses a Rice saddle point expansion.
// To ensure numerical stability, this computation is performed at higher precision.
template<typename scalar_t, typename accscalar_t>
C10_DEVICE static inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
const accscalar_t total = alpha + beta;
const accscalar_t mean = alpha / total;
const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total;
if (mean - 0.1 * std <= x && x <= mean + 0.1 * std) {
// Avoid the singularity at x = mean.
const accscalar_t poly = 47 * x * (beta * beta) * (beta * beta) + alpha * (
(43 + 20 * (16 + 27 * beta) * x) * (beta * beta) * beta + alpha * (
3 * (59 + 180 * beta - 90 * x) * (beta * beta) + alpha * (
(453 + 1620 * beta * (1 - x) - 455 * x) * beta + alpha * (
8 * (1 - x) * (135 * beta - 11)))));
const accscalar_t prefactor_num = (1 + 12 * alpha) * (1 + 12 * beta) / (total * total);
const accscalar_t prefactor_den = 12960 * alpha * alpha * alpha * beta * beta * (1 + 12 * total);
return prefactor_num / (1 - x) * poly / prefactor_den;
}
const accscalar_t prefactor = -x / compat_sqrt(2 * alpha * beta / total);
const accscalar_t stirling = (1 + 1 / (12 * alpha) + 1 / (288 * alpha * alpha))
* (1 + 1 / (12 * beta) + 1 / (288 * beta * beta))
/ (1 + 1 / (12 * total) + 1 / (288 * total * total));
const accscalar_t term1_num = 2 * (alpha * alpha) * (x - 1) + alpha * beta * (x - 1) - x * (beta * beta);
const accscalar_t axbx = alpha * (x - 1) + beta * x;
const accscalar_t term1_den = compat_sqrt(2 * alpha / beta) * compat_pow(total, static_cast<accscalar_t>(1.5f)) * axbx * axbx;
const accscalar_t term1 = term1_num / term1_den;
const accscalar_t term2 = 0.5f * compat_log(alpha / (total * x));
const accscalar_t term3_num = compat_sqrt(8 * alpha * beta / total);
const accscalar_t term3_den = beta * x + alpha * (x - 1);
const accscalar_t term3 = term3_num / term3_den;
const accscalar_t term4_base = beta * compat_log(beta / (total * (1 - x))) +
alpha * compat_log(alpha / (total * x));
const accscalar_t term4 = compat_pow(term4_base, static_cast<accscalar_t>(-1.5f));
const accscalar_t term1234 = term1 + term2 * (term3 + (x < mean ? term4 : -term4));
return static_cast<scalar_t>(stirling * prefactor * term1234);
}

// Computes a scaled reparameterized gradient
// -(d/dalpha cdf(x;alpha,beta)) / pdf(x;alpha,beta) / (1-x)
// for random number x drawn from a Beta distribution Beta(alpha,beta).
// This function inputs total=alpha+beta to make it easy to implement
// Dirichlet reparameterized gradients in terms of Betas.
template<typename scalar_t, typename accscalar_t>
C10_DEVICE static inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
accscalar_t x_ = static_cast<accscalar_t>(x);
accscalar_t alpha_ = static_cast<accscalar_t>(alpha);
accscalar_t total_ = static_cast<accscalar_t>(total);

const scalar_t beta = total - alpha;
const accscalar_t beta_ = static_cast<accscalar_t>(beta);
const scalar_t boundary = total * x * (1 - x);

// Use an asymptotic approximation for x close to 0.
if (x <= 0.5f && boundary < 2.5f) {
return _beta_grad_alpha_small<scalar_t, accscalar_t>(x, alpha, beta);
}

// Use an asymptotic approximation for x close to 1.
if (x >= 0.5f && boundary < 0.75f) {
return -_beta_grad_beta_small<scalar_t, accscalar_t>(1 - x, beta, alpha);
}

// Use an asymptotic approximation when alpha and (total - alpha) are both large.
if (alpha > 6 && beta > 6) {
return _beta_grad_alpha_mid<scalar_t, accscalar_t>(x_, alpha_, beta_);
}

// Use a rational correction to an analytic approximation.
static const scalar_t c[2][3][3][4] = {
{{{1.003668233, -0.01061107488, -0.0657888334, 0.01201642863},
{0.6336835991, -0.3557432599, 0.05486251648, -0.001465281033},
{-0.03276231906, 0.004474107445, 0.002429354597, -0.0001557569013}},
{{0.221950385, -0.3187676331, 0.01799915743, 0.01074823814},
{-0.2951249643, 0.06219954479, 0.01535556598, 0.001550077057},
{0.02155310298, 0.004170831599, 0.001292462449, 6.976601077e-05}},
{{-0.05980841433, 0.008441916499, 0.01085618172, 0.002319392565},
{0.02911413504, 0.01400243777, -0.002721828457, 0.000751041181},
{0.005900514878, -0.001936558688, -9.495446725e-06, 5.385558597e-05}}},
{{{1, -0.02924021934, -0.04438342661, 0.007285809825},
{0.6357567472, -0.3473456711, 0.05454656494, -0.002407477521},
{-0.03301322327, 0.004845219414, 0.00231480583, -0.0002307248149}},
{{0.5925320577, -0.1757678135, 0.01505928619, 0.000564515273},
{0.1014815858, -0.06589186703, 0.01272886114, -0.0007316646956},
{-0.007258481865, 0.001096195486, 0.0003934994223, -4.12701925e-05}},
{{0.06469649321, -0.0236701437, 0.002902096474, -5.896963079e-05},
{0.001925008108, -0.002869809258, 0.0008000589141, -6.063713228e-05},
{-0.0003477407336, 6.959756487e-05, 1.097287507e-05, -1.650964693e-06}}},
};
const scalar_t u = compat_log(x);
const scalar_t a = compat_log(alpha) - u;
const scalar_t b = compat_log(total) - a;
const scalar_t pow_u[3] = {1, u, u * u};
const scalar_t pow_a[3] = {1, a, a * a};
scalar_t p = 0.0;
scalar_t q = 0.0;
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
const scalar_t ua = pow_u[i] * pow_a[j];
p += ua * (c[0][i][j][0] + b * (c[0][i][j][1] + b * (c[0][i][j][2] + b * c[0][i][j][3])));
q += ua * (c[1][i][j][0] + b * (c[1][i][j][1] + b * (c[1][i][j][2] + b * c[1][i][j][3])));
}
}
const accscalar_t approx = x * digamma_one<accscalar_t, accscalar_t>(total) - digamma_one<accscalar_t, accscalar_t>(alpha) / beta;
return p / q * static_cast<scalar_t>(approx);
}

} // namespace
22 changes: 22 additions & 0 deletions aten/src/ATen/native/cuda/Distributions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,20 @@ void gamma_grad_cuda_kernel(
});
}

template <typename scalar_t>
void dirichlet_grad_cuda_kernel(
at::Tensor& ret,
const at::Tensor& x,
const at::Tensor& alpha,
const at::Tensor& total) {
using accscalar_t = at::acc_type<scalar_t, true>;
at::cuda::CUDA_tensor_apply4<scalar_t, scalar_t, scalar_t, scalar_t>(
ret, x, alpha, total,
[] __device__ (scalar_t& ret_val, const scalar_t& x_val, const scalar_t& alpha_val, const scalar_t& total_val) {
ret_val = dirichlet_grad_one<scalar_t, accscalar_t>(x_val, alpha_val, total_val);
});
}

template<typename scalar_t, typename prob_t>
void bernoulli_tensor_cuda_kernel(
at::Tensor& ret, const at::Tensor& p,
Expand Down Expand Up @@ -381,6 +395,14 @@ Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
return ret;
}

Tensor _dirichlet_grad_cuda(const Tensor& x, const Tensor& alpha, const Tensor& total) {
Tensor ret = at::empty(x.sizes(), x.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "_dirichlet_grad_cuda", [&] {
dirichlet_grad_cuda_kernel<scalar_t>(ret, x, alpha, total);
});
return ret;
}

Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator* gen) {
auto p = std::get<0>(expand_inplace(self, p_.to(kCUDA)));
AT_DISPATCH_ALL_TYPES_AND(
Expand Down
13 changes: 5 additions & 8 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2084,6 +2084,11 @@
CPU: _s_gamma_cpu
CUDA: _s_gamma_cuda

- func: _dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor
dispatch:
CPU: _dirichlet_grad_cpu
CUDA: _dirichlet_grad_cuda

- func: _sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor
variants: function
dispatch:
Expand Down Expand Up @@ -3959,14 +3964,6 @@
CPU: legacy::cpu::_th_alias
CUDA: legacy::cuda::_th_alias

- func: _dirichlet_grad(Tensor x, Tensor alpha, Tensor total, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: legacy::cpu::_th_dirichlet_grad_out

- func: _dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor
dispatch:
CPU: legacy::cpu::_th_dirichlet_grad

- func: _addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
dispatch:
CPU: legacy::cpu::_th_addr
Expand Down
2 changes: 0 additions & 2 deletions aten/src/TH/generic/THTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,6 @@ TH_API accreal THTensor_(meanall)(THTensor *self);
TH_API accreal THTensor_(varall)(THTensor *self, int biased);
TH_API accreal THTensor_(stdall)(THTensor *self, int biased);
TH_API accreal THTensor_(normall)(THTensor *t, scalar_t value);

TH_API void THTensor_(dirichlet_grad)(THTensor *self, THTensor *x, THTensor *alpha, THTensor *total);
#endif

#endif
Expand Down
Loading