@@ -307,45 +307,6 @@ void bernoulli_tensor_cuda_kernel(
307307 );
308308}
309309
310- template <typename scalar_t >
311- void bernoulli_scalar_cuda_kernel (
312- at::Tensor& ret, double p_,
313- std::pair<uint64_t , uint64_t > seeds) {
314- float p = static_cast <float >(p_);
315- // The template argument `4` below indicates that we want to operate on four
316- // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
317- at::cuda::CUDA_tensor_apply1<scalar_t , 4 >(
318- ret, [seeds, p] __device__ (
319- int n, scalar_t & v1, scalar_t & v2, scalar_t & v3, scalar_t & v4) {
320- curandStatePhilox4_32_10_t state;
321- curand_init (
322- seeds.first ,
323- blockIdx .x * blockDim .x + threadIdx .x ,
324- seeds.second ,
325- &state);
326- // See Note [Register spilling in curand call for CUDA < 10]
327- float4 rand = curand_uniform4 (&state);
328- switch (n) {
329- case 4 : {
330- v4 = static_cast <scalar_t >(rand.w <= p);
331- // fallthrough
332- }
333- case 3 : {
334- v3 = static_cast <scalar_t >(rand.z <= p);
335- // fallthrough
336- }
337- case 2 : {
338- v2 = static_cast <scalar_t >(rand.y <= p);
339- // fallthrough
340- }
341- case 1 : {
342- v1 = static_cast <scalar_t >(rand.x <= p);
343- }
344- }
345- }
346- );
347- }
348-
349310template <typename scalar_t >
350311void dirichlet_scalar_cuda_kernel (
351312 at::Tensor& ret,
@@ -412,15 +373,6 @@ Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator* gen) {
412373 return self;
413374}
414375
415- Tensor& bernoulli_scalar_cuda_ (Tensor &self, double p, Generator* gen) {
416- TORCH_CHECK (0 <= p && p <= 1 , " bernoulli_ expects p to be in [0, 1], but got p=" , p);
417- AT_DISPATCH_ALL_TYPES_AND (at::ScalarType::Half, self.scalar_type (), " bernoulli_scalar_cuda_" , [&] {
418- auto seeds = next_philox_seed (gen, 10 );
419- bernoulli_scalar_cuda_kernel<scalar_t >(self, p, seeds);
420- });
421- return self;
422- }
423-
424376void uniform_kernel_cuda (TensorIterator& iter, double from_, double to_, Generator* gen_) {
425377 auto gen = check_generator<CUDAGenerator>(gen_, &globalContext ().defaultGenerator (kCUDA ));
426378 AT_DISPATCH_FLOATING_TYPES_AND_HALF (iter.dtype (), " uniform_cuda" , [&] {
@@ -644,6 +596,31 @@ void log_normal_kernel_cuda(TensorIterator& iter, double mean_, double std_, Gen
644596 });
645597}
646598
599+ void bernoulli_scalar_cuda_kernel (TensorIterator& iter, double p_, Generator* gen_) {
600+ auto gen = check_generator<CUDAGenerator>(gen_, &globalContext ().defaultGenerator (kCUDA ));
601+ AT_DISPATCH_ALL_TYPES_AND (at::ScalarType::Half, iter.dtype (), " bernoulli_scalar_cuda_" , [&] {
602+ if (std::is_same<scalar_t , double >::value) {
603+ // define lambda for bernoulli transformation
604+ auto bernoulli_func = [p_] __device__ (double rand) {
605+ return static_cast <scalar_t >(rand <= p_);
606+ };
607+ distribution_nullary_kernel<scalar_t , double , curand4_engine_calls/2 >(iter,
608+ gen,
609+ [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double (state); },
610+ bernoulli_func);
611+ } else {
612+ auto p = static_cast <float >(p_);
613+ auto bernoulli_func = [p] __device__ (float rand) {
614+ return static_cast <scalar_t >(rand <= p);
615+ };
616+ distribution_nullary_kernel<scalar_t , float , curand4_engine_calls>(iter,
617+ gen,
618+ [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4 (state); },
619+ bernoulli_func);
620+ }
621+ });
622+ }
623+
647624Tensor& uniform_cuda_ (Tensor& self, double from, double to, Generator* gen) {
648625 auto iter = TensorIterator::nullary_op (self);
649626 uniform_kernel_cuda (*iter, from, to, gen);
@@ -749,4 +726,11 @@ Tensor& log_normal_cuda_(Tensor& self, double mean, double std, Generator* gen)
749726 return self;
750727}
751728
729+ Tensor& bernoulli_scalar_cuda_ (Tensor &self, double p, Generator* gen) {
730+ TORCH_CHECK (0 <= p && p <= 1 , " bernoulli_ expects p to be in [0, 1], but got p=" , p);
731+ auto iter = TensorIterator::nullary_op (self);
732+ bernoulli_scalar_cuda_kernel (*iter, p, gen);
733+ return self;
734+ }
735+
752736}} // namespace at::native
0 commit comments