Skip to content

Conversation

@ssnl
Copy link
Collaborator

@ssnl ssnl commented Aug 6, 2018

Fixes

Notable changes:

In order to use curand_uniform4, I made some changes to CUDAApplyUtils.cuh. Specifically, I introduced an optional template parameter int step to the CUDA_tensor_applyN methods, representing that we want to process step values at each time for each of the N tensors.

The calling convention for step = 1 (default) isn't changed. But if step > 1, the given lambda op must take in int n as its first argument, representing the number of valid values, because there may not be full step values at the boundary. E.g., here is what the bernoulli(self, p_tensor) call look like:

  // The template argument `4` below indicates that we want to operate on four
  // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
  at::cuda::CUDA_tensor_apply2<scalar_t, prob_t, 4>(
      ret, p,
      [seeds] __device__(
          int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4,
          const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) {
        curandStatePhilox4_32_10_t state;
        curand_init(
            seeds.first,
            blockIdx.x * blockDim.x + threadIdx.x,
            seeds.second,
            &state);
        float4 rand = curand_uniform4(&state);
        switch (n) {
          case 4: {
            assert(0 <= p4 && p4 <= 1);
            v4 = static_cast<scalar_t>(rand.w <= p4);
          }
          case 3: {
            assert(0 <= p3 && p3 <= 1);
            v3 = static_cast<scalar_t>(rand.z <= p3);
          }
          case 2: {
            assert(0 <= p2 && p2 <= 1);
            v2 = static_cast<scalar_t>(rand.y <= p2);
          }
          case 1: {
            assert(0 <= p1 && p1 <= 1);
            v1 = static_cast<scalar_t>(rand.x <= p1);
          }
        }
      }
    );

Benchmarking

Benchmarking on torch.rand(200, 300, 400) 20 times, each time with 20 loops:

post patch

➜  ~ numactl --cpunodebind 1 --membind 1 -- taskset -c 12,13,14,15,16,17,18,19,20,21,22,23 env CUDA_LAUNCH_BLOCKING=1 python bern.py
torch.bernoulli(x)
6.841588497161865 +- 0.05413117632269859
torch.bernoulli(xc)
0.05963418632745743 +- 0.0008014909108169377
x.bernoulli_()
0.4024486541748047 +- 0.0021550932433456182
xc.bernoulli_()
0.02167394384741783 +- 2.3818030967959203e-05

pre-patch

➜  ~ numactl --cpunodebind 1 --membind 1 -- taskset -c 12,13,14,15,16,17,18,19,20,21,22,23 env CUDA_LAUNCH_BLOCKING=1 python bern.py
torch.bernoulli(x)
12.394511222839355 +- 0.0966421514749527
torch.bernoulli(xc)
0.08970972150564194 +- 0.0038722590543329716
x.bernoulli_()
1.654480218887329 +- 0.02364428900182247
xc.bernoulli_()
0.058352887630462646 +- 0.003094920190051198

This comment was marked as off-topic.

@ssnl ssnl changed the title Move bernoulli into ATen [wip] Move bernoulli into ATen Aug 6, 2018

This comment was marked as off-topic.

@ssnl ssnl force-pushed the bern_ctor branch 3 times, most recently from ed451c1 to 70ae05c Compare August 10, 2018 19:40

This comment was marked as off-topic.

@ssnl ssnl force-pushed the bern_ctor branch 4 times, most recently from 3350db8 to af1686a Compare August 14, 2018 21:38
@fmassa
Copy link
Member

fmassa commented Aug 28, 2018

ping @ssnl

@ssnl
Copy link
Collaborator Author

ssnl commented Aug 28, 2018

yeah I'll fix windows error

@ssnl ssnl force-pushed the bern_ctor branch 3 times, most recently from 8b50449 to aefa656 Compare August 29, 2018 21:58

This comment was marked as off-topic.

@ssnl ssnl changed the title [wip] Move bernoulli into ATen [ready] Move bernoulli into ATen Sep 12, 2018
@ssnl ssnl force-pushed the bern_ctor branch 4 times, most recently from 56eff49 to a3fa723 Compare September 13, 2018 16:09
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SsnL has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Sep 20, 2018
Summary:
+ pytorch/pytorch#10236 : torch.bernoulli's out kwarg is broken
  fixed in moving `bernoulli_out` to ATen
+ pytorch/pytorch#9917 : BUG torch.bernoulli(p.expand(shape)) is broken
  fixed in moving all `bernoulli` ops in ATen to use the modern apply utils methods
+ pytorch/pytorch#10357 : torch.bernoulli inconsistent gpu/cpu results
  fixed by adding CUDA asserts

In order to use `curand_uniform4`, I made some changes to `CUDAApplyUtils.cuh`. Specifically, I introduced an optional template parameter `int step` to the `CUDA_tensor_applyN` methods, representing that we want to process `step` values at each time for each of the `N` tensors.

The calling convention for `step = 1` (default) isn't changed. But if `step > 1`, the given lambda `op` must take in `int n` as its first argument, representing the number of valid values, because there may not be full `step` values at the boundary. E.g., here is what the `bernoulli(self, p_tensor)` call look like:
```cpp

  // The template argument `4` below indicates that we want to operate on four
  // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
  at::cuda::CUDA_tensor_apply2<scalar_t, prob_t, 4>(
      ret, p,
      [seeds] __device__(
          int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4,
          const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) {
        curandStatePhilox4_32_10_t state;
        curand_init(
            seeds.first,
            blockIdx.x * blockDim.x + threadIdx.x,
            seeds.second,
            &state);
        float4 rand = curand_uniform4(&state);
        switch (n) {
          case 4: {
            assert(0 <= p4 && p4 <= 1);
            v4 = static_cast<scalar_t>(rand.w <= p4);
          }
          case 3: {
            assert(0 <= p3 && p3 <= 1);
            v3 = static_cast<scalar_t>(rand.z <= p3);
          }
          case 2: {
            assert(0 <= p2 && p2 <= 1);
            v2 = static_cast<scalar_t>(rand.y <= p2);
          }
          case 1: {
            assert(0 <= p1 && p1 <= 1);
            v1 = static_cast<scalar_t>(rand.x <= p1);
          }
        }
      }
    );
```

Benchmarking on `torch.rand(200, 300, 400)` 20 times, each time with 20 loops:

post patch
```
➜  ~ numactl --cpunodebind 1 --membind 1 -- taskset -c 12,13,14,15,16,17,18,19,20,21,22,23 env CUDA_LAUNCH_BLOCKING=1 python bern.py
torch.bernoulli(x)
6.841588497161865 +- 0.05413117632269859
torch.bernoulli(xc)
0.05963418632745743 +- 0.0008014909108169377
x.bernoulli_()
0.4024486541748047 +- 0.0021550932433456182
xc.bernoulli_()
0.02167394384741783 +- 2.3818030967959203e-05

```

pre-patch
```
➜  ~ numactl --cpunodebind 1 --membind 1 -- taskset -c 12,13,14,15,16,17,18,19,20,21,22,23 env CUDA_LAUNCH_BLOCKING=1 python bern.py
torch.bernoulli(x)
12.394511222839355 +- 0.0966421514749527
torch.bernoulli(xc)
0.08970972150564194 +- 0.0038722590543329716
x.bernoulli_()
1.654480218887329 +- 0.02364428900182247
xc.bernoulli_()
0.058352887630462646 +- 0.003094920190051198

```
Pull Request resolved: pytorch/pytorch#10273

Differential Revision: D9831294

Pulled By: SsnL

fbshipit-source-id: 65e0655a36b90d5278b675d35cb5327751604088
@ssnl ssnl deleted the bern_ctor branch September 21, 2018 20:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants