[FlexFlash] CuteDSL flat indexer needs to be colexigraphic in coordinate space#166657
[FlexFlash] CuteDSL flat indexer needs to be colexigraphic in coordinate space#166657drisspg wants to merge 15 commits intogh/drisspg/217/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166657
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit cba6288 with merge base 687c15c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
eellison
left a comment
There was a problem hiding this comment.
Can you add some tests with views, to be safe ?
| original_make_indexer = FixedLayout.make_indexer | ||
|
|
||
| def cutedsl_make_indexer(self): | ||
| return _fixed_indexer_cute(self.size, self.stride, self.offset) |
There was a problem hiding this comment.
Is it just FixedLayout.make_indexer that encodes lexicographic order ? I see make_reindexer, but I think that works because it just dispatches to FixedLayout in the end.
There was a problem hiding this comment.
I think one thing to that I need to make more clear is that it is not actually the colexigraphic part that is the problem.
SO make indexer; is producing an index in ptr space.
if I have tensor with shape(10, 15) : stride(15, 1)
and I index at Tensor[9, 12]
this does dot(index, stride) so 9 * 15 + 12 * 1 and this is what is given to the load,
IN the current cutedsl path, we are passing the indexes directly to the the on device cute tensor. which for all behaves like a pytorch tensor.
So it actually wants cute_tensor[9,12] and then it does the stride dot. It also happens to always accept a 'flat 1d index' so in this case the total possible number of indices is dot(shape) = 10 * 15 = 150 and the mapping of 1d index to ND index is colexigraphic.
So that would be Sum index_n * size(n-1)
or 9 * 1 + 12 * 10 = 129
does this make sense?
| def cutedsl_make_indexer(self): | ||
| return _fixed_indexer_cute(self.size, self.stride, self.offset) | ||
|
|
||
| FixedLayout.make_indexer = cutedsl_make_indexer # type: ignore[assignment] |
There was a problem hiding this comment.
This gives me a little pause, because i'm not sure what other locations in our codebase encodes row-major indexing.. Another solution would be to wrap the indexes and expressions in Identity, which would make it easier to do the row-major -> column major transformation.
But I can't actually find anywhere else we do this, as far as codegen goes.
good call this IMAs but the indexing looks correct to me, I would expect that dl_pack handles the offset |
test/inductor/test_flex_flash.py
Outdated
| prof_result["found"], | ||
| f"Flash attention kernel unexpectedly found when force_flash=False. Kernels: {prof_result['kernel_names']}", | ||
| ) | ||
| # @dtypes(torch.float16, torch.bfloat16) |
There was a problem hiding this comment.
un comment before land
The only difference between these two tests is that the captured tensor is a slice w/ the same shape in case 1. confirmed this by adding compile cache clear inbetween which fixes the test. So concretely we generated this closure for the mask-mod is The input tensor for aux_tensors has a shape and stride We have identical mask mod so we produce the same cache key and hit but now: We are creating the tensors with: https://github.com/Dao-AILab/flash-attention/blob/0256114fe2381ab293503219bdd9078de3cd26b3/flash_attn/cute/interface.py#L349C1-L351C91 So no alignment or leading dim.. |
|
Made a repro, landing this so its easier for cutedsl folks to repro and this issue is unrelated to this PR. I updates the tests to have different mask mods and we don't cache in in FA |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 jobs have failed, first few of them are: trunk / win-vs2022-cpu-py3 / build, trunk / win-vs2022-cuda12.8-py3 / build Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ate space (pytorch#166657) Benchmarks on Hopper: Note the triton impl is not using max-autotune because I didnt feel like waiting for 90x plots <img width="12517" height="5995" alt="combined_comparison" src="https://github.com/user-attachments/assets/d94debd9-920d-4413-b51f-b8e906e4fb01" /> Pull Request resolved: pytorch#166657 Approved by: https://github.com/v0i0, https://github.com/mlazos, https://github.com/eellison ghstack dependencies: pytorch#166359
…calars) Co-authored-by: dilililiwhy<why.wuhuanyu@huawei.com> # message auto-generated for no-merge-commit merge: !26081 merge main_sync_20251028 into master TORCH MAIN SYNC : add update_wrapped_number (bugfix to ForwardADWithScalars) Created-by: dilililiwhy Commit-by: dilililiwhy Merged-by: ascend-robot Description: <!-- Thanks for sending a pull request! --> **What type of PR is this?** > Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line: > > /kind bug > /kind task > /kind feature **What does this PR do / why do we need it**: 2.10.0.dev20251110 **Which issue(s) this PR fixes**: <!-- *Automatically closes linked issue when PR is merged. Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`. --> Fixes # **Special notes for your reviewers**: pytorch/pytorch#160513 pytorch/pytorch#165784 pytorch/pytorch#166657 See merge request: Ascend/pytorch!26081
Stack from ghstack (oldest at bottom):
Benchmarks on Hopper:

Note the triton impl is not using max-autotune because I didnt feel like waiting for 90x plots
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben