Skip to content

Commit c10da63

Browse files
t-viezyang
authored andcommitted
implement gamma cuda (#6855)
* Refactor standard_gamma and implement CUDA gamma sampling * Attempt fixes for AT_CUDA_ENABLED changes * Gamma cuda and cpu forward as ATen native * implement standard_gamma_grad_cuda * update native_test.cpp, try to fix windows and various cuda version compiles * searching a windows fix via CI... use std:: for math * casting some constants in the calculation, compute at float for half precision * whitespace fixes * add acctype to do half->float computation, include HALF in generation, cast locally rather than tensors * fix cuda8 half compilation * always use scalar_cast with CUDACC, lock CPU generator, CPU acctype = double\nThank you for your review comments!
1 parent 7cbef70 commit c10da63

File tree

12 files changed

+391
-189
lines changed

12 files changed

+391
-189
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3843,26 +3843,6 @@
38433843
kwarg_only: True
38443844
- THTensor* self
38453845
]]
3846-
[[
3847-
name: _standard_gamma
3848-
types:
3849-
- floating_point
3850-
backends:
3851-
- CPU
3852-
return: argument 0
3853-
variants:
3854-
- method
3855-
- function
3856-
options:
3857-
- cname: standard_gamma
3858-
arguments:
3859-
- arg: THTensor* output
3860-
output: True
3861-
- arg: THGenerator* generator
3862-
default: nullptr
3863-
kwarg_only: True
3864-
- THTensor* self
3865-
]]
38663846
[[
38673847
name: _dirichlet_grad
38683848
types:

aten/src/ATen/native/Distributions.cpp

Lines changed: 36 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
#include "ATen/CPUGenerator.h"
99
#include "ATen/CheckGenerator.h"
1010
#include "ATen/Generator.h"
11+
#include "ATen/native/Distributions.h"
12+
13+
#include <functional>
1114

1215
#include "TH/THRandom.h"
16+
#include "TH/THGenerator.h"
1317
#include "TH/THMath.h"
1418

1519
namespace {
@@ -100,85 +104,6 @@ int64_t sample_poisson(double lambda, THGenerator* generator) {
100104
}
101105
}
102106

103-
template <typename scalar>
104-
static inline scalar digamma_one(scalar x) {
105-
throw std::runtime_error("digamma is only implemented for float, double");
106-
}
107-
108-
template <>
109-
inline double digamma_one<double>(double x) {
110-
return TH_digamma(x);
111-
}
112-
113-
template <>
114-
inline float digamma_one<float>(float x) {
115-
return TH_digammaf(x);
116-
}
117-
118-
// Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
119-
// for random number x drawn from a standard Gamma distribution Gamma(alpha).
120-
template <typename scalar_t>
121-
scalar_t standard_gamma_grad_one(scalar_t alpha, scalar_t x) {
122-
// Use a Taylor series expansion for small x.
123-
if (x < 0.8f) {
124-
scalar_t numer = 1;
125-
scalar_t denom = alpha;
126-
auto series1 = numer / denom;
127-
auto series2 = numer / (denom * denom);
128-
for (int i = 1; i <= 5; ++i) {
129-
numer *= -x / i;
130-
denom += 1;
131-
series1 += numer / denom;
132-
series2 += numer / (denom * denom);
133-
}
134-
const auto pow_x_alpha = std::pow(x, alpha);
135-
const auto gamma_pdf = std::pow(x, alpha - 1) * std::exp(-x);
136-
const auto gamma_cdf = pow_x_alpha * series1;
137-
const auto gamma_cdf_alpha = (std::log(x) - digamma_one(alpha)) * gamma_cdf
138-
- pow_x_alpha * series2;
139-
const auto result = -gamma_cdf_alpha / gamma_pdf;
140-
return std::isnan(result) ? 0 : result;
141-
}
142-
143-
// Use a Rice saddle point expansion for large alpha.
144-
if (alpha > 8.0f) {
145-
if (0.9f * alpha <= x && x <= 1.1f * alpha) {
146-
const auto numer_1 = 1 + 24 * alpha * (1 + 12 * alpha);
147-
const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x)
148-
- 65 * x * x / alpha + alpha * (107 + 3600 * x);
149-
const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha);
150-
return numer_1 * numer_2 / denom;
151-
}
152-
const auto denom = std::sqrt(8 * alpha);
153-
const auto term2 = denom / (alpha - x);
154-
const auto term3 = std::pow(x - alpha - alpha * std::log(x / alpha), -1.5f);
155-
const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3;
156-
const auto term1 = std::log(x / alpha) * term23
157-
- std::sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x));
158-
const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha));
159-
const auto numer = x * term1;
160-
return -stirling * numer / denom;
161-
}
162-
163-
// Use a bivariate rational approximation to the reparameterized gradient.
164-
const auto u = std::log(x / alpha);
165-
const auto v = std::log(alpha);
166-
static const scalar_t coef_uv[3][8] = {
167-
{0.16009398, -0.094634809, 0.025146376, -0.0030648343,
168-
1, 0.32668115, 0.10406089, 0.0014179084},
169-
{0.53487893, 0.1298071, 0.065735949, -0.0015649758,
170-
0.16639465, 0.020070113, -0.0035938915, -0.00058392623},
171-
{0.040121004, -0.0065914022, -0.0026286047, -0.0013441777,
172-
0.017050642, -0.0021309326, 0.00085092367, -1.5247877e-07},
173-
};
174-
scalar_t coef_v[8];
175-
for (int i = 0; i < 8; ++ i) {
176-
coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
177-
}
178-
const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
179-
const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
180-
return std::exp(p / q);
181-
}
182107
} // namespace
183108

184109
namespace at {
@@ -198,28 +123,51 @@ Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) {
198123
AT_DISPATCH_FLOATING_TYPES(self.type(), "_standard_gamma_grad", [&] {
199124
CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>(ret, self, output,
200125
[](scalar_t& ret_val, const scalar_t& self_val, const scalar_t &output_val) {
201-
ret_val = standard_gamma_grad_one(self_val, output_val);
126+
ret_val = standard_gamma_grad_one<scalar_t, double>(self_val, output_val);
202127
}
203128
);
204129
});
205130
return ret;
206131
}
207132

208-
Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
209-
AT_ERROR("_standard_gamma_grad is not implemented for CUDA types");
210-
}
133+
/*
134+
* This section is a counterpart to Distributions.cu
135+
*/
211136

212137
Tensor _s_poisson_cpu(const Tensor& lambda, Generator *gen) {
213138
Tensor ret = at::zeros(lambda.type(), lambda.sizes());
214-
auto lambda_ = lambda.toType(ScalarType::Double);
215139
AT_DISPATCH_FLOATING_TYPES(ret.type(), "poisson", [&] {
216140
THGenerator* generator = get_generator(gen);
217-
CPU_tensor_apply2<scalar_t, double>(ret, lambda_,
218-
[generator](scalar_t& ret_val, const double& lambda){
219-
ret_val = sample_poisson(lambda, generator);
141+
std::lock_guard<std::mutex> lock(generator->mutex);
142+
CPU_tensor_apply2<scalar_t, scalar_t>(ret, lambda,
143+
[generator](scalar_t& ret_val, const scalar_t& lambda){
144+
ret_val = static_cast<scalar_t>(sample_poisson(static_cast<double>(lambda), generator));
220145
}
221146
);
222-
});
147+
});
148+
return ret;
149+
}
150+
151+
Tensor _s_gamma_cpu(const Tensor& alpha, Generator *gen) {
152+
Tensor ret = alpha.type().zeros(alpha.sizes());
153+
AT_DISPATCH_FLOATING_TYPES(ret.type(), "gamma", [&] {
154+
THGenerator* generator = get_generator(gen);
155+
std::lock_guard<std::mutex> lock(generator->mutex);
156+
CPU_tensor_apply2<scalar_t, scalar_t>(ret, alpha,
157+
[generator](scalar_t& ret_val, const scalar_t& alpha){
158+
BaseSampler<double> standard_uniform([generator] () {
159+
return THRandom_standard_uniform(generator);
160+
});
161+
BaseSampler<double> standard_normal([generator] () {
162+
return THRandom_normal(generator, 0.0, 1.0);
163+
});
164+
auto sample = sample_gamma<scalar_t, double>(alpha, standard_uniform, standard_normal);
165+
ret_val = std::max(std::numeric_limits<scalar_t>::min(), (scalar_t) sample);
166+
}
167+
);
168+
});
169+
223170
return ret;
224171
}
172+
225173
}} // namespace at::native

0 commit comments

Comments
 (0)