-
Notifications
You must be signed in to change notification settings - Fork 26.3k
implement gamma cuda #6855
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement gamma cuda #6855
Conversation
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Two minor nits. Did you need to pull in any new implementations of those math functions, or are they just coming from other places in our codebase?
| return curand_normal(&state); | ||
| }); | ||
| auto sample = sample_gamma<float>(alpha, standard_uniform, standard_normal); | ||
| ret_val = ::max(THCNumerics<scalar_t>::min(), scalar_cast<scalar_t>(sample)); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| return THRandom_normal(generator, 0.0, 1.0); | ||
| }); | ||
| auto sample = sample_gamma<double>(alpha, standard_uniform, standard_normal); | ||
| ret_val = std::max(std::numeric_limits<scalar_t>::min(), (scalar_t) sample); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Regarding new functions: No. I just moved stuff around from aten/src/TH and aten/src/ATen/native/Distributions.cpp. The copyright notice was from Rachit's patch, but the code below it actually was in Distributions.cpp before. |
fritzo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for implementing this!
| } | ||
|
|
||
| // Use a Rice saddle point expansion for large alpha. | ||
| if (alpha > 8.0f) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Distributions.h
Outdated
| // Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha) | ||
| // for random number x drawn from a standard Gamma distribution Gamma(alpha). | ||
| template <typename scalar_t> | ||
| deviceforcuda scalar_t standard_gamma_grad_one(scalar_t alpha, scalar_t x) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Distributions.h
Outdated
|
|
||
| // Boost alpha for higher acceptance probability. | ||
| if (alpha < 1.0) { | ||
| scale *= std::pow(1 - standard_uniform.sample(), static_cast<scalar_t>(1.0) / alpha); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
It looks good to me! I too am not terribly happy with the |
aten/src/ATen/native/Distributions.h
Outdated
|
|
||
| // This implements the acceptance-rejection method of Marsaglia and Tsang (2000) | ||
| // doi:10.1145/358407.358414 | ||
| const scalar_t d = alpha - 1.0 / 3.0; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Distributions.h
Outdated
| */ | ||
| template <typename scalar_t> | ||
| deviceforcuda static inline scalar_t digamma_one(scalar_t x) { | ||
| using acc_scalar_t = typename std::conditional<std::is_same<scalar_t, at::Half>::value, float, scalar_t>::type; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Distributions.h
Outdated
| // Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha) | ||
| // for random number x drawn from a standard Gamma distribution Gamma(alpha). | ||
| template <typename scalar_t> | ||
| deviceforcuda scalar_t standard_gamma_grad_one(scalar_t alpha, scalar_t x) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| Tensor _s_gamma_cuda(const Tensor& alpha, Generator* gen) { | ||
| Tensor ret = alpha.type().tensor(alpha.sizes()); | ||
| auto alpha_ = alpha.toType(ScalarType::Float); | ||
| AT_DISPATCH_FLOATING_TYPES(ret.type(), "gamma", [&] { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@pytorchbot retest this please |
|
So while looking into adding an increment parameter to next_philox_seed. |
|
@t-vi but that's what is done now, AtomicAdd returns "old" value that is used by a thread. It is also thread-safe (i.e. two threads are guaranteed to have two different "old" values). Am I missing something? |
|
Ah yes. I was confused. Thank you! |
…, cast locally rather than tensors
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Needs some final fixes and should be good to merge.
| AT_DISPATCH_FLOATING_TYPES(ret.type(), "poisson", [&] { | ||
| poisson_cuda_kernel<scalar_t>(ret, lambda_, next_philox_seed(gen)); | ||
| AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.type(), "poisson", [&] { | ||
| poisson_cuda_kernel<cuda::type<scalar_t>>(ret, lambda, next_philox_seed(gen, 20)); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| BaseSampler<scalar_t> standard_normal([generator] () { | ||
| return THRandom_normal(generator, 0.0, 1.0); | ||
| }); | ||
| auto sample = sample_gamma<scalar_t, scalar_t>(alpha, standard_uniform, standard_normal); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| void THTensor_(standard_gamma)(THTensor *self, THGenerator *_generator, THTensor *alpha) | ||
| { | ||
| std::lock_guard<std::mutex> lock(_generator->mutex); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| #if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_MAJOR_VER < 9)) | ||
| template<typename R, typename T> | ||
| deviceforcuda R cast_wrapper(T v) { return scalar_cast<R>(v); } | ||
| #else |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
… double\nThank you for your review comments!
* Refactor standard_gamma and implement CUDA gamma sampling * Attempt fixes for AT_CUDA_ENABLED changes * Gamma cuda and cpu forward as ATen native * implement standard_gamma_grad_cuda * update native_test.cpp, try to fix windows and various cuda version compiles * searching a windows fix via CI... use std:: for math * casting some constants in the calculation, compute at float for half precision * whitespace fixes * add acctype to do half->float computation, include HALF in generation, cast locally rather than tensors * fix cuda8 half compilation * always use scalar_cast with CUDACC, lock CPU generator, CPU acctype = double\nThank you for your review comments!
* Refactor standard_gamma and implement CUDA gamma sampling * Attempt fixes for AT_CUDA_ENABLED changes * Gamma cuda and cpu forward as ATen native * implement standard_gamma_grad_cuda * update native_test.cpp, try to fix windows and various cuda version compiles * searching a windows fix via CI... use std:: for math * casting some constants in the calculation, compute at float for half precision * whitespace fixes * add acctype to do half->float computation, include HALF in generation, cast locally rather than tensors * fix cuda8 half compilation * always use scalar_cast with CUDACC, lock CPU generator, CPU acctype = double\nThank you for your review comments!
Thank you!
Things I could use feedback on (in addition to all the needs for improvement you spot):