Skip to content

Commit 24e958a

Browse files
ssnlfacebook-github-bot
authored andcommitted
Move bernoulli into ATen (#10273)
Summary: + #10236 : torch.bernoulli's out kwarg is broken fixed in moving `bernoulli_out` to ATen + #9917 : BUG torch.bernoulli(p.expand(shape)) is broken fixed in moving all `bernoulli` ops in ATen to use the modern apply utils methods + #10357 : torch.bernoulli inconsistent gpu/cpu results fixed by adding CUDA asserts In order to use `curand_uniform4`, I made some changes to `CUDAApplyUtils.cuh`. Specifically, I introduced an optional template parameter `int step` to the `CUDA_tensor_applyN` methods, representing that we want to process `step` values at each time for each of the `N` tensors. The calling convention for `step = 1` (default) isn't changed. But if `step > 1`, the given lambda `op` must take in `int n` as its first argument, representing the number of valid values, because there may not be full `step` values at the boundary. E.g., here is what the `bernoulli(self, p_tensor)` call look like: ```cpp // The template argument `4` below indicates that we want to operate on four // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details. at::cuda::CUDA_tensor_apply2<scalar_t, prob_t, 4>( ret, p, [seeds] __device__( int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4, const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) { curandStatePhilox4_32_10_t state; curand_init( seeds.first, blockIdx.x * blockDim.x + threadIdx.x, seeds.second, &state); float4 rand = curand_uniform4(&state); switch (n) { case 4: { assert(0 <= p4 && p4 <= 1); v4 = static_cast<scalar_t>(rand.w <= p4); } case 3: { assert(0 <= p3 && p3 <= 1); v3 = static_cast<scalar_t>(rand.z <= p3); } case 2: { assert(0 <= p2 && p2 <= 1); v2 = static_cast<scalar_t>(rand.y <= p2); } case 1: { assert(0 <= p1 && p1 <= 1); v1 = static_cast<scalar_t>(rand.x <= p1); } } } ); ``` Benchmarking on `torch.rand(200, 300, 400)` 20 times, each time with 20 loops: post patch ``` ➜ ~ numactl --cpunodebind 1 --membind 1 -- taskset -c 12,13,14,15,16,17,18,19,20,21,22,23 env CUDA_LAUNCH_BLOCKING=1 python bern.py torch.bernoulli(x) 6.841588497161865 +- 0.05413117632269859 torch.bernoulli(xc) 0.05963418632745743 +- 0.0008014909108169377 x.bernoulli_() 0.4024486541748047 +- 0.0021550932433456182 xc.bernoulli_() 0.02167394384741783 +- 2.3818030967959203e-05 ``` pre-patch ``` ➜ ~ numactl --cpunodebind 1 --membind 1 -- taskset -c 12,13,14,15,16,17,18,19,20,21,22,23 env CUDA_LAUNCH_BLOCKING=1 python bern.py torch.bernoulli(x) 12.394511222839355 +- 0.0966421514749527 torch.bernoulli(xc) 0.08970972150564194 +- 0.0038722590543329716 x.bernoulli_() 1.654480218887329 +- 0.02364428900182247 xc.bernoulli_() 0.058352887630462646 +- 0.003094920190051198 ``` Pull Request resolved: #10273 Differential Revision: D9831294 Pulled By: SsnL fbshipit-source-id: 65e0655a36b90d5278b675d35cb5327751604088
1 parent cf5a21e commit 24e958a

37 files changed

+1098
-623
lines changed

aten/src/ATen/CPUApplyUtils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
207207
for (size_t i = 0; i < tensors.size() - 1; i++) {
208208
oss << tensors[i].sizes() << ", ";
209209
}
210-
oss << "and " << tensors[tensors.size() - 1]
210+
oss << "and " << tensors[tensors.size() - 1].sizes()
211211
<< " to have the same number of elements, but got ";
212212
for (size_t i = 0; i < tensors.size() - 1; i++) {
213213
oss << tensors[i].numel() << ", ";
@@ -220,7 +220,7 @@ inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
220220
inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
221221
checkBackend("CPU_tensor_apply", tensors, Backend::CPU);
222222
if (!_all_equal_numel(tensors))
223-
throw std::runtime_error(_all_equal_numel_error(tensors));
223+
AT_ERROR(_all_equal_numel_error(tensors));
224224
// An empty tensor has no elements
225225
for (auto& t : tensors)
226226
if (t.numel() == 0)

aten/src/ATen/Declarations.cwrap

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3218,38 +3218,6 @@
32183218
kwarg_only: True
32193219
- double p
32203220
]]
3221-
[[
3222-
name: _bernoulli_
3223-
backends:
3224-
- CPU
3225-
- CUDA
3226-
cname: bernoulli
3227-
return: self
3228-
variants: function
3229-
arguments:
3230-
- THTensor* self
3231-
- arg: THGenerator* generator
3232-
default: nullptr
3233-
kwarg_only: True
3234-
- double p
3235-
]]
3236-
[[
3237-
name: _th_bernoulli
3238-
types:
3239-
- Float
3240-
- Double
3241-
return: argument 0
3242-
variants: function
3243-
cname: bernoulli_Tensor
3244-
arguments:
3245-
- arg: THTensor* output
3246-
output: True
3247-
resize: self
3248-
- arg: THGenerator* generator
3249-
default: nullptr
3250-
kwarg_only: True
3251-
- THTensor* self
3252-
]]
32533221
[[
32543222
name: _dirichlet_grad
32553223
types:

aten/src/ATen/core/Tensor.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -441,12 +441,10 @@ struct AT_API Tensor {
441441
Tensor & atan_();
442442
Tensor baddbmm(const Tensor & batch1, const Tensor & batch2, Scalar beta=1, Scalar alpha=1) const;
443443
Tensor & baddbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta=1, Scalar alpha=1);
444-
Tensor bernoulli(const Tensor & p, Generator * generator=nullptr) const;
445-
Tensor bernoulli(double p, Generator * generator=nullptr) const;
446-
Tensor bernoulli() const;
444+
Tensor bernoulli(Generator * generator=nullptr) const;
447445
Tensor & bernoulli_(const Tensor & p, Generator * generator=nullptr);
448-
Tensor & bernoulli_(double p, Generator * generator=nullptr);
449-
Tensor & bernoulli_();
446+
Tensor & bernoulli_(double p=0.5, Generator * generator=nullptr);
447+
Tensor bernoulli(double p, Generator * generator=nullptr) const;
450448
Tensor bincount(const Tensor & weights={}, int64_t minlength=0) const;
451449
Tensor bmm(const Tensor & mat2) const;
452450
Tensor ceil() const;

aten/src/ATen/core/TensorMethods.h

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -605,23 +605,17 @@ inline Tensor Tensor::baddbmm(const Tensor & batch1, const Tensor & batch2, Scal
605605
inline Tensor & Tensor::baddbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) {
606606
return type().baddbmm_(*this, batch1, batch2, beta, alpha);
607607
}
608-
inline Tensor Tensor::bernoulli(const Tensor & p, Generator * generator) const {
609-
return type().bernoulli(*this, p, generator);
610-
}
611-
inline Tensor Tensor::bernoulli(double p, Generator * generator) const {
612-
return type().bernoulli(*this, p, generator);
613-
}
614-
inline Tensor Tensor::bernoulli() const {
615-
return type().bernoulli(*this);
608+
inline Tensor Tensor::bernoulli(Generator * generator) const {
609+
return type().bernoulli(*this, generator);
616610
}
617611
inline Tensor & Tensor::bernoulli_(const Tensor & p, Generator * generator) {
618612
return type().bernoulli_(*this, p, generator);
619613
}
620614
inline Tensor & Tensor::bernoulli_(double p, Generator * generator) {
621615
return type().bernoulli_(*this, p, generator);
622616
}
623-
inline Tensor & Tensor::bernoulli_() {
624-
return type().bernoulli_(*this);
617+
inline Tensor Tensor::bernoulli(double p, Generator * generator) const {
618+
return type().bernoulli(*this, p, generator);
625619
}
626620
inline Tensor Tensor::bincount(const Tensor & weights, int64_t minlength) const {
627621
return type().bincount(*this, weights, minlength);

aten/src/ATen/core/Type.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -397,12 +397,10 @@ struct AT_API Type {
397397
virtual Tensor & atan_(Tensor & self) const = 0;
398398
virtual Tensor baddbmm(const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const = 0;
399399
virtual Tensor & baddbmm_(Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const = 0;
400-
virtual Tensor bernoulli(const Tensor & self, const Tensor & p, Generator * generator) const = 0;
401-
virtual Tensor bernoulli(const Tensor & self, double p, Generator * generator) const = 0;
402-
virtual Tensor bernoulli(const Tensor & self) const = 0;
400+
virtual Tensor bernoulli(const Tensor & self, Generator * generator) const = 0;
403401
virtual Tensor & bernoulli_(Tensor & self, const Tensor & p, Generator * generator) const = 0;
404402
virtual Tensor & bernoulli_(Tensor & self, double p, Generator * generator) const = 0;
405-
virtual Tensor & bernoulli_(Tensor & self) const = 0;
403+
virtual Tensor bernoulli(const Tensor & self, double p, Generator * generator) const = 0;
406404
virtual Tensor bincount(const Tensor & self, const Tensor & weights, int64_t minlength) const = 0;
407405
virtual Tensor bmm(const Tensor & self, const Tensor & mat2) const = 0;
408406
virtual Tensor ceil(const Tensor & self) const = 0;

aten/src/ATen/cpu/vec256/vec256_base.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,4 +438,14 @@ interleave2(const Vec256<T>& a, const Vec256<T>& b) {
438438
Vec256<T>::loadu(static_cast<void*>(buffer2)));
439439
}
440440

441+
template <typename src_T, typename dst_T>
442+
void convert(const src_T *src, dst_T *dst, int64_t n) {
443+
#pragma unroll
444+
for (int64_t i = 0; i < n; i++) {
445+
*dst = static_cast<dst_T>(*src);
446+
src++;
447+
dst++;
448+
}
449+
}
450+
441451
}}}

aten/src/ATen/cpu/vec256/vec256_int.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,38 @@ struct Vec256<int32_t> : public Vec256i {
208208
}
209209
};
210210

211+
template <>
212+
void convert(const int32_t *src, float *dst, int64_t n) {
213+
int64_t i;
214+
// int32_t and float have same size
215+
#pragma unroll
216+
for (i = 0; i <= (n - Vec256<int32_t>::size); i += Vec256<int32_t>::size) {
217+
auto input_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
218+
auto output_vec = _mm256_cvtepi32_ps(input_vec);
219+
_mm256_storeu_ps(reinterpret_cast<float*>(dst + i), output_vec);
220+
}
221+
#pragma unroll
222+
for (; i < n; i++) {
223+
dst[i] = static_cast<float>(src[i]);
224+
}
225+
}
226+
227+
template <>
228+
void convert(const int32_t *src, double *dst, int64_t n) {
229+
int64_t i;
230+
// int32_t has half the size of double
231+
#pragma unroll
232+
for (i = 0; i <= (n - Vec256<double>::size); i += Vec256<double>::size) {
233+
auto input_128_vec = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
234+
auto output_vec = _mm256_cvtepi32_pd(input_128_vec);
235+
_mm256_storeu_pd(reinterpret_cast<double*>(dst + i), output_vec);
236+
}
237+
#pragma unroll
238+
for (; i < n; i++) {
239+
dst[i] = static_cast<double>(src[i]);
240+
}
241+
}
242+
211243
template <>
212244
struct Vec256<int16_t> : public Vec256i {
213245
static constexpr int size = 16;

0 commit comments

Comments
 (0)