Commit 24e958a
Move bernoulli into ATen (#10273)
Summary:
+ #10236 : torch.bernoulli's out kwarg is broken
fixed in moving `bernoulli_out` to ATen
+ #9917 : BUG torch.bernoulli(p.expand(shape)) is broken
fixed in moving all `bernoulli` ops in ATen to use the modern apply utils methods
+ #10357 : torch.bernoulli inconsistent gpu/cpu results
fixed by adding CUDA asserts
In order to use `curand_uniform4`, I made some changes to `CUDAApplyUtils.cuh`. Specifically, I introduced an optional template parameter `int step` to the `CUDA_tensor_applyN` methods, representing that we want to process `step` values at each time for each of the `N` tensors.
The calling convention for `step = 1` (default) isn't changed. But if `step > 1`, the given lambda `op` must take in `int n` as its first argument, representing the number of valid values, because there may not be full `step` values at the boundary. E.g., here is what the `bernoulli(self, p_tensor)` call look like:
```cpp
// The template argument `4` below indicates that we want to operate on four
// element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
at::cuda::CUDA_tensor_apply2<scalar_t, prob_t, 4>(
ret, p,
[seeds] __device__(
int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4,
const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) {
curandStatePhilox4_32_10_t state;
curand_init(
seeds.first,
blockIdx.x * blockDim.x + threadIdx.x,
seeds.second,
&state);
float4 rand = curand_uniform4(&state);
switch (n) {
case 4: {
assert(0 <= p4 && p4 <= 1);
v4 = static_cast<scalar_t>(rand.w <= p4);
}
case 3: {
assert(0 <= p3 && p3 <= 1);
v3 = static_cast<scalar_t>(rand.z <= p3);
}
case 2: {
assert(0 <= p2 && p2 <= 1);
v2 = static_cast<scalar_t>(rand.y <= p2);
}
case 1: {
assert(0 <= p1 && p1 <= 1);
v1 = static_cast<scalar_t>(rand.x <= p1);
}
}
}
);
```
Benchmarking on `torch.rand(200, 300, 400)` 20 times, each time with 20 loops:
post patch
```
➜ ~ numactl --cpunodebind 1 --membind 1 -- taskset -c 12,13,14,15,16,17,18,19,20,21,22,23 env CUDA_LAUNCH_BLOCKING=1 python bern.py
torch.bernoulli(x)
6.841588497161865 +- 0.05413117632269859
torch.bernoulli(xc)
0.05963418632745743 +- 0.0008014909108169377
x.bernoulli_()
0.4024486541748047 +- 0.0021550932433456182
xc.bernoulli_()
0.02167394384741783 +- 2.3818030967959203e-05
```
pre-patch
```
➜ ~ numactl --cpunodebind 1 --membind 1 -- taskset -c 12,13,14,15,16,17,18,19,20,21,22,23 env CUDA_LAUNCH_BLOCKING=1 python bern.py
torch.bernoulli(x)
12.394511222839355 +- 0.0966421514749527
torch.bernoulli(xc)
0.08970972150564194 +- 0.0038722590543329716
x.bernoulli_()
1.654480218887329 +- 0.02364428900182247
xc.bernoulli_()
0.058352887630462646 +- 0.003094920190051198
```
Pull Request resolved: #10273
Differential Revision: D9831294
Pulled By: SsnL
fbshipit-source-id: 65e0655a36b90d5278b675d35cb53277516040881 parent cf5a21e commit 24e958a
File tree
37 files changed
+1098
-623
lines changed- aten/src
- ATen
- core
- cpu/vec256
- cuda
- native
- cpu
- cuda
- THC
- generic
- TH
- generic
- vector
- test
- tools/autograd
- torch
- csrc/jit
- passes
- nn
37 files changed
+1098
-623
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
207 | 207 | | |
208 | 208 | | |
209 | 209 | | |
210 | | - | |
| 210 | + | |
211 | 211 | | |
212 | 212 | | |
213 | 213 | | |
| |||
220 | 220 | | |
221 | 221 | | |
222 | 222 | | |
223 | | - | |
| 223 | + | |
224 | 224 | | |
225 | 225 | | |
226 | 226 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3218 | 3218 | | |
3219 | 3219 | | |
3220 | 3220 | | |
3221 | | - | |
3222 | | - | |
3223 | | - | |
3224 | | - | |
3225 | | - | |
3226 | | - | |
3227 | | - | |
3228 | | - | |
3229 | | - | |
3230 | | - | |
3231 | | - | |
3232 | | - | |
3233 | | - | |
3234 | | - | |
3235 | | - | |
3236 | | - | |
3237 | | - | |
3238 | | - | |
3239 | | - | |
3240 | | - | |
3241 | | - | |
3242 | | - | |
3243 | | - | |
3244 | | - | |
3245 | | - | |
3246 | | - | |
3247 | | - | |
3248 | | - | |
3249 | | - | |
3250 | | - | |
3251 | | - | |
3252 | | - | |
3253 | 3221 | | |
3254 | 3222 | | |
3255 | 3223 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
441 | 441 | | |
442 | 442 | | |
443 | 443 | | |
444 | | - | |
445 | | - | |
446 | | - | |
| 444 | + | |
447 | 445 | | |
448 | | - | |
449 | | - | |
| 446 | + | |
| 447 | + | |
450 | 448 | | |
451 | 449 | | |
452 | 450 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
605 | 605 | | |
606 | 606 | | |
607 | 607 | | |
608 | | - | |
609 | | - | |
610 | | - | |
611 | | - | |
612 | | - | |
613 | | - | |
614 | | - | |
615 | | - | |
| 608 | + | |
| 609 | + | |
616 | 610 | | |
617 | 611 | | |
618 | 612 | | |
619 | 613 | | |
620 | 614 | | |
621 | 615 | | |
622 | 616 | | |
623 | | - | |
624 | | - | |
| 617 | + | |
| 618 | + | |
625 | 619 | | |
626 | 620 | | |
627 | 621 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
397 | 397 | | |
398 | 398 | | |
399 | 399 | | |
400 | | - | |
401 | | - | |
402 | | - | |
| 400 | + | |
403 | 401 | | |
404 | 402 | | |
405 | | - | |
| 403 | + | |
406 | 404 | | |
407 | 405 | | |
408 | 406 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
438 | 438 | | |
439 | 439 | | |
440 | 440 | | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
441 | 451 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
208 | 208 | | |
209 | 209 | | |
210 | 210 | | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
211 | 243 | | |
212 | 244 | | |
213 | 245 | | |
| |||
0 commit comments