Skip to content

[FlexFlash] CuteDSL flat indexer needs to be colexigraphic in coordinate space#166657

Closed
drisspg wants to merge 15 commits intogh/drisspg/217/basefrom
gh/drisspg/217/head
Closed

[FlexFlash] CuteDSL flat indexer needs to be colexigraphic in coordinate space#166657
drisspg wants to merge 15 commits intogh/drisspg/217/basefrom
gh/drisspg/217/head

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Oct 30, 2025

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166657

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit cba6288 with merge base 687c15c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

drisspg added a commit that referenced this pull request Oct 30, 2025
ghstack-source-id: 9f02a02
Pull-Request: #166657
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 30, 2025
ghstack-source-id: fa9ee7b
Pull-Request: #166657
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 30, 2025
ghstack-source-id: d356935
Pull-Request: #166657
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 30, 2025
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 30, 2025
drisspg added a commit that referenced this pull request Oct 30, 2025
@drisspg drisspg changed the title Take 2 [FlexFlash] CuteDSL flat indexer needs to be colexicographic in coordinate space Oct 30, 2025
@drisspg drisspg changed the title [FlexFlash] CuteDSL flat indexer needs to be colexicographic in coordinate space [FlexFlash] CuteDSL flat indexer needs to be colexigraphic in coordinate space Oct 30, 2025
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 30, 2025
@drisspg drisspg added the release notes: nn release notes category label Oct 30, 2025
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 30, 2025
[ghstack-poisoned]
@drisspg drisspg requested a review from albanD as a code owner October 31, 2025 00:17
drisspg added a commit that referenced this pull request Oct 31, 2025
@drisspg drisspg requested review from Chillee, eellison and v0i0 and removed request for albanD October 31, 2025 00:17
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some tests with views, to be safe ?

Comment on lines +86 to +89
original_make_indexer = FixedLayout.make_indexer

def cutedsl_make_indexer(self):
return _fixed_indexer_cute(self.size, self.stride, self.offset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it just FixedLayout.make_indexer that encodes lexicographic order ? I see make_reindexer, but I think that works because it just dispatches to FixedLayout in the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one thing to that I need to make more clear is that it is not actually the colexigraphic part that is the problem.

SO make indexer; is producing an index in ptr space.

if I have tensor with shape(10, 15) : stride(15, 1)

and I index at Tensor[9, 12]

this does dot(index, stride) so 9 * 15 + 12 * 1 and this is what is given to the load,

IN the current cutedsl path, we are passing the indexes directly to the the on device cute tensor. which for all behaves like a pytorch tensor.

So it actually wants cute_tensor[9,12] and then it does the stride dot. It also happens to always accept a 'flat 1d index' so in this case the total possible number of indices is dot(shape) = 10 * 15 = 150 and the mapping of 1d index to ND index is colexigraphic.

So that would be Sum index_n * size(n-1)

or 9 * 1 + 12 * 10 = 129

does this make sense?

def cutedsl_make_indexer(self):
return _fixed_indexer_cute(self.size, self.stride, self.offset)

FixedLayout.make_indexer = cutedsl_make_indexer # type: ignore[assignment]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives me a little pause, because i'm not sure what other locations in our codebase encodes row-major indexing.. Another solution would be to wrap the indexes and expressions in Identity, which would make it easier to do the row-major -> column major transformation.

But I can't actually find anywhere else we do this, as far as codegen goes.

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Nov 1, 2025
@drisspg
Copy link
Contributor Author

drisspg commented Nov 1, 2025

Can you add some tests with views, to be safe ?

good call this IMAs but the indexing looks correct to me, I would expect that dl_pack handles the offset

prof_result["found"],
f"Flash attention kernel unexpectedly found when force_flash=False. Kernels: {prof_result['kernel_names']}",
)
# @dtypes(torch.float16, torch.bfloat16)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

un comment before land

@drisspg
Copy link
Contributor Author

drisspg commented Nov 1, 2025

good call this IMAs but the indexing looks correct to me, I would expect that dl_pack handles the offset
The indexing is the correct, the problem is the cache is hitting:

TestFlexFlashCUDA.test_flash_attention_with_mask_mod_buffer_cuda_float16
TestFlexFlashCUDA.test_flash_attention_mask_mod_with_view_buffer_cuda_float16

The only difference between these two tests is that the captured tensor is a slice w/ the same shape in case 1.

confirmed this by adding compile cache clear inbetween which fixes the test.

So concretely we generated this closure for the mask-mod is

    @cute.jit
    def mask_mod(b_idx, h_idx, q_idx, kv_idx, aux_tensors):

        in_ptr8 = aux_tensors[0]
        tmp1 = q_idx
        tmp2 = kv_idx
        tmp3 = operator.ge(tmp1, tmp2)
        tmp4 = h_idx
        tmp5 = ssa_to_indexable(tmp4, cutlass.Int32)
        tmp6 = cute.make_fragment(1, cutlass.Float16)
        tmp6[0] = (in_ptr8[tmp5])
        tmp7 = (tmp6.load()).to(cutlass.Float32)
        tmp8 = operator.gt(tmp7, cute.full_like(tmp7, 0))
        mask_mod_output = tmp3 | tmp8

        return mask_mod_output
        
        

The input tensor for aux_tensors has a shape and stride assert_size_stride(arg5_1, (4, ), (3, ))

We have identical mask mod so we produce the same cache key and hit but now:
Aux_tensors contains a tensor w/ shape and stride: assert_size_stride(arg5_1, (4, ), (1 )) -> contiguous case

We are creating the tensors with: https://github.com/Dao-AILab/flash-attention/blob/0256114fe2381ab293503219bdd9078de3cd26b3/flash_attn/cute/interface.py#L349C1-L351C91

So no alignment or leading dim..

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Nov 1, 2025
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Nov 1, 2025
@drisspg
Copy link
Contributor Author

drisspg commented Nov 1, 2025

#166789

Made a repro, landing this so its easier for cutedsl folks to repro and this issue is unrelated to this PR. I updates the tests to have different mask mods and we don't cache in in FA

@drisspg
Copy link
Contributor Author

drisspg commented Nov 1, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 1, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 jobs have failed, first few of them are: trunk / win-vs2022-cpu-py3 / build, trunk / win-vs2022-cuda12.8-py3 / build

Details for Dev Infra team Raised by workflow job

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Nov 1, 2025
@drisspg
Copy link
Contributor Author

drisspg commented Nov 1, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

etaf pushed a commit to etaf/pytorch-inductor-xpu that referenced this pull request Nov 4, 2025
…ate space (pytorch#166657)

Benchmarks on Hopper:
Note the triton impl is not using max-autotune because I didnt feel like waiting for 90x plots
<img width="12517" height="5995" alt="combined_comparison" src="https://github.com/user-attachments/assets/d94debd9-920d-4413-b51f-b8e906e4fb01" />

Pull Request resolved: pytorch#166657
Approved by: https://github.com/v0i0, https://github.com/mlazos, https://github.com/eellison
ghstack dependencies: pytorch#166359
drizzlezyk pushed a commit to Ascend/pytorch that referenced this pull request Nov 17, 2025
…calars)

Co-authored-by: dilililiwhy<why.wuhuanyu@huawei.com>



# message auto-generated for no-merge-commit merge:
!26081 merge main_sync_20251028 into master

TORCH MAIN SYNC : add update_wrapped_number (bugfix to ForwardADWithScalars)

Created-by: dilililiwhy
Commit-by: dilililiwhy
Merged-by: ascend-robot
Description: <!--  Thanks for sending a pull request! 
-->

**What type of PR is this?**
> Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line:
>
> /kind bug
> /kind task
> /kind feature


**What does this PR do / why do we need it**:
2.10.0.dev20251110

**Which issue(s) this PR fixes**:
<!-- 
*Automatically closes linked issue when PR is merged.
Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`.
-->
Fixes #

**Special notes for your reviewers**:
pytorch/pytorch#160513
pytorch/pytorch#165784

pytorch/pytorch#166657


See merge request: Ascend/pytorch!26081
@github-actions github-actions bot deleted the gh/drisspg/217/head branch December 2, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor release notes: nn release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants