Skip to content

Commit eadac84

Browse files
syed-ahmedfacebook-github-bot
authored andcommitted
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop (#21300)
Summary: Pull Request resolved: #21300 ghimport-source-id: c314c28 Reviewed By: jerryzh168 Differential Revision: D15632935 Pulled By: ezyang fbshipit-source-id: 9bb24f17d78151bf50942905c967bdcfe1ff00cb
1 parent c82bf8e commit eadac84

File tree

1 file changed

+32
-48
lines changed

1 file changed

+32
-48
lines changed

aten/src/ATen/native/cuda/Distributions.cu

Lines changed: 32 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -307,45 +307,6 @@ void bernoulli_tensor_cuda_kernel(
307307
);
308308
}
309309

310-
template<typename scalar_t>
311-
void bernoulli_scalar_cuda_kernel(
312-
at::Tensor& ret, double p_,
313-
std::pair<uint64_t, uint64_t> seeds) {
314-
float p = static_cast<float>(p_);
315-
// The template argument `4` below indicates that we want to operate on four
316-
// element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
317-
at::cuda::CUDA_tensor_apply1<scalar_t, 4>(
318-
ret, [seeds, p] __device__(
319-
int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4) {
320-
curandStatePhilox4_32_10_t state;
321-
curand_init(
322-
seeds.first,
323-
blockIdx.x * blockDim.x + threadIdx.x,
324-
seeds.second,
325-
&state);
326-
// See Note [Register spilling in curand call for CUDA < 10]
327-
float4 rand = curand_uniform4(&state);
328-
switch (n) {
329-
case 4: {
330-
v4 = static_cast<scalar_t>(rand.w <= p);
331-
// fallthrough
332-
}
333-
case 3: {
334-
v3 = static_cast<scalar_t>(rand.z <= p);
335-
// fallthrough
336-
}
337-
case 2: {
338-
v2 = static_cast<scalar_t>(rand.y <= p);
339-
// fallthrough
340-
}
341-
case 1: {
342-
v1 = static_cast<scalar_t>(rand.x <= p);
343-
}
344-
}
345-
}
346-
);
347-
}
348-
349310
template<typename scalar_t>
350311
void dirichlet_scalar_cuda_kernel(
351312
at::Tensor& ret,
@@ -412,15 +373,6 @@ Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator* gen) {
412373
return self;
413374
}
414375

415-
Tensor& bernoulli_scalar_cuda_(Tensor &self, double p, Generator* gen) {
416-
TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
417-
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "bernoulli_scalar_cuda_", [&] {
418-
auto seeds = next_philox_seed(gen, 10);
419-
bernoulli_scalar_cuda_kernel<scalar_t>(self, p, seeds);
420-
});
421-
return self;
422-
}
423-
424376
void uniform_kernel_cuda(TensorIterator& iter, double from_, double to_, Generator* gen_) {
425377
auto gen = check_generator<CUDAGenerator>(gen_, &globalContext().defaultGenerator(kCUDA));
426378
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "uniform_cuda", [&] {
@@ -644,6 +596,31 @@ void log_normal_kernel_cuda(TensorIterator& iter, double mean_, double std_, Gen
644596
});
645597
}
646598

599+
void bernoulli_scalar_cuda_kernel(TensorIterator& iter, double p_, Generator* gen_) {
600+
auto gen = check_generator<CUDAGenerator>(gen_, &globalContext().defaultGenerator(kCUDA));
601+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
602+
if (std::is_same<scalar_t, double>::value) {
603+
// define lambda for bernoulli transformation
604+
auto bernoulli_func = [p_] __device__ (double rand) {
605+
return static_cast<scalar_t>(rand <= p_);
606+
};
607+
distribution_nullary_kernel<scalar_t, double, curand4_engine_calls/2>(iter,
608+
gen,
609+
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
610+
bernoulli_func);
611+
} else {
612+
auto p = static_cast<float>(p_);
613+
auto bernoulli_func = [p] __device__ (float rand) {
614+
return static_cast<scalar_t>(rand <= p);
615+
};
616+
distribution_nullary_kernel<scalar_t, float, curand4_engine_calls>(iter,
617+
gen,
618+
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
619+
bernoulli_func);
620+
}
621+
});
622+
}
623+
647624
Tensor& uniform_cuda_(Tensor& self, double from, double to, Generator* gen) {
648625
auto iter = TensorIterator::nullary_op(self);
649626
uniform_kernel_cuda(*iter, from, to, gen);
@@ -749,4 +726,11 @@ Tensor& log_normal_cuda_(Tensor& self, double mean, double std, Generator* gen)
749726
return self;
750727
}
751728

729+
Tensor& bernoulli_scalar_cuda_(Tensor &self, double p, Generator* gen) {
730+
TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
731+
auto iter = TensorIterator::nullary_op(self);
732+
bernoulli_scalar_cuda_kernel(*iter, p, gen);
733+
return self;
734+
}
735+
752736
}} // namespace at::native

0 commit comments

Comments
 (0)