-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[SDPA] Standardizes the return shape for dense tensor of SDPA regardless of fused kernel called #90776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90776
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a6024e0: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| 'nn.functional._scaled_dot_product_attention', | ||
| op=lambda *args, **kwargs: | ||
| wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs), | ||
| wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When need_attn_weights we return an empty tensor now (consistently). Alot of the tests just check the outputs of the ops causing failues. The problem the test_meta point the fact that the second return was not consistent in its shape soooo idk.
76454ee to
e3bb78b
Compare
torch/_meta_registrations.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the first source of failure for compilation
torch/_meta_registrations.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not work for meta_tensors. Their storage is Null so query._storage().data_ptr() returns 0 for all of them even.
Solution which I implemented here is pass in a kwarg to efficeint_attention_backward that we want to chunk the grads.
I created that function as a native_function. This is the part I don't like. But as you can see in the derivatives.yaml updates I need this func to be available here. Which from my experience means it needs to be a native_function. @albanD if there is another place I could write it I think that would be cleaner.
I also tried adding this to the return value of efficient_attention_forward -> (Tensor, Tensor, bool)
:( Torchgen no like.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fun fact - you don't need a function to be a native_function to use it in derivatives.yaml!
The best example of that is all of the functions in torch/csrc/autograd/FunctionsManual.h. Those functions are all directly referenced in derivatives.yaml.
Another reason you probably don't want it as a (cuda-specific) native function is because (I think) you want to evaluate that function at compile time, and not bake it into the graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since there are calls to storage here, I think you might still want to have a dispatching operation somewhere to hide that call to make it composite compliant.
(cuda-specific) native function is because (I think) you want to evaluate that function at compile time, and not bake it into the graph.
Interesting, why is that specific to cuda-specific native functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since there are calls to storage here, I think you might still want to have a dispatching operation somewhere to hide that call to make it composite compliant.
Oh thanks @soulitzer! I think I'm wrong (but maybe Driss can confirm). It looks like we do have to since we're mucking with storage data pointers directly here, I don't think we can trace through this info at compile time - so we're kinda forced to bake this in as an op that gets evaluated at runtime (which means we'd graph break on it, since the whether that condition is true is data-dependent on the input tensor's data pointers).
Interesting, why is that specific to cuda-specific native functions?
oop I should have said this more clearly - I meant it as "operators that aren't registered as CompositeImplicitAutograd will be baked directly into the compiled graph". I just noticed that the new op has a function registered to the CUDA dispatch key, and so it isn't composite.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If avoiding the graph break is really important, we can probably brainstorm more on how to get that to work.
This totally needs more fleshing out, but one idea is:
(1) Have this function just check if the input's storages are aliased, and don't check the underlying data pointers. Note that doing that would require registering guards to dynamo on the aliasing relationship of the inputs. That isn't actually something we can do today, but I believe we need to do something similar in aot_autograd anyway - so maybe in the medium term we can get this to work.
(2) Have the kernel that gets invoked in the "inputs are aliased" situation do a runtime check that the data pointers actually match, and fail otherwise. (is this ok? or is it valid for users to want to pass query/key/value inputs that are views of each other, but don't share a data pointer. This feels a bit sketchy)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a noob question, and should have said that part more clearly too haha.
You're totally right - baking the op into the graph doesn't automatically make it a graph break. The problem is that in order to bake _chunk_grad_outputs_efficient_attention into the graph, we need to be able to trace through it and evaluate its output (a bool) at compile time, with fake tensors.
And in order to evaluate its output, we need to know if the input tensors point to the same underlying memory. But at the time we're tracing, we don't have the real tensor data - we just have fake tensors, which don't store any notion of "what's the memory address of my data". Because fake tensors are only meant to accurately represent metadata, and not actual tensor data.
The way that case is "normally" handled today is that certain aten ops (like aten.nonzero) are marked with a tag, Tag.data_dependent_output, and when fake tensor sees one of those ops, it raises an exception that dynamo knows to catch and create a graph break on.
But since this is a specific check about data pointers, and not the actual underlying data values, maybe there are other options that we can brainstorm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks! I guess the current state is that if an operation returns a non-Tensor we will always try to know that value ahead of time and potentially cause a graph break if we aren't able to. Would it be better if we could say, even though this operation is returning a bool (or some other non-Tensor value) that is data dependent, we don't need to know that value at compile time because we don't depend on it for control flow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm
Oh hmm maybe you mean running code evaluating the condition at compile time vs at runtime. (In this case, I think since you are passing that boolean to _scaled_dot_product_efficient_attention_backward_cuda and it eventually does control flow on that condition inside and since that function is itself a non-composite native function, whatever control flow happens inside shouldn't cause a graph break?)
yeah this was how I was thinking about initially. But you are right I found this problem because doing this check in python is always true -> AFAIK fake tensors' storage is null. If I were to make this function a non aten op do you think that would that break things?
I think another alternative would be to not do this chunking' view behavior which would potentially hurt eager mode perf but could get recored by compilation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I were to make this function a non aten op do you think that would that break things?
Feel free to try it - I... think so. It'll either break, or maybe when you try to call .data_ptr() on a FakeTensor it'll silently do the wrong thing (e.g. always return 0, so it assumes that the inputs always alias).
I think another alternative would be to not do this chunking' view behavior which would potentially hurt eager mode perf but could get recored by compilation?
Hmmm just another thought: Are the output shapes of _scaled_dot_product_efficient_backward() actually dependent on the chunking behavior? It looks like the outputs are the same size in both cases, and the only difference is whether or not they're separate tensors, or chunks of some underlying tensor.
Since this is a backward() function, then I don't think the grad_outputs are allowed to be mutated. So maybe it's ok if the meta function just unconditionally returns non-aliased grad_outputs. LMK what you think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the output shapes of _scaled_dot_product_efficient_backward() actually dependent on the chunking behavior?
If it is, sounds like we would need to tag it properly (and we'll always have to graph break)?
Also another way to avoid having an extra native function while avoiding composite compliant issues could be just to call this helper inside the kernels instead of outside. Could we just do that instead?
… to not be native_func
4da69f6 to
a6024e0
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary
Continues to fix up the meta output story of SDPA to be more correct