88#include " ATen/CPUGenerator.h"
99#include " ATen/CheckGenerator.h"
1010#include " ATen/Generator.h"
11+ #include " ATen/native/Distributions.h"
12+
13+ #include < functional>
1114
1215#include " TH/THRandom.h"
16+ #include " TH/THGenerator.h"
1317#include " TH/THMath.h"
1418
1519namespace {
@@ -100,85 +104,6 @@ int64_t sample_poisson(double lambda, THGenerator* generator) {
100104 }
101105}
102106
103- template <typename scalar>
104- static inline scalar digamma_one (scalar x) {
105- throw std::runtime_error (" digamma is only implemented for float, double" );
106- }
107-
108- template <>
109- inline double digamma_one<double >(double x) {
110- return TH_digamma (x);
111- }
112-
113- template <>
114- inline float digamma_one<float >(float x) {
115- return TH_digammaf (x);
116- }
117-
118- // Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
119- // for random number x drawn from a standard Gamma distribution Gamma(alpha).
120- template <typename scalar_t >
121- scalar_t standard_gamma_grad_one (scalar_t alpha, scalar_t x) {
122- // Use a Taylor series expansion for small x.
123- if (x < 0 .8f ) {
124- scalar_t numer = 1 ;
125- scalar_t denom = alpha;
126- auto series1 = numer / denom;
127- auto series2 = numer / (denom * denom);
128- for (int i = 1 ; i <= 5 ; ++i) {
129- numer *= -x / i;
130- denom += 1 ;
131- series1 += numer / denom;
132- series2 += numer / (denom * denom);
133- }
134- const auto pow_x_alpha = std::pow (x, alpha);
135- const auto gamma_pdf = std::pow (x, alpha - 1 ) * std::exp (-x);
136- const auto gamma_cdf = pow_x_alpha * series1;
137- const auto gamma_cdf_alpha = (std::log (x) - digamma_one (alpha)) * gamma_cdf
138- - pow_x_alpha * series2;
139- const auto result = -gamma_cdf_alpha / gamma_pdf;
140- return std::isnan (result) ? 0 : result;
141- }
142-
143- // Use a Rice saddle point expansion for large alpha.
144- if (alpha > 8 .0f ) {
145- if (0 .9f * alpha <= x && x <= 1 .1f * alpha) {
146- const auto numer_1 = 1 + 24 * alpha * (1 + 12 * alpha);
147- const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x)
148- - 65 * x * x / alpha + alpha * (107 + 3600 * x);
149- const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha);
150- return numer_1 * numer_2 / denom;
151- }
152- const auto denom = std::sqrt (8 * alpha);
153- const auto term2 = denom / (alpha - x);
154- const auto term3 = std::pow (x - alpha - alpha * std::log (x / alpha), -1 .5f );
155- const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3;
156- const auto term1 = std::log (x / alpha) * term23
157- - std::sqrt (2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x));
158- const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha));
159- const auto numer = x * term1;
160- return -stirling * numer / denom;
161- }
162-
163- // Use a bivariate rational approximation to the reparameterized gradient.
164- const auto u = std::log (x / alpha);
165- const auto v = std::log (alpha);
166- static const scalar_t coef_uv[3 ][8 ] = {
167- {0.16009398 , -0.094634809 , 0.025146376 , -0.0030648343 ,
168- 1 , 0.32668115 , 0.10406089 , 0.0014179084 },
169- {0.53487893 , 0.1298071 , 0.065735949 , -0.0015649758 ,
170- 0.16639465 , 0.020070113 , -0.0035938915 , -0.00058392623 },
171- {0.040121004 , -0.0065914022 , -0.0026286047 , -0.0013441777 ,
172- 0.017050642 , -0.0021309326 , 0.00085092367 , -1.5247877e-07 },
173- };
174- scalar_t coef_v[8 ];
175- for (int i = 0 ; i < 8 ; ++ i) {
176- coef_v[i] = coef_uv[0 ][i] + u * (coef_uv[1 ][i] + u * coef_uv[2 ][i]);
177- }
178- const auto p = coef_v[0 ] + v * (coef_v[1 ] + v * (coef_v[2 ] + v * coef_v[3 ]));
179- const auto q = coef_v[4 ] + v * (coef_v[5 ] + v * (coef_v[6 ] + v * coef_v[7 ]));
180- return std::exp (p / q);
181- }
182107} // namespace
183108
184109namespace at {
@@ -198,28 +123,51 @@ Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) {
198123 AT_DISPATCH_FLOATING_TYPES (self.type (), " _standard_gamma_grad" , [&] {
199124 CPU_tensor_apply3<scalar_t , scalar_t , scalar_t >(ret, self, output,
200125 [](scalar_t & ret_val, const scalar_t & self_val, const scalar_t &output_val) {
201- ret_val = standard_gamma_grad_one (self_val, output_val);
126+ ret_val = standard_gamma_grad_one< scalar_t , double > (self_val, output_val);
202127 }
203128 );
204129 });
205130 return ret;
206131}
207132
208- Tensor _standard_gamma_grad_cuda ( const Tensor& self, const Tensor& output) {
209- AT_ERROR ( " _standard_gamma_grad is not implemented for CUDA types " );
210- }
133+ /*
134+ * This section is a counterpart to Distributions.cu
135+ */
211136
212137Tensor _s_poisson_cpu (const Tensor& lambda, Generator *gen) {
213138 Tensor ret = at::zeros (lambda.type (), lambda.sizes ());
214- auto lambda_ = lambda.toType (ScalarType::Double);
215139 AT_DISPATCH_FLOATING_TYPES (ret.type (), " poisson" , [&] {
216140 THGenerator* generator = get_generator (gen);
217- CPU_tensor_apply2<scalar_t , double >(ret, lambda_,
218- [generator](scalar_t & ret_val, const double & lambda){
219- ret_val = sample_poisson (lambda, generator);
141+ std::lock_guard<std::mutex> lock (generator->mutex );
142+ CPU_tensor_apply2<scalar_t , scalar_t >(ret, lambda,
143+ [generator](scalar_t & ret_val, const scalar_t & lambda){
144+ ret_val = static_cast <scalar_t >(sample_poisson (static_cast <double >(lambda), generator));
220145 }
221146 );
222- });
147+ });
148+ return ret;
149+ }
150+
151+ Tensor _s_gamma_cpu (const Tensor& alpha, Generator *gen) {
152+ Tensor ret = alpha.type ().zeros (alpha.sizes ());
153+ AT_DISPATCH_FLOATING_TYPES (ret.type (), " gamma" , [&] {
154+ THGenerator* generator = get_generator (gen);
155+ std::lock_guard<std::mutex> lock (generator->mutex );
156+ CPU_tensor_apply2<scalar_t , scalar_t >(ret, alpha,
157+ [generator](scalar_t & ret_val, const scalar_t & alpha){
158+ BaseSampler<double > standard_uniform ([generator] () {
159+ return THRandom_standard_uniform (generator);
160+ });
161+ BaseSampler<double > standard_normal ([generator] () {
162+ return THRandom_normal (generator, 0.0 , 1.0 );
163+ });
164+ auto sample = sample_gamma<scalar_t , double >(alpha, standard_uniform, standard_normal);
165+ ret_val = std::max (std::numeric_limits<scalar_t >::min (), (scalar_t ) sample);
166+ }
167+ );
168+ });
169+
223170 return ret;
224171}
172+
225173}} // namespace at::native
0 commit comments