Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 32 additions & 48 deletions aten/src/ATen/native/cuda/Distributions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -307,45 +307,6 @@ void bernoulli_tensor_cuda_kernel(
);
}

template<typename scalar_t>
void bernoulli_scalar_cuda_kernel(
at::Tensor& ret, double p_,
std::pair<uint64_t, uint64_t> seeds) {
float p = static_cast<float>(p_);
// The template argument `4` below indicates that we want to operate on four
// element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
at::cuda::CUDA_tensor_apply1<scalar_t, 4>(
ret, [seeds, p] __device__(
int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4) {
curandStatePhilox4_32_10_t state;
curand_init(
seeds.first,
blockIdx.x * blockDim.x + threadIdx.x,
seeds.second,
&state);
// See Note [Register spilling in curand call for CUDA < 10]
float4 rand = curand_uniform4(&state);
switch (n) {
case 4: {
v4 = static_cast<scalar_t>(rand.w <= p);
// fallthrough
}
case 3: {
v3 = static_cast<scalar_t>(rand.z <= p);
// fallthrough
}
case 2: {
v2 = static_cast<scalar_t>(rand.y <= p);
// fallthrough
}
case 1: {
v1 = static_cast<scalar_t>(rand.x <= p);
}
}
}
);
}

template<typename scalar_t>
void dirichlet_scalar_cuda_kernel(
at::Tensor& ret,
Expand Down Expand Up @@ -412,15 +373,6 @@ Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator* gen) {
return self;
}

Tensor& bernoulli_scalar_cuda_(Tensor &self, double p, Generator* gen) {
TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "bernoulli_scalar_cuda_", [&] {
auto seeds = next_philox_seed(gen, 10);
bernoulli_scalar_cuda_kernel<scalar_t>(self, p, seeds);
});
return self;
}

void uniform_kernel_cuda(TensorIterator& iter, double from_, double to_, Generator* gen_) {
auto gen = check_generator<CUDAGenerator>(gen_, &globalContext().defaultGenerator(kCUDA));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "uniform_cuda", [&] {
Expand Down Expand Up @@ -644,6 +596,31 @@ void log_normal_kernel_cuda(TensorIterator& iter, double mean_, double std_, Gen
});
}

void bernoulli_scalar_cuda_kernel(TensorIterator& iter, double p_, Generator* gen_) {
auto gen = check_generator<CUDAGenerator>(gen_, &globalContext().defaultGenerator(kCUDA));
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
if (std::is_same<scalar_t, double>::value) {
// define lambda for bernoulli transformation
auto bernoulli_func = [p_] __device__ (double rand) {
return static_cast<scalar_t>(rand <= p_);
};
distribution_nullary_kernel<scalar_t, double, curand4_engine_calls/2>(iter,
gen,
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
bernoulli_func);
} else {
auto p = static_cast<float>(p_);
auto bernoulli_func = [p] __device__ (float rand) {
return static_cast<scalar_t>(rand <= p);
};
distribution_nullary_kernel<scalar_t, float, curand4_engine_calls>(iter,
gen,
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
bernoulli_func);
}
});
}

Tensor& uniform_cuda_(Tensor& self, double from, double to, Generator* gen) {
auto iter = TensorIterator::nullary_op(self);
uniform_kernel_cuda(*iter, from, to, gen);
Expand Down Expand Up @@ -749,4 +726,11 @@ Tensor& log_normal_cuda_(Tensor& self, double mean, double std, Generator* gen)
return self;
}

Tensor& bernoulli_scalar_cuda_(Tensor &self, double p, Generator* gen) {
TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
auto iter = TensorIterator::nullary_op(self);
bernoulli_scalar_cuda_kernel(*iter, p, gen);
return self;
}

}} // namespace at::native