[CuTe] Change the logic of pycute manipulation ops like coalesce, complement from co-lex to lex#162690
[CuTe] Change the logic of pycute manipulation ops like coalesce, complement from co-lex to lex#162690fduwjj wants to merge 7 commits intogh/fduwjj/195/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162690
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3fc2173 with merge base 5babb4d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
… lexico" cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…alesce, complement from co-lex to lex"
PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.
However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.
Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci
[ghstack-poisoned]
…alesce, complement from co-lex to lex"
PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.
However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.
Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci
[ghstack-poisoned]
|
|
||
|
|
||
| # Exclusive prefix product with output congruent to input a | ||
| # Exclusive prefix product with output congruent to input a (lexicographic) |
There was a problem hiding this comment.
It's not a prefix product anymore right? It's a suffix_product now
There was a problem hiding this comment.
It's interesting that you decided to reverse the output. So it's also like a reversed suffix product
There was a problem hiding this comment.
agree we should name it suffix_product.. no it is not a reversed suffix. I just don't want to do insert(0, suffix_product(a[i], current_init))
| if is_tuple(shape) and is_tuple(stride): # "int" tuple tuple | ||
| assert len(shape) == len(stride) | ||
| return tuple(idx2crd(idx, s, d) for s, d in zip(shape, stride)) | ||
| # Process from left to right for lexicographic ordering (opposite of crd2idx) |
There was a problem hiding this comment.
hmm? The old code processed left to right too??
There was a problem hiding this comment.
I think you meant to say "right to left" here.
There was a problem hiding this comment.
But you don't actually process right-to-left. It's probably dramatically simpler to go right-to-left
There was a problem hiding this comment.
you are right... We don't even need to change this line.. As long as it is lexico, from left to right or from right to left does not matter that much..
| # replace our shape-1 with anything | ||
| elif result_shape[-1] == 1: | ||
| result_shape[-1] = shape | ||
| result_stride[-1] = stride |
There was a problem hiding this comment.
So it looks like the way the old code worked was by "absorbing" new shape/stride into shape-1 as we iterate over the shape/stride. Because you didn't change the iteration order above (you are still going left-to-right), this strategy no longer works and you had to invert everything. It does seem plausible you made enough changes to make it work, but there's probably a simpler way to do it.
torch/distributed/_pycute/layout.py
Outdated
| result_stride.pop() | ||
| prev_shape = result_shape.pop() | ||
| result_shape.append(shape * prev_shape) | ||
| result_stride.append(stride) |
There was a problem hiding this comment.
Yeah, it doesn't feel like it should look like this. Instead, it feels like we should have processed shape, stride in reversed order at the loop above
| elif is_int(layout): | ||
| return Layout(layout) | ||
| return right_inverse(make_layout(layout, complement(layout))) # type: ignore[arg-type] | ||
| return right_inverse(make_layout(complement(layout), layout)) # type: ignore[arg-type] |
There was a problem hiding this comment.
Same answer to right_inverse, I found that as long as we swap the order, i == inv_layout(layout(i)) in the case of lexico. You might ask "can you give me proof using math". I don't have any tbh, and since we don't use them, so I decide to cut corners here. Pycute has two UT: test/distributed/_pycute/test_left_inverse.py and test/distributed/_pycute/test_right_inverse.py which tests the logic pretty thoroughly IIUC.
| current_idx = shape * stride | ||
|
|
||
| result_shape.reverse() | ||
| result_stride.reverse() |
There was a problem hiding this comment.
TBF I guess we don't care that much about these, as we have no use of them right now
There was a problem hiding this comment.
yes you are right but this change will make CI happy (passing UT) in the lexicographic case.
There was a problem hiding this comment.
Just need to make sure the UTs aren't actually explicitly checking for colex.
There was a problem hiding this comment.
OK one more point here is that the definition of right_inverse is that one index will be itself after layout(inv_layout(i)) and the UT is pretty thorough on this case. So I think with this change it will work in the lexico case. Why? Because when you do map from crd to idx it is using lexico order now.
|
|
||
| # Reverse the lists because we build lists in reverse order (append to end), this way it is more efficient. | ||
| result_shape.reverse() | ||
| result_stride.reverse() |
There was a problem hiding this comment.
Everything else here looks plausible but I am going to have to sweat the math lol
There was a problem hiding this comment.
sure, basically we just do it the reverse order as described here: https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html..
And I tried to read the proof in https://leimao.github.io/article/CuTe-Layout-Algebra/, for example, Definition 2.13 Composition - Restricted Case from right to left. I mean at the end of the day the divisor is the same it's the division order?
ezyang
left a comment
There was a problem hiding this comment.
While I think some of the changes here can be done more simply, it does look plausible. Not sure if you think the test coverage is good enough, since you certainly had to add more tests!
|
@ezyang thanks for really fast review. Let me see if we make the part you commented simpler. I think maybe I can ask llm (in another PR) to generate more UT for this code for sure. |
| def prefix_product(a: IntTuple, init: IntTuple = 1) -> IntTuple: | ||
| if is_tuple(a): | ||
| if is_tuple(init): # tuple tuple | ||
| assert len(a) == len(init) |
There was a problem hiding this comment.
With all these length asserts, may want to create a zip_strict wrapper. I think we already have one in the codebase somewhere
There was a problem hiding this comment.
I can add a TODO here so we can do the cosmetic change in another PR.
…alesce, complement from co-lex to lex"
PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.
However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.
Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci
[ghstack-poisoned]
…alesce, complement from co-lex to lex"
PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.
However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.
Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci
[ghstack-poisoned]
…alesce, complement from co-lex to lex"
PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.
However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.
Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci
[ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…plement from co-lex to lex (pytorch#162690) PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous. However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well. Changes included in this PR: 1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way. 2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage. 3. One bug fix inside composition, which will lead to infinite recursive call. Pull Request resolved: pytorch#162690 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#162413, pytorch#162534, pytorch#162414
…plement from co-lex to lex (pytorch#162690) PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous. However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well. Changes included in this PR: 1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way. 2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage. 3. One bug fix inside composition, which will lead to infinite recursive call. Pull Request resolved: pytorch#162690 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#162413, pytorch#162534, pytorch#162414
…plement from co-lex to lex (pytorch#162690) PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous. However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well. Changes included in this PR: 1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way. 2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage. 3. One bug fix inside composition, which will lead to infinite recursive call. Pull Request resolved: pytorch#162690 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#162413, pytorch#162534, pytorch#162414
…plement from co-lex to lex (pytorch#162690) PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous. However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well. Changes included in this PR: 1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way. 2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage. 3. One bug fix inside composition, which will lead to infinite recursive call. Pull Request resolved: pytorch#162690 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#162413, pytorch#162534, pytorch#162414
Stack from ghstack (oldest at bottom):
PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.
However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.
Changes included in this PR:
cc @H-Huang @awgu @wanchaol @fegin @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci