Skip to content

Conversation

@nickgg
Copy link
Contributor

@nickgg nickgg commented Sep 8, 2020

Fix an issue where loops of different sizes are bound to the same Cuda dimension / metavar.

Coming soon more info and tests...

@nickgg nickgg requested a review from apaszke as a code owner September 8, 2020 20:40
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Sep 8, 2020
@dr-ci
Copy link

dr-ci bot commented Sep 8, 2020

💊 CI failures summary and remediations

As of commit cdce535 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 10 times.

@codecov
Copy link

codecov bot commented Sep 8, 2020

Codecov Report

Merging #44325 into master will not change coverage.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master   #44325   +/-   ##
=======================================
  Coverage   67.98%   67.98%           
=======================================
  Files         384      384           
  Lines       49596    49596           
=======================================
  Hits        33718    33718           
  Misses      15878    15878           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a9754fb...cdce535. Read the comment docs.

@bertmaher bertmaher self-requested a review September 9, 2020 19:25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lowercase i in threadidx.y looks like a typo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to rebind u in visual mode, easy to screw up case. Good catch.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nickgg has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nickgg has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@nickgg merged this pull request in 64b4307.

facebook-github-bot pushed a commit that referenced this pull request Sep 16, 2020
…44733)

Summary:
Unifies a number of partial solutions to the thread and block dimension extent masking, including the NoThreadIdxWriter and my last fix #44325. The NoThreadIdxWriter is gone in favour of tracking the current loop extents and masking any statements that have a lower rank than the launch parameters in any Block or Thread dimension, which handles both the "no" and "smaller" axis binding cases.

For example it will transform the following:
```
for i in 0..10 // blockIdx.x
  for j in 0..10 // threadIdx.x
    do thing(i, j);
  for k in 0..5 // threadIdx.x
    do other thing(i, k);
```

Into:
```
do thing(blockIdx.x, threadIdx.x);
if (threadIdx.x < 5) {
  do other thing(blockIdx.x, threadIdx.x);
}
```

And handle the case where statements are not bound by any axis, eg.
```
do outer thing;
for i in 0..10 // blockIdx.x
  for j in 0..10 // threadIdx.x
    do thing(i, j);
  do other thing(i);
```

will become:

```
if (blockIdx.x < 1) {
  if (threadIdx.x < 1) {
    do outer thing;
  }
}
syncthreads();
do thing(blockIdx.x, threadIdx.x);
syncthreads();
if (threadIdx.x < 1) {
  do other thing(blockIdx.x);
}
```

Pull Request resolved: #44733

Reviewed By: mruberry

Differential Revision: D23736878

Pulled By: nickgg

fbshipit-source-id: 52d08626ae8043d53eb937843466874d479a6768
xuzhao9 pushed a commit that referenced this pull request Sep 18, 2020
)

Summary:
Fix an issue where loops of different sizes are bound to the same Cuda dimension / metavar.

Coming soon more info and tests...

Pull Request resolved: #44325

Reviewed By: colesbury

Differential Revision: D23628859

Pulled By: nickgg

fbshipit-source-id: 3621850a4cc38a790b62ad168d32e7a0e2462fad
xuzhao9 pushed a commit that referenced this pull request Sep 18, 2020
…44733)

Summary:
Unifies a number of partial solutions to the thread and block dimension extent masking, including the NoThreadIdxWriter and my last fix #44325. The NoThreadIdxWriter is gone in favour of tracking the current loop extents and masking any statements that have a lower rank than the launch parameters in any Block or Thread dimension, which handles both the "no" and "smaller" axis binding cases.

For example it will transform the following:
```
for i in 0..10 // blockIdx.x
  for j in 0..10 // threadIdx.x
    do thing(i, j);
  for k in 0..5 // threadIdx.x
    do other thing(i, k);
```

Into:
```
do thing(blockIdx.x, threadIdx.x);
if (threadIdx.x < 5) {
  do other thing(blockIdx.x, threadIdx.x);
}
```

And handle the case where statements are not bound by any axis, eg.
```
do outer thing;
for i in 0..10 // blockIdx.x
  for j in 0..10 // threadIdx.x
    do thing(i, j);
  do other thing(i);
```

will become:

```
if (blockIdx.x < 1) {
  if (threadIdx.x < 1) {
    do outer thing;
  }
}
syncthreads();
do thing(blockIdx.x, threadIdx.x);
syncthreads();
if (threadIdx.x < 1) {
  do other thing(blockIdx.x);
}
```

Pull Request resolved: #44733

Reviewed By: mruberry

Differential Revision: D23736878

Pulled By: nickgg

fbshipit-source-id: 52d08626ae8043d53eb937843466874d479a6768
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants