Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -2623,7 +2623,6 @@
- floating_point
backends:
- CPU
- CUDA
return: argument 0
variants:
- function
Expand Down Expand Up @@ -2663,7 +2662,6 @@
- floating_point
backends:
- CPU
- CUDA
cname: normal
variants: function
return: self
Expand Down
85 changes: 85 additions & 0 deletions aten/src/ATen/native/cuda/Distributions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ATen/native/Distributions.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/LegacyTHFunctionsCUDA.h>

#include <THC/THCGeneral.h>
#include <THC/THCTensorRandom.h>
Expand Down Expand Up @@ -120,6 +121,22 @@ __global__ void distribution_elementwise_grid_stride_kernel(int numel,
}
}

/**
* distribution_nullary_kernel is analogous to gpu_nullary_kernel in
* ATen/native/cuda/Loops.cuh. Like gpu_nullary_kernel, it uses
* TensorIterator to launch a kernel. However, the differences are
* - it launches a grid-stride loop based kernel. The kernel is not
* generic like elementwise_kernel in Loops.cuh and is specialized
* for the distribution kernels here.
* - For big size tensors, we can launch multiple kernels recursively
* (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox
* offset calculation is done in this function.
*
* FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh
* to have grid-stride loop kernel and then use that to launch our distribution
* kernels? Note that we need a grid-stride loop kernel because, we found by testing
* that it achieves peak effective bandwidth.
*/
template<typename scalar_t,
typename accscalar_t,
int unroll_factor,
Expand Down Expand Up @@ -475,6 +492,30 @@ void random_kernel_cuda(TensorIterator& iter, uint64_t range, int64_t base, Gene
});
}

void normal_kernel_cuda(TensorIterator& iter, double mean_, double std_, Generator* gen_) {
auto gen = check_generator<CUDAGenerator>(gen_, &globalContext().defaultGenerator(kCUDA));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "normal_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto mean = static_cast<accscalar_t>(mean_);
auto std = static_cast<accscalar_t>(std_);
// define lambda to multiply std and add mean
auto normal_func = [mean, std] __device__ (accscalar_t rand) {
return static_cast<scalar_t>(rand * std + mean);
};
if (std::is_same<scalar_t, double>::value) {
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
gen,
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal2_double(state); },
normal_func);
} else {
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter,
gen,
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal4(state); },
normal_func);
}
});
}

Tensor& uniform_cuda_(Tensor& self, double from, double to, Generator* gen) {
auto iter = TensorIterator::nullary_op(self);
uniform_kernel_cuda(*iter, from, to, gen);
Expand Down Expand Up @@ -510,4 +551,48 @@ Tensor& capped_random_cuda_(Tensor& self, int64_t to, Generator* gen) {
return clamped_random_cuda_(self, 0, to, gen);
}

Tensor& normal_cuda_(Tensor& self, double mean, double std, Generator* gen) {
TORCH_CHECK(std > 0.0, "normal_ expects std > 0.0, but found std=", std);
auto iter = TensorIterator::nullary_op(self);
normal_kernel_cuda(*iter, mean, std, gen);
return self;
}

Tensor& normal_out_cuda(Tensor& output, const Tensor& mean, double std, Generator* gen) {
normal_cuda_(output, 0, std, gen);
output.add_(mean);
return output;
}

Tensor& normal_out_cuda(Tensor& output, double mean, const Tensor& std, Generator* gen) {
normal_cuda_(output, 0, 1, gen);
auto mean_tensor = at::full({1}, mean, output.options());
at::native::legacy::cuda::_th_addcmul_out(output, mean_tensor, output, std, 1);
return output;
}

Tensor& normal_out_cuda(Tensor& output, const Tensor& mean, const Tensor& std, Generator* gen) {
normal_cuda_(output, 0, 1, gen);
at::native::legacy::cuda::_th_addcmul_out(output, mean, output, std, 1);
return output;
}

Tensor normal_cuda(const Tensor& mean, double std, Generator* gen) {
Tensor ret = at::empty(mean.sizes(), mean.options());
normal_out_cuda(ret, mean, std, gen);
return ret;
}

Tensor normal_cuda(double mean, const Tensor& std, Generator* gen) {
Tensor ret = at::empty(std.sizes(), std.options());
normal_out_cuda(ret, mean, std, gen);
return ret;
}

Tensor normal_cuda(const Tensor& mean, const Tensor& std, Generator* gen) {
Tensor ret = at::empty(mean.sizes(), mean.options());
normal_out_cuda(ret, mean, std, gen);
return ret;
}

}} // namespace at::native
14 changes: 7 additions & 7 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3181,7 +3181,7 @@
variants: method
dispatch:
CPU: legacy::cpu::_th_normal_
CUDA: legacy::cuda::_th_normal_
CUDA: normal_cuda_

- func: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)
variants: method
Expand Down Expand Up @@ -3923,32 +3923,32 @@
- func: normal(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: legacy::cpu::_th_normal_out
CUDA: legacy::cuda::_th_normal_out
CUDA: normal_out_cuda

- func: normal(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor
dispatch:
CPU: legacy::cpu::_th_normal
CUDA: legacy::cuda::_th_normal
CUDA: normal_cuda

- func: normal(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: legacy::cpu::_th_normal_out
CUDA: legacy::cuda::_th_normal_out
CUDA: normal_out_cuda

- func: normal(float mean, Tensor std, *, Generator? generator=None) -> Tensor
dispatch:
CPU: legacy::cpu::_th_normal
CUDA: legacy::cuda::_th_normal
CUDA: normal_cuda

- func: normal(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: legacy::cpu::_th_normal_out
CUDA: legacy::cuda::_th_normal_out
CUDA: normal_out_cuda

- func: normal(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor
dispatch:
CPU: legacy::cpu::_th_normal
CUDA: legacy::cuda::_th_normal
CUDA: normal_cuda

- func: alias(Tensor(a) self) -> Tensor(a)
variants: method, function
Expand Down
4 changes: 0 additions & 4 deletions aten/src/THC/THCTensorRandom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,12 @@ __global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1, ARG2)
} \
}

GENERATE_KERNEL2(generate_normal, float, double mean, double stdv, float, curand_normal, (x * stdv) + mean)
GENERATE_KERNEL2(generate_normal, double, double mean, double stdv, double, curand_normal_double, (x * stdv) + mean)

GENERATE_KERNEL1(generate_exponential, float, double lambda, float, curand_uniform, (float)(-1. / lambda * log(x)))
GENERATE_KERNEL1(generate_exponential, double, double lambda, double, curand_uniform_double, (double)(-1. / lambda * log(x)))

GENERATE_KERNEL2(generate_cauchy, float, double median, double sigma, float, curand_uniform, (float)(median + sigma * tan(M_PI*(x-0.5))))
GENERATE_KERNEL2(generate_cauchy, double, double median, double sigma, double, curand_uniform_double, (double)(median + sigma * tan(M_PI*(x-0.5))))

GENERATE_KERNEL2(generate_normal, at::Half, double mean, double stdv, float, curand_normal, (ScalarConvert<float, at::Half>::to((x * stdv) + mean)))
GENERATE_KERNEL1(generate_exponential, at::Half, double lambda, float, curand_uniform, (ScalarConvert<float, at::Half>::to((float)(-1. / lambda * log(x)))))
GENERATE_KERNEL2(generate_cauchy, at::Half, double median, double sigma, float, curand_uniform, (ScalarConvert<float, at::Half>::to((float)(median + sigma * tan(M_PI*(x-0.5))))))

Expand Down
36 changes: 0 additions & 36 deletions aten/src/THC/generic/THCTensorRandom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,6 @@

#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)

void THCTensor_(normal)(THCState* state, THCTensor *self_, double mean, double stdv)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self_));
ptrdiff_t size = THCTensor_(nElement)(state, self_);
if (size == 0) return;
THCGenerator* gen = THCRandom_getGenerator(state);
THCTensor *self = THCTensor_(newContiguous)(state, self_);
scalar_t *data = THCTensor_(data)(state, self);

generate_normal<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->state.gen_states, size, data, mean, stdv);

THCTensor_(freeCopyTo)(state, self, self_);
};

void THCTensor_(normal_means)(THCState *state, THCTensor *self, THCTensor *means, double stddev) {
THCTensor_(resizeAs)(state, self, means);
THCTensor_(normal)(state, self, 0, stddev);
THCTensor_(cadd)(state, self, self, ScalarConvert<int, scalar_t>::to(1), means);
}

void THCTensor_(normal_stddevs)(THCState *state, THCTensor *self, double mean, THCTensor *stddevs)
{
THCTensor_(resizeAs)(state, self, stddevs);
THCTensor_(normal)(state, self, 0, 1);
THCTensor_(cmul)(state, self, self, stddevs);
THCTensor_(add)(state, self, self, ScalarConvert<double, scalar_t>::to(mean));
}

void THCTensor_(normal_means_stddevs)(THCState *state, THCTensor *self, THCTensor *means, THCTensor *stddevs)
{
THCTensor_(resizeAs)(state, self, means);
THCTensor_(normal)(state, self, 0, 1);
THCTensor_(cmul)(state, self, self, stddevs);
THCTensor_(cadd)(state, self, self, ScalarConvert<int, scalar_t>::to(1), means);
}

void THCTensor_(logNormal)(THCState* state, THCTensor *self_, double mean, double stdv)
{
Expand Down
4 changes: 0 additions & 4 deletions aten/src/THC/generic/THCTensorRandom.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@

#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)

THC_API void THCTensor_(normal)(struct THCState *state, THCTensor *self, double mean, double stdv);
THC_API void THCTensor_(normal_means)(struct THCState *state, THCTensor *self, THCTensor *means, double stddev);
THC_API void THCTensor_(normal_stddevs)(struct THCState *state, THCTensor *self, double mean, THCTensor *stddevs);
THC_API void THCTensor_(normal_means_stddevs)(struct THCState *state, THCTensor *self, THCTensor *means, THCTensor *stddevs);
THC_API void THCTensor_(logNormal)(struct THCState *state, THCTensor *self, double mean, double stdv);
THC_API void THCTensor_(exponential)(struct THCState *state, THCTensor *self, double lambda);
THC_API void THCTensor_(cauchy)(struct THCState *state, THCTensor *self, double median, double sigma);
Expand Down
1 change: 1 addition & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4779,6 +4779,7 @@ def test_Conv2d_groups_nobias(self):
# See also https://github.com/pytorch/pytorch/pull/18463#issuecomment-476563686
# and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024
def test_Conv2d_groups_nobias_v2(self):
torch.manual_seed(123)
dev_dtypes = [("cpu", torch.float)]
if TEST_CUDA:
dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)]
Expand Down