Skip to content

Conversation

@nickgg
Copy link
Contributor

@nickgg nickgg commented Sep 15, 2020

Unifies a number of partial solutions to the thread and block dimension extent masking, including the NoThreadIdxWriter and my last fix #44325. The NoThreadIdxWriter is gone in favour of tracking the current loop extents and masking any statements that have a lower rank than the launch parameters in any Block or Thread dimension, which handles both the "no" and "smaller" axis binding cases.

For example it will transform the following:

for i in 0..10 // blockIdx.x
  for j in 0..10 // threadIdx.x
    do thing(i, j);
  for k in 0..5 // threadIdx.x
    do other thing(i, k);

Into:

do thing(blockIdx.x, threadIdx.x);
if (threadIdx.x < 5) {
  do other thing(blockIdx.x, threadIdx.x);
}

And handle the case where statements are not bound by any axis, eg.

do outer thing;
for i in 0..10 // blockIdx.x
  for j in 0..10 // threadIdx.x
    do thing(i, j);
  do other thing(i);

will become:

if (blockIdx.x < 1) {
  if (threadIdx.x < 1) {
    do outer thing;
  }
}
syncthreads();
do thing(blockIdx.x, threadIdx.x);
syncthreads();
if (threadIdx.x < 1) {
  do other thing(blockIdx.x);
}

@nickgg nickgg requested a review from apaszke as a code owner September 15, 2020 20:19
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Sep 15, 2020
@dr-ci
Copy link

dr-ci bot commented Sep 15, 2020

💊 CI failures summary and remediations

As of commit b6a686f (more details on the Dr. CI page):



1 failure not recognized by patterns:

Job Step Action
CircleCI pytorch_linux_xenial_cuda10_2_cudnn7_py3_ge_config_profiling_test Spin up environment 🔁 rerun

❄️ 3 failures tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_test (1/3)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun) ❄️

Sep 16 19:17:02 CondaHTTPError: HTTP 000 CONNECTION FAILED for url
Sep 16 19:15:47 ++ [[ pytorch-xla-linux-bionic-py3.6-clang9-test == *pytorch-linux-xenial-cuda10.1-cudnn7-py3* ]] 
Sep 16 19:15:47 ++ [[ pytorch-xla-linux-bionic-py3.6-clang9-test == *pytorch-linux-trusty-py3.6-gcc7* ]] 
Sep 16 19:15:47 ++ [[ pytorch-xla-linux-bionic-py3.6-clang9-test == *pytorch_macos* ]] 
Sep 16 19:15:47 ++ BUILD_TEST_LIBTORCH=0 
Sep 16 19:15:47 ++ [[ pytorch-xla-linux-bionic-py3.6-clang9-test == *pytorch-xla-linux-bionic* ]] 
Sep 16 19:15:47 ++ which conda 
Sep 16 19:15:47 /opt/conda/bin/conda 
Sep 16 19:15:47 ++ conda install -q -y cmake 
Sep 16 19:17:02 Collecting package metadata (current_repodata.json): ...working... failed 
Sep 16 19:17:02  
Sep 16 19:17:02 CondaHTTPError: HTTP 000 CONNECTION FAILED for url <https://repo.anaconda.com/pkgs/main/noarch/current_repodata.json> 
Sep 16 19:17:02 Elapsed: - 
Sep 16 19:17:02  
Sep 16 19:17:02 An HTTP error occurred when trying to retrieve this URL. 
Sep 16 19:17:02 HTTP errors are often intermittent, and a simple retry will get you on your way. 
Sep 16 19:17:02  
Sep 16 19:17:02 If your current network has https://www.anaconda.com blocked, please file 
Sep 16 19:17:02 a support request with your network engineering team. 
Sep 16 19:17:02  
Sep 16 19:17:02 'https://repo.anaconda.com/pkgs/main/noarch' 
Sep 16 19:17:02  

See CircleCI build pytorch_linux_bionic_py3_6_clang9_build (2/3)

Step: "Build" (full log | diagnosis details | 🔁 rerun) ❄️

Sep 16 18:37:29 CondaHTTPError: HTTP 000 CONNECTION FAILED for url
Sep 16 18:37:29   rhash              pkgs/main/linux-64::rhash-1.3.8-h1ba5d50_0 
Sep 16 18:37:29  
Sep 16 18:37:29  
Sep 16 18:37:29  
Sep 16 18:37:29 CondaHTTPError: HTTP 000 CONNECTION FAILED for url <https://repo.anaconda.com/pkgs/main/linux-64/cmake-3.14.0-h52cb24c_0.conda> 
Sep 16 18:37:29 Elapsed: - 
Sep 16 18:37:29  
Sep 16 18:37:29 An HTTP error occurred when trying to retrieve this URL. 
Sep 16 18:37:29 HTTP errors are often intermittent, and a simple retry will get you on your way. 
Sep 16 18:37:29  
Sep 16 18:37:29 CondaHTTPError: HTTP 000 CONNECTION FAILED for url <https://repo.anaconda.com/pkgs/main/linux-64/krb5-1.18.2-h173b8e3_0.conda> 
Sep 16 18:37:29 Elapsed: - 
Sep 16 18:37:29  
Sep 16 18:37:29 An HTTP error occurred when trying to retrieve this URL. 
Sep 16 18:37:29 HTTP errors are often intermittent, and a simple retry will get you on your way. 
Sep 16 18:37:29  
Sep 16 18:37:29  
Sep 16 18:37:29 =================== sccache compilation log =================== 
Sep 16 18:37:29 + cleanup 
Sep 16 18:37:29 + retcode=1 
Sep 16 18:37:29 + set +x 

See CircleCI build pytorch_linux_bionic_rocm3_7_py3_6_build (3/3)

Step: "Build" (full log | diagnosis details | 🔁 rerun) ❄️

Sep 16 18:38:55 CondaHTTPError: HTTP 000 CONNECTION FAILED for url
Sep 16 18:22:51 ++ conda install -q -y cmake 
Sep 16 18:23:45 Collecting package metadata (current_repodata.json): ...working... done 
Sep 16 18:23:46 Solving environment: ...working... done 
Sep 16 18:38:55  
Sep 16 18:38:55 CondaHTTPError: HTTP 000 CONNECTION FAILED for url <https://repo.anaconda.com/pkgs/main/linux-64/cmake-3.14.0-h52cb24c_0.conda> 
Sep 16 18:38:55 Elapsed: - 
Sep 16 18:38:55  
Sep 16 18:38:55 An HTTP error occurred when trying to retrieve this URL. 
Sep 16 18:38:55 HTTP errors are often intermittent, and a simple retry will get you on your way. 
Sep 16 18:38:55  
Sep 16 18:38:55 CondaHTTPError: HTTP 000 CONNECTION FAILED for url <https://repo.anaconda.com/pkgs/main/linux-64/krb5-1.18.2-h173b8e3_0.conda> 
Sep 16 18:38:55 Elapsed: - 
Sep 16 18:38:55  
Sep 16 18:38:55 An HTTP error occurred when trying to retrieve this URL. 
Sep 16 18:38:55 HTTP errors are often intermittent, and a simple retry will get you on your way. 
Sep 16 18:38:55  
Sep 16 18:38:55  
Sep 16 18:38:55  
Sep 16 18:38:55 ## Package Plan ## 
Sep 16 18:38:55  
Sep 16 18:38:55   environment location: /opt/conda 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 8 times.

@codecov
Copy link

codecov bot commented Sep 15, 2020

Codecov Report

Merging #44733 into master will decrease coverage by 0.00%.
The diff coverage is n/a.

Impacted file tree graph

@@            Coverage Diff             @@
##           master   #44733      +/-   ##
==========================================
- Coverage   68.08%   68.08%   -0.01%     
==========================================
  Files         384      384              
  Lines       49774    49768       -6     
==========================================
- Hits        33890    33883       -7     
- Misses      15884    15885       +1     
Impacted Files Coverage Δ
torch/fx/symbolic_trace.py 93.69% <0.00%> (-1.66%) ⬇️
torch/fx/graph.py 96.38% <0.00%> (-0.29%) ⬇️
torch/optim/lr_scheduler.py 88.77% <0.00%> (+0.04%) ⬆️
torch/fx/proxy.py 93.10% <0.00%> (+0.44%) ⬆️
torch/nn/parallel/distributed.py 41.75% <0.00%> (+0.61%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update b85568a...b6a686f. Read the comment docs.

Copy link
Contributor

@zheng-xq zheng-xq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor comments. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: I am a bit troubled by the term "rank" here. Until you see it somewhere in Cuda guide, most places use the term "rank" to refer to the rank of a tensor. I would prefer to pick a different name here. But it is up to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, happy to - any suggestions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to avoid "dimension", "extent" and "size" because they're overloaded as well. I landed on "reach" ??

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: this is fine for me. But syncthreads can be part of the mask, as long as all threads go through the same branch. But that is a corner case that we can ignore for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but then we'd need analysis to know that all threads went through the branch - this way it should just work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: fine for now. But I could imagine that Alloc that turns into local registers and reside within the masks. We can always handle those cases where we get to that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think the way to handle this is a pass that analyses usages of temporary buffers and removes allocates for buffers that are not needed. In that case registerizing accesses into scalars would remove accesses to the original buf and we could eliminate it entirely.

The case we definitely need to handle at this point is where the Buf should be shared across threads, and the allocate is turned into a variable definition which needs to stay in scope for all its usages.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the commit message, could you list a few generated code examples before and after your changes? So it is easy to see the effect of your change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Examples from the tests you mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the comment on this PR.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nickgg has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@nickgg merged this pull request in 82ab167.

xuzhao9 pushed a commit that referenced this pull request Sep 18, 2020
…44733)

Summary:
Unifies a number of partial solutions to the thread and block dimension extent masking, including the NoThreadIdxWriter and my last fix #44325. The NoThreadIdxWriter is gone in favour of tracking the current loop extents and masking any statements that have a lower rank than the launch parameters in any Block or Thread dimension, which handles both the "no" and "smaller" axis binding cases.

For example it will transform the following:
```
for i in 0..10 // blockIdx.x
  for j in 0..10 // threadIdx.x
    do thing(i, j);
  for k in 0..5 // threadIdx.x
    do other thing(i, k);
```

Into:
```
do thing(blockIdx.x, threadIdx.x);
if (threadIdx.x < 5) {
  do other thing(blockIdx.x, threadIdx.x);
}
```

And handle the case where statements are not bound by any axis, eg.
```
do outer thing;
for i in 0..10 // blockIdx.x
  for j in 0..10 // threadIdx.x
    do thing(i, j);
  do other thing(i);
```

will become:

```
if (blockIdx.x < 1) {
  if (threadIdx.x < 1) {
    do outer thing;
  }
}
syncthreads();
do thing(blockIdx.x, threadIdx.x);
syncthreads();
if (threadIdx.x < 1) {
  do other thing(blockIdx.x);
}
```

Pull Request resolved: #44733

Reviewed By: mruberry

Differential Revision: D23736878

Pulled By: nickgg

fbshipit-source-id: 52d08626ae8043d53eb937843466874d479a6768
facebook-github-bot pushed a commit that referenced this pull request Sep 21, 2020
Summary:
A previous fix for masking Cuda dimensions (#44733) changed the behaviour of inserting thread synchronization barriers in the Cuda CodeGen, causing the CudaSharedMemReduce_1 to be flaky and ultimately disabled.

The issue is working out where these barriers must be inserted - solving this optimally is very hard, and I think not possible without dependency analysis we don't have, so I've changed our logic to be quite pessimistic. We'll insert barriers before and after any blocks that have thread dimensions masked (even between blocks that have no data dependencies). This should be correct, but it's an area we could improve performance. To address this somewhat I've added a simplifier pass that removes obviously unnecessary syncThreads.

To avoid this test being flaky again, I've added a check against the generated code to ensure there is a syncThread in the right place.

Also fixed a couple of non-functional but clarity issues in the generated code: fixed the missing newline after Stores in the CudaPrinter, and prevented the PrioritizeLoad mutator from pulling out loads contained within simple Let statements (such as those produced by the Registerizer).

Pull Request resolved: #44909

Reviewed By: agolynski

Differential Revision: D23800565

Pulled By: nickgg

fbshipit-source-id: bddef1f40d8d461da965685f01d00b468d8a2c2f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants