Skip to content

Commit 012069c

Browse files
ezyangfacebook-github-bot
authored andcommitted
Revert D15454048: Move THCTensor_{normal, normal_means, normal_stddevs, normal_means_stddevs} to ATen
Differential Revision: D15454048 Original commit changeset: 8bfc57bf015b fbshipit-source-id: 98c562ab4cf7a00e9041b2aa50eb7fb0f0c48f69
1 parent dc8f306 commit 012069c

File tree

6 files changed

+53
-92
lines changed

6 files changed

+53
-92
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2625,6 +2625,7 @@
26252625
- floating_point
26262626
backends:
26272627
- CPU
2628+
- CUDA
26282629
return: argument 0
26292630
variants:
26302631
- function
@@ -2664,6 +2665,7 @@
26642665
- floating_point
26652666
backends:
26662667
- CPU
2668+
- CUDA
26672669
cname: normal
26682670
variants: function
26692671
return: self

aten/src/ATen/native/cuda/Distributions.cu

Lines changed: 0 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include <ATen/native/Distributions.h>
1515
#include <ATen/native/cuda/Loops.cuh>
1616
#include <ATen/native/TensorIterator.h>
17-
#include <ATen/LegacyTHFunctionsCUDA.h>
1817

1918
#include <THC/THCGeneral.h>
2019
#include <THC/THCTensorRandom.h>
@@ -121,22 +120,6 @@ __global__ void distribution_elementwise_grid_stride_kernel(int numel,
121120
}
122121
}
123122

124-
/**
125-
* distribution_nullary_kernel is analogous to gpu_nullary_kernel in
126-
* ATen/native/cuda/Loops.cuh. Like gpu_nullary_kernel, it uses
127-
* TensorIterator to launch a kernel. However, the differences are
128-
* - it launches a grid-stride loop based kernel. The kernel is not
129-
* generic like elementwise_kernel in Loops.cuh and is specialized
130-
* for the distribution kernels here.
131-
* - For big size tensors, we can launch multiple kernels recursively
132-
* (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox
133-
* offset calculation is done in this function.
134-
*
135-
* FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh
136-
* to have grid-stride loop kernel and then use that to launch our distribution
137-
* kernels? Note that we need a grid-stride loop kernel because, we found by testing
138-
* that it achieves peak effective bandwidth.
139-
*/
140123
template<typename scalar_t,
141124
typename accscalar_t,
142125
int unroll_factor,
@@ -492,30 +475,6 @@ void random_kernel_cuda(TensorIterator& iter, uint64_t range, int64_t base, Gene
492475
});
493476
}
494477

495-
void normal_kernel_cuda(TensorIterator& iter, double mean_, double std_, Generator* gen_) {
496-
auto gen = check_generator<CUDAGenerator>(gen_, &globalContext().defaultGenerator(kCUDA));
497-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "normal_cuda", [&] {
498-
using accscalar_t = at::acc_type<scalar_t, true>;
499-
auto mean = static_cast<accscalar_t>(mean_);
500-
auto std = static_cast<accscalar_t>(std_);
501-
// define lambda to multiply std and add mean
502-
auto normal_func = [mean, std] __device__ (accscalar_t rand) {
503-
return static_cast<scalar_t>(rand * std + mean);
504-
};
505-
if (std::is_same<scalar_t, double>::value) {
506-
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
507-
gen,
508-
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal2_double(state); },
509-
normal_func);
510-
} else {
511-
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter,
512-
gen,
513-
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal4(state); },
514-
normal_func);
515-
}
516-
});
517-
}
518-
519478
Tensor& uniform_cuda_(Tensor& self, double from, double to, Generator* gen) {
520479
auto iter = TensorIterator::nullary_op(self);
521480
uniform_kernel_cuda(*iter, from, to, gen);
@@ -551,48 +510,4 @@ Tensor& capped_random_cuda_(Tensor& self, int64_t to, Generator* gen) {
551510
return clamped_random_cuda_(self, 0, to, gen);
552511
}
553512

554-
Tensor& normal_cuda_(Tensor& self, double mean, double std, Generator* gen) {
555-
TORCH_CHECK(std > 0.0, "normal_ expects std > 0.0, but found std=", std);
556-
auto iter = TensorIterator::nullary_op(self);
557-
normal_kernel_cuda(*iter, mean, std, gen);
558-
return self;
559-
}
560-
561-
Tensor& normal_out_cuda(Tensor& output, const Tensor& mean, double std, Generator* gen) {
562-
normal_cuda_(output, 0, std, gen);
563-
output.add_(mean);
564-
return output;
565-
}
566-
567-
Tensor& normal_out_cuda(Tensor& output, double mean, const Tensor& std, Generator* gen) {
568-
normal_cuda_(output, 0, 1, gen);
569-
auto mean_tensor = at::full({1}, mean, output.options());
570-
at::native::legacy::cuda::_th_addcmul_out(output, mean_tensor, output, std, 1);
571-
return output;
572-
}
573-
574-
Tensor& normal_out_cuda(Tensor& output, const Tensor& mean, const Tensor& std, Generator* gen) {
575-
normal_cuda_(output, 0, 1, gen);
576-
at::native::legacy::cuda::_th_addcmul_out(output, mean, output, std, 1);
577-
return output;
578-
}
579-
580-
Tensor normal_cuda(const Tensor& mean, double std, Generator* gen) {
581-
Tensor ret = at::empty(mean.sizes(), mean.options());
582-
normal_out_cuda(ret, mean, std, gen);
583-
return ret;
584-
}
585-
586-
Tensor normal_cuda(double mean, const Tensor& std, Generator* gen) {
587-
Tensor ret = at::empty(std.sizes(), std.options());
588-
normal_out_cuda(ret, mean, std, gen);
589-
return ret;
590-
}
591-
592-
Tensor normal_cuda(const Tensor& mean, const Tensor& std, Generator* gen) {
593-
Tensor ret = at::empty(mean.sizes(), mean.options());
594-
normal_out_cuda(ret, mean, std, gen);
595-
return ret;
596-
}
597-
598513
}} // namespace at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3181,7 +3181,7 @@
31813181
variants: method
31823182
dispatch:
31833183
CPU: legacy::cpu::_th_normal_
3184-
CUDA: normal_cuda_
3184+
CUDA: legacy::cuda::_th_normal_
31853185

31863186
- func: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)
31873187
variants: method
@@ -3923,32 +3923,32 @@
39233923
- func: normal(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
39243924
dispatch:
39253925
CPU: legacy::cpu::_th_normal_out
3926-
CUDA: normal_out_cuda
3926+
CUDA: legacy::cuda::_th_normal_out
39273927

39283928
- func: normal(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor
39293929
dispatch:
39303930
CPU: legacy::cpu::_th_normal
3931-
CUDA: normal_cuda
3931+
CUDA: legacy::cuda::_th_normal
39323932

39333933
- func: normal(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
39343934
dispatch:
39353935
CPU: legacy::cpu::_th_normal_out
3936-
CUDA: normal_out_cuda
3936+
CUDA: legacy::cuda::_th_normal_out
39373937

39383938
- func: normal(float mean, Tensor std, *, Generator? generator=None) -> Tensor
39393939
dispatch:
39403940
CPU: legacy::cpu::_th_normal
3941-
CUDA: normal_cuda
3941+
CUDA: legacy::cuda::_th_normal
39423942

39433943
- func: normal(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
39443944
dispatch:
39453945
CPU: legacy::cpu::_th_normal_out
3946-
CUDA: normal_out_cuda
3946+
CUDA: legacy::cuda::_th_normal_out
39473947

39483948
- func: normal(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor
39493949
dispatch:
39503950
CPU: legacy::cpu::_th_normal
3951-
CUDA: normal_cuda
3951+
CUDA: legacy::cuda::_th_normal
39523952

39533953
- func: alias(Tensor(a) self) -> Tensor(a)
39543954
variants: method, function

aten/src/THC/THCTensorRandom.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,16 @@ __global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1, ARG2)
129129
} \
130130
}
131131

132+
GENERATE_KERNEL2(generate_normal, float, double mean, double stdv, float, curand_normal, (x * stdv) + mean)
133+
GENERATE_KERNEL2(generate_normal, double, double mean, double stdv, double, curand_normal_double, (x * stdv) + mean)
134+
132135
GENERATE_KERNEL1(generate_exponential, float, double lambda, float, curand_uniform, (float)(-1. / lambda * log(x)))
133136
GENERATE_KERNEL1(generate_exponential, double, double lambda, double, curand_uniform_double, (double)(-1. / lambda * log(x)))
134137

135138
GENERATE_KERNEL2(generate_cauchy, float, double median, double sigma, float, curand_uniform, (float)(median + sigma * tan(M_PI*(x-0.5))))
136139
GENERATE_KERNEL2(generate_cauchy, double, double median, double sigma, double, curand_uniform_double, (double)(median + sigma * tan(M_PI*(x-0.5))))
137140

141+
GENERATE_KERNEL2(generate_normal, at::Half, double mean, double stdv, float, curand_normal, (ScalarConvert<float, at::Half>::to((x * stdv) + mean)))
138142
GENERATE_KERNEL1(generate_exponential, at::Half, double lambda, float, curand_uniform, (ScalarConvert<float, at::Half>::to((float)(-1. / lambda * log(x)))))
139143
GENERATE_KERNEL2(generate_cauchy, at::Half, double median, double sigma, float, curand_uniform, (ScalarConvert<float, at::Half>::to((float)(median + sigma * tan(M_PI*(x-0.5))))))
140144

aten/src/THC/generic/THCTensorRandom.cu

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,42 @@
88

99
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
1010

11+
void THCTensor_(normal)(THCState* state, THCTensor *self_, double mean, double stdv)
12+
{
13+
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self_));
14+
ptrdiff_t size = THCTensor_(nElement)(state, self_);
15+
if (size == 0) return;
16+
THCGenerator* gen = THCRandom_getGenerator(state);
17+
THCTensor *self = THCTensor_(newContiguous)(state, self_);
18+
scalar_t *data = THCTensor_(data)(state, self);
19+
20+
generate_normal<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
21+
gen->state.gen_states, size, data, mean, stdv);
22+
23+
THCTensor_(freeCopyTo)(state, self, self_);
24+
};
25+
26+
void THCTensor_(normal_means)(THCState *state, THCTensor *self, THCTensor *means, double stddev) {
27+
THCTensor_(resizeAs)(state, self, means);
28+
THCTensor_(normal)(state, self, 0, stddev);
29+
THCTensor_(cadd)(state, self, self, ScalarConvert<int, scalar_t>::to(1), means);
30+
}
31+
32+
void THCTensor_(normal_stddevs)(THCState *state, THCTensor *self, double mean, THCTensor *stddevs)
33+
{
34+
THCTensor_(resizeAs)(state, self, stddevs);
35+
THCTensor_(normal)(state, self, 0, 1);
36+
THCTensor_(cmul)(state, self, self, stddevs);
37+
THCTensor_(add)(state, self, self, ScalarConvert<double, scalar_t>::to(mean));
38+
}
39+
40+
void THCTensor_(normal_means_stddevs)(THCState *state, THCTensor *self, THCTensor *means, THCTensor *stddevs)
41+
{
42+
THCTensor_(resizeAs)(state, self, means);
43+
THCTensor_(normal)(state, self, 0, 1);
44+
THCTensor_(cmul)(state, self, self, stddevs);
45+
THCTensor_(cadd)(state, self, self, ScalarConvert<int, scalar_t>::to(1), means);
46+
}
1147

1248
void THCTensor_(logNormal)(THCState* state, THCTensor *self_, double mean, double stdv)
1349
{

aten/src/THC/generic/THCTensorRandom.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
66

7+
THC_API void THCTensor_(normal)(struct THCState *state, THCTensor *self, double mean, double stdv);
8+
THC_API void THCTensor_(normal_means)(struct THCState *state, THCTensor *self, THCTensor *means, double stddev);
9+
THC_API void THCTensor_(normal_stddevs)(struct THCState *state, THCTensor *self, double mean, THCTensor *stddevs);
10+
THC_API void THCTensor_(normal_means_stddevs)(struct THCState *state, THCTensor *self, THCTensor *means, THCTensor *stddevs);
711
THC_API void THCTensor_(logNormal)(struct THCState *state, THCTensor *self, double mean, double stdv);
812
THC_API void THCTensor_(exponential)(struct THCState *state, THCTensor *self, double lambda);
913
THC_API void THCTensor_(cauchy)(struct THCState *state, THCTensor *self, double median, double sigma);

0 commit comments

Comments
 (0)