-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[NNC] Fix masking for all block and thread dimensions in CudaCodeGen #44733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit b6a686f (more details on the Dr. CI page):
1 failure not recognized by patterns:
❄️ 3 failures tentatively classified as flakybut reruns have not yet been triggered to confirm:
|
Codecov Report
@@ Coverage Diff @@
## master #44733 +/- ##
==========================================
- Coverage 68.08% 68.08% -0.01%
==========================================
Files 384 384
Lines 49774 49768 -6
==========================================
- Hits 33890 33883 -7
- Misses 15884 15885 +1
Continue to review full report at Codecov.
|
zheng-xq
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few minor comments. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: I am a bit troubled by the term "rank" here. Until you see it somewhere in Cuda guide, most places use the term "rank" to refer to the rank of a tensor. I would prefer to pick a different name here. But it is up to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, happy to - any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to avoid "dimension", "extent" and "size" because they're overloaded as well. I landed on "reach" ??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: this is fine for me. But syncthreads can be part of the mask, as long as all threads go through the same branch. But that is a corner case that we can ignore for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but then we'd need analysis to know that all threads went through the branch - this way it should just work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: fine for now. But I could imagine that Alloc that turns into local registers and reside within the masks. We can always handle those cases where we get to that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think the way to handle this is a pass that analyses usages of temporary buffers and removes allocates for buffers that are not needed. In that case registerizing accesses into scalars would remove accesses to the original buf and we could eliminate it entirely.
The case we definitely need to handle at this point is where the Buf should be shared across threads, and the allocate is turned into a variable definition which needs to stay in scope for all its usages.
test/cpp/tensorexpr/test_cuda.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the commit message, could you list a few generated code examples before and after your changes? So it is easy to see the effect of your change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Examples from the tests you mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated the comment on this PR.
b6d4976 to
2344b86
Compare
2344b86 to
b6a686f
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nickgg has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…44733) Summary: Unifies a number of partial solutions to the thread and block dimension extent masking, including the NoThreadIdxWriter and my last fix #44325. The NoThreadIdxWriter is gone in favour of tracking the current loop extents and masking any statements that have a lower rank than the launch parameters in any Block or Thread dimension, which handles both the "no" and "smaller" axis binding cases. For example it will transform the following: ``` for i in 0..10 // blockIdx.x for j in 0..10 // threadIdx.x do thing(i, j); for k in 0..5 // threadIdx.x do other thing(i, k); ``` Into: ``` do thing(blockIdx.x, threadIdx.x); if (threadIdx.x < 5) { do other thing(blockIdx.x, threadIdx.x); } ``` And handle the case where statements are not bound by any axis, eg. ``` do outer thing; for i in 0..10 // blockIdx.x for j in 0..10 // threadIdx.x do thing(i, j); do other thing(i); ``` will become: ``` if (blockIdx.x < 1) { if (threadIdx.x < 1) { do outer thing; } } syncthreads(); do thing(blockIdx.x, threadIdx.x); syncthreads(); if (threadIdx.x < 1) { do other thing(blockIdx.x); } ``` Pull Request resolved: #44733 Reviewed By: mruberry Differential Revision: D23736878 Pulled By: nickgg fbshipit-source-id: 52d08626ae8043d53eb937843466874d479a6768
Summary: A previous fix for masking Cuda dimensions (#44733) changed the behaviour of inserting thread synchronization barriers in the Cuda CodeGen, causing the CudaSharedMemReduce_1 to be flaky and ultimately disabled. The issue is working out where these barriers must be inserted - solving this optimally is very hard, and I think not possible without dependency analysis we don't have, so I've changed our logic to be quite pessimistic. We'll insert barriers before and after any blocks that have thread dimensions masked (even between blocks that have no data dependencies). This should be correct, but it's an area we could improve performance. To address this somewhat I've added a simplifier pass that removes obviously unnecessary syncThreads. To avoid this test being flaky again, I've added a check against the generated code to ensure there is a syncThread in the right place. Also fixed a couple of non-functional but clarity issues in the generated code: fixed the missing newline after Stores in the CudaPrinter, and prevented the PrioritizeLoad mutator from pulling out loads contained within simple Let statements (such as those produced by the Registerizer). Pull Request resolved: #44909 Reviewed By: agolynski Differential Revision: D23800565 Pulled By: nickgg fbshipit-source-id: bddef1f40d8d461da965685f01d00b468d8a2c2f
Unifies a number of partial solutions to the thread and block dimension extent masking, including the NoThreadIdxWriter and my last fix #44325. The NoThreadIdxWriter is gone in favour of tracking the current loop extents and masking any statements that have a lower rank than the launch parameters in any Block or Thread dimension, which handles both the "no" and "smaller" axis binding cases.
For example it will transform the following:
Into:
And handle the case where statements are not bound by any axis, eg.
will become: