Skip to content

[CuTe] Change the logic of pycute manipulation ops like coalesce, complement from co-lex to lex#162690

Closed
fduwjj wants to merge 7 commits intogh/fduwjj/195/basefrom
gh/fduwjj/195/head
Closed

[CuTe] Change the logic of pycute manipulation ops like coalesce, complement from co-lex to lex#162690
fduwjj wants to merge 7 commits intogh/fduwjj/195/basefrom
gh/fduwjj/195/head

Conversation

@fduwjj
Copy link
Contributor

@fduwjj fduwjj commented Sep 11, 2025

Stack from ghstack (oldest at bottom):

PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.

However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.

Changes included in this PR:

  1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
  2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
  3. One bug fix inside composition, which will lead to infinite recursive call.

cc @H-Huang @awgu @wanchaol @fegin @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162690

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3fc2173 with merge base 5babb4d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 11, 2025
fduwjj added a commit that referenced this pull request Sep 11, 2025
… lexico"

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 11, 2025
@fduwjj fduwjj requested review from ezyang, fegin and tianyu-l September 11, 2025 05:14
@fduwjj fduwjj changed the title [CuTe] Change the logic of coalesce, complement from co to lexico [CuTe] Change the logic of pycute manipulation ops like coalesce, complement from co-lex to lex Sep 11, 2025
…alesce, complement from co-lex to lex"


PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.

However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.

Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
…alesce, complement from co-lex to lex"


PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.

However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.

Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 11, 2025


# Exclusive prefix product with output congruent to input a
# Exclusive prefix product with output congruent to input a (lexicographic)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a prefix product anymore right? It's a suffix_product now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's interesting that you decided to reverse the output. So it's also like a reversed suffix product

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree we should name it suffix_product.. no it is not a reversed suffix. I just don't want to do insert(0, suffix_product(a[i], current_init))

if is_tuple(shape) and is_tuple(stride): # "int" tuple tuple
assert len(shape) == len(stride)
return tuple(idx2crd(idx, s, d) for s, d in zip(shape, stride))
# Process from left to right for lexicographic ordering (opposite of crd2idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm? The old code processed left to right too??

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you meant to say "right to left" here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you don't actually process right-to-left. It's probably dramatically simpler to go right-to-left

Copy link
Contributor Author

@fduwjj fduwjj Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right... We don't even need to change this line.. As long as it is lexico, from left to right or from right to left does not matter that much..

# replace our shape-1 with anything
elif result_shape[-1] == 1:
result_shape[-1] = shape
result_stride[-1] = stride
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we lose this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it looks like the way the old code worked was by "absorbing" new shape/stride into shape-1 as we iterate over the shape/stride. Because you didn't change the iteration order above (you are still going left-to-right), this strategy no longer works and you had to invert everything. It does seem plausible you made enough changes to make it work, but there's probably a simpler way to do it.

result_stride.pop()
prev_shape = result_shape.pop()
result_shape.append(shape * prev_shape)
result_stride.append(stride)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it doesn't feel like it should look like this. Instead, it feels like we should have processed shape, stride in reversed order at the loop above

elif is_int(layout):
return Layout(layout)
return right_inverse(make_layout(layout, complement(layout))) # type: ignore[arg-type]
return right_inverse(make_layout(complement(layout), layout)) # type: ignore[arg-type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here?

Copy link
Contributor Author

@fduwjj fduwjj Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same answer to right_inverse, I found that as long as we swap the order, i == inv_layout(layout(i)) in the case of lexico. You might ask "can you give me proof using math". I don't have any tbh, and since we don't use them, so I decide to cut corners here. Pycute has two UT: test/distributed/_pycute/test_left_inverse.py and test/distributed/_pycute/test_right_inverse.py which tests the logic pretty thoroughly IIUC.

current_idx = shape * stride

result_shape.reverse()
result_stride.reverse()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also surprising

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBF I guess we don't care that much about these, as we have no use of them right now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes you are right but this change will make CI happy (passing UT) in the lexicographic case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just need to make sure the UTs aren't actually explicitly checking for colex.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK one more point here is that the definition of right_inverse is that one index will be itself after layout(inv_layout(i)) and the UT is pretty thorough on this case. So I think with this change it will work in the lexico case. Why? Because when you do map from crd to idx it is using lexico order now.


# Reverse the lists because we build lists in reverse order (append to end), this way it is more efficient.
result_shape.reverse()
result_stride.reverse()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything else here looks plausible but I am going to have to sweat the math lol

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, basically we just do it the reverse order as described here: https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html..

And I tried to read the proof in https://leimao.github.io/article/CuTe-Layout-Algebra/, for example, Definition 2.13 Composition - Restricted Case from right to left. I mean at the end of the day the divisor is the same it's the division order?

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I think some of the changes here can be done more simply, it does look plausible. Not sure if you think the test coverage is good enough, since you certainly had to add more tests!

@fduwjj
Copy link
Contributor Author

fduwjj commented Sep 11, 2025

@ezyang thanks for really fast review. Let me see if we make the part you commented simpler. I think maybe I can ask llm (in another PR) to generate more UT for this code for sure.

def prefix_product(a: IntTuple, init: IntTuple = 1) -> IntTuple:
if is_tuple(a):
if is_tuple(init): # tuple tuple
assert len(a) == len(init)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With all these length asserts, may want to create a zip_strict wrapper. I think we already have one in the codebase somewhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add a TODO here so we can do the cosmetic change in another PR.

…alesce, complement from co-lex to lex"


PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.

However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.

Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 12, 2025
…alesce, complement from co-lex to lex"


PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.

However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.

Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 12, 2025
…alesce, complement from co-lex to lex"


PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.

However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.

Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 12, 2025
@fduwjj fduwjj added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 12, 2025
@fduwjj
Copy link
Contributor Author

fduwjj commented Sep 16, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Sep 16, 2025
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…plement from co-lex to lex (pytorch#162690)

PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.

However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.

Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.

Pull Request resolved: pytorch#162690
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#162413, pytorch#162534, pytorch#162414
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…plement from co-lex to lex (pytorch#162690)

PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.

However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.

Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.

Pull Request resolved: pytorch#162690
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#162413, pytorch#162534, pytorch#162414
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…plement from co-lex to lex (pytorch#162690)

PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.

However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.

Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.

Pull Request resolved: pytorch#162690
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#162413, pytorch#162534, pytorch#162414
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…plement from co-lex to lex (pytorch#162690)

PyTorch tensor iteration (.view, contiguous, broadcasting) and NumPy array indexing all follow lexicographic (row-major) order. In Lexicographic (lex) on (i0, i1, …, i{k-1}): the leftmost index(stride is larger) changes fastest and the rightmost index changes slowest and usually last dim is contiguous.

However original pycute is all based on co-lex, after porting their code into pytorch and some cosmetic change, we now make it lex so that we can use it for use cases like device mesh internal bookkeeping and other stuff as well.

Changes included in this PR:
1. We changes all API ported in, included prefix_product(stride inferring and rename it to suffix_product), idx2crd, crd2idx, coalesce, composition, complement, right_inverse and left_inverse to make sure they are working in the lex way.
2. Added more unit test cases for some API mentioned above since existing unit tests do not have full coverage.
3. One bug fix inside composition, which will lead to infinite recursive call.

Pull Request resolved: pytorch#162690
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#162413, pytorch#162534, pytorch#162414
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: DeviceMesh

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants