Skip to content

Conversation

@t-vi
Copy link
Collaborator

@t-vi t-vi commented Apr 23, 2018

Things I could use feedback on (in addition to all the needs for improvement you spot):

  • I accidentally moved digamma to ATen native before noticing that it was available for CUDA and CPU in TH/THC. Should I expose the new digamma and drop the old or should I back out the new digamma and use TH/THC?
  • My impression is that the digamma half implementation previously used float for intermediate results. I do not do this yet. Should I?
  • One of the possion tests seems to start to fail, but I'm not entirely sure what I have changed to effect that.
  • I'm not super happy about the ifdef's I needed to get the code to work in Distributions.h on CUDA and CPU.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Two minor nits. Did you need to pull in any new implementations of those math functions, or are they just coming from other places in our codebase?

return curand_normal(&state);
});
auto sample = sample_gamma<float>(alpha, standard_uniform, standard_normal);
ret_val = ::max(THCNumerics<scalar_t>::min(), scalar_cast<scalar_t>(sample));

This comment was marked as off-topic.

return THRandom_normal(generator, 0.0, 1.0);
});
auto sample = sample_gamma<double>(alpha, standard_uniform, standard_normal);
ret_val = std::max(std::numeric_limits<scalar_t>::min(), (scalar_t) sample);

This comment was marked as off-topic.

This comment was marked as off-topic.

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 23, 2018

Regarding new functions: No. I just moved stuff around from aten/src/TH and aten/src/ATen/native/Distributions.cpp. The copyright notice was from Rachit's patch, but the code below it actually was in Distributions.cpp before.

Copy link
Collaborator

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing this!

}

// Use a Rice saddle point expansion for large alpha.
if (alpha > 8.0f) {

This comment was marked as off-topic.

// Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
// for random number x drawn from a standard Gamma distribution Gamma(alpha).
template <typename scalar_t>
deviceforcuda scalar_t standard_gamma_grad_one(scalar_t alpha, scalar_t x) {

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


// Boost alpha for higher acceptance probability.
if (alpha < 1.0) {
scale *= std::pow(1 - standard_uniform.sample(), static_cast<scalar_t>(1.0) / alpha);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@rachtsingh
Copy link
Contributor

It looks good to me! I too am not terribly happy with the #ifdef situation but it does work - thanks for following up on Adam's suggestion and getting this finished. The build failure looks spurious?


// This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
// doi:10.1145/358407.358414
const scalar_t d = alpha - 1.0 / 3.0;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

*/
template <typename scalar_t>
deviceforcuda static inline scalar_t digamma_one(scalar_t x) {
using acc_scalar_t = typename std::conditional<std::is_same<scalar_t, at::Half>::value, float, scalar_t>::type;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

// Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
// for random number x drawn from a standard Gamma distribution Gamma(alpha).
template <typename scalar_t>
deviceforcuda scalar_t standard_gamma_grad_one(scalar_t alpha, scalar_t x) {

This comment was marked as off-topic.

Tensor _s_gamma_cuda(const Tensor& alpha, Generator* gen) {
Tensor ret = alpha.type().tensor(alpha.sizes());
auto alpha_ = alpha.toType(ScalarType::Float);
AT_DISPATCH_FLOATING_TYPES(ret.type(), "gamma", [&] {

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Apr 24, 2018

@pytorchbot retest this please

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 24, 2018

So while looking into adding an increment parameter to next_philox_seed.
In terms of thread safety: I think there is a race condition that you might end up having two threads grab the same state and then do an atomic add each (in other words, it would be correct to have the two assignments in next_philox_seed in one atomic operation). Is that rare enough to neglect given the overhead of having to lock?

@ngimel
Copy link
Collaborator

ngimel commented Apr 24, 2018

@t-vi but that's what is done now, AtomicAdd returns "old" value that is used by a thread. It is also thread-safe (i.e. two threads are guaranteed to have two different "old" values). Am I missing something?

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 24, 2018

Ah yes. I was confused. Thank you!

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Needs some final fixes and should be good to merge.

AT_DISPATCH_FLOATING_TYPES(ret.type(), "poisson", [&] {
poisson_cuda_kernel<scalar_t>(ret, lambda_, next_philox_seed(gen));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.type(), "poisson", [&] {
poisson_cuda_kernel<cuda::type<scalar_t>>(ret, lambda, next_philox_seed(gen, 20));

This comment was marked as off-topic.

This comment was marked as off-topic.

BaseSampler<scalar_t> standard_normal([generator] () {
return THRandom_normal(generator, 0.0, 1.0);
});
auto sample = sample_gamma<scalar_t, scalar_t>(alpha, standard_uniform, standard_normal);

This comment was marked as off-topic.


void THTensor_(standard_gamma)(THTensor *self, THGenerator *_generator, THTensor *alpha)
{
std::lock_guard<std::mutex> lock(_generator->mutex);

This comment was marked as off-topic.

#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_MAJOR_VER < 9))
template<typename R, typename T>
deviceforcuda R cast_wrapper(T v) { return scalar_cast<R>(v); }
#else

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

… double\nThank you for your review comments!
@ezyang ezyang merged commit c10da63 into pytorch:master Apr 26, 2018
Jorghi12 pushed a commit to wsttiger/pytorch that referenced this pull request May 10, 2018
* Refactor standard_gamma and implement CUDA gamma sampling

* Attempt fixes for AT_CUDA_ENABLED changes

* Gamma cuda and cpu forward as ATen native

* implement standard_gamma_grad_cuda

* update native_test.cpp, try to fix windows and various cuda version compiles

* searching a windows fix via CI... use std:: for math

* casting some constants in the calculation, compute at float for half precision

* whitespace fixes

* add acctype to do half->float computation, include HALF in generation, cast locally rather than tensors

* fix cuda8 half compilation

* always use scalar_cast with CUDACC, lock CPU generator, CPU acctype = double\nThank you for your review comments!
weiyangfb pushed a commit to weiyangfb/pytorch that referenced this pull request Jun 11, 2018
* Refactor standard_gamma and implement CUDA gamma sampling

* Attempt fixes for AT_CUDA_ENABLED changes

* Gamma cuda and cpu forward as ATen native

* implement standard_gamma_grad_cuda

* update native_test.cpp, try to fix windows and various cuda version compiles

* searching a windows fix via CI... use std:: for math

* casting some constants in the calculation, compute at float for half precision

* whitespace fixes

* add acctype to do half->float computation, include HALF in generation, cast locally rather than tensors

* fix cuda8 half compilation

* always use scalar_cast with CUDACC, lock CPU generator, CPU acctype = double\nThank you for your review comments!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants