|
14 | 14 | #include <ATen/native/Distributions.h> |
15 | 15 | #include <ATen/native/cuda/Loops.cuh> |
16 | 16 | #include <ATen/native/TensorIterator.h> |
| 17 | +#include <ATen/LegacyTHFunctionsCUDA.h> |
17 | 18 |
|
18 | 19 | #include <THC/THCGeneral.h> |
19 | 20 | #include <THC/THCTensorRandom.h> |
@@ -120,6 +121,22 @@ __global__ void distribution_elementwise_grid_stride_kernel(int numel, |
120 | 121 | } |
121 | 122 | } |
122 | 123 |
|
| 124 | +/** |
| 125 | + * distribution_nullary_kernel is analogous to gpu_nullary_kernel in |
| 126 | + * ATen/native/cuda/Loops.cuh. Like gpu_nullary_kernel, it uses |
| 127 | + * TensorIterator to launch a kernel. However, the differences are |
| 128 | + * - it launches a grid-stride loop based kernel. The kernel is not |
| 129 | + * generic like elementwise_kernel in Loops.cuh and is specialized |
| 130 | + * for the distribution kernels here. |
| 131 | + * - For big size tensors, we can launch multiple kernels recursively |
| 132 | + * (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox |
| 133 | + * offset calculation is done in this function. |
| 134 | + * |
| 135 | + * FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh |
| 136 | + * to have grid-stride loop kernel and then use that to launch our distribution |
| 137 | + * kernels? Note that we need a grid-stride loop kernel because, we found by testing |
| 138 | + * that it achieves peak effective bandwidth. |
| 139 | + */ |
123 | 140 | template<typename scalar_t, |
124 | 141 | typename accscalar_t, |
125 | 142 | int unroll_factor, |
@@ -475,6 +492,30 @@ void random_kernel_cuda(TensorIterator& iter, uint64_t range, int64_t base, Gene |
475 | 492 | }); |
476 | 493 | } |
477 | 494 |
|
| 495 | +void normal_kernel_cuda(TensorIterator& iter, double mean_, double std_, Generator* gen_) { |
| 496 | + auto gen = check_generator<CUDAGenerator>(gen_, &globalContext().defaultGenerator(kCUDA)); |
| 497 | + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "normal_cuda", [&] { |
| 498 | + using accscalar_t = at::acc_type<scalar_t, true>; |
| 499 | + auto mean = static_cast<accscalar_t>(mean_); |
| 500 | + auto std = static_cast<accscalar_t>(std_); |
| 501 | + // define lambda to multiply std and add mean |
| 502 | + auto normal_func = [mean, std] __device__ (accscalar_t rand) { |
| 503 | + return static_cast<scalar_t>(rand * std + mean); |
| 504 | + }; |
| 505 | + if (std::is_same<scalar_t, double>::value) { |
| 506 | + distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter, |
| 507 | + gen, |
| 508 | + [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal2_double(state); }, |
| 509 | + normal_func); |
| 510 | + } else { |
| 511 | + distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter, |
| 512 | + gen, |
| 513 | + [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal4(state); }, |
| 514 | + normal_func); |
| 515 | + } |
| 516 | + }); |
| 517 | +} |
| 518 | + |
478 | 519 | Tensor& uniform_cuda_(Tensor& self, double from, double to, Generator* gen) { |
479 | 520 | auto iter = TensorIterator::nullary_op(self); |
480 | 521 | uniform_kernel_cuda(*iter, from, to, gen); |
@@ -510,4 +551,48 @@ Tensor& capped_random_cuda_(Tensor& self, int64_t to, Generator* gen) { |
510 | 551 | return clamped_random_cuda_(self, 0, to, gen); |
511 | 552 | } |
512 | 553 |
|
| 554 | +Tensor& normal_cuda_(Tensor& self, double mean, double std, Generator* gen) { |
| 555 | + TORCH_CHECK(std > 0.0, "normal_ expects std > 0.0, but found std=", std); |
| 556 | + auto iter = TensorIterator::nullary_op(self); |
| 557 | + normal_kernel_cuda(*iter, mean, std, gen); |
| 558 | + return self; |
| 559 | +} |
| 560 | + |
| 561 | +Tensor& normal_out_cuda(Tensor& output, const Tensor& mean, double std, Generator* gen) { |
| 562 | + normal_cuda_(output, 0, std, gen); |
| 563 | + output.add_(mean); |
| 564 | + return output; |
| 565 | +} |
| 566 | + |
| 567 | +Tensor& normal_out_cuda(Tensor& output, double mean, const Tensor& std, Generator* gen) { |
| 568 | + normal_cuda_(output, 0, 1, gen); |
| 569 | + auto mean_tensor = at::full({1}, mean, output.options()); |
| 570 | + at::native::legacy::cuda::_th_addcmul_out(output, mean_tensor, output, std, 1); |
| 571 | + return output; |
| 572 | +} |
| 573 | + |
| 574 | +Tensor& normal_out_cuda(Tensor& output, const Tensor& mean, const Tensor& std, Generator* gen) { |
| 575 | + normal_cuda_(output, 0, 1, gen); |
| 576 | + at::native::legacy::cuda::_th_addcmul_out(output, mean, output, std, 1); |
| 577 | + return output; |
| 578 | +} |
| 579 | + |
| 580 | +Tensor normal_cuda(const Tensor& mean, double std, Generator* gen) { |
| 581 | + Tensor ret = at::empty(mean.sizes(), mean.options()); |
| 582 | + normal_out_cuda(ret, mean, std, gen); |
| 583 | + return ret; |
| 584 | +} |
| 585 | + |
| 586 | +Tensor normal_cuda(double mean, const Tensor& std, Generator* gen) { |
| 587 | + Tensor ret = at::empty(std.sizes(), std.options()); |
| 588 | + normal_out_cuda(ret, mean, std, gen); |
| 589 | + return ret; |
| 590 | +} |
| 591 | + |
| 592 | +Tensor normal_cuda(const Tensor& mean, const Tensor& std, Generator* gen) { |
| 593 | + Tensor ret = at::empty(mean.sizes(), mean.options()); |
| 594 | + normal_out_cuda(ret, mean, std, gen); |
| 595 | + return ret; |
| 596 | +} |
| 597 | + |
513 | 598 | }} // namespace at::native |
0 commit comments