Skip to content

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Dec 13, 2022

Summary

Continues to fix up the meta output story of SDPA to be more correct

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 13, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90776

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a6024e0:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

'nn.functional._scaled_dot_product_attention',
op=lambda *args, **kwargs:
wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs),
wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When need_attn_weights we return an empty tensor now (consistently). Alot of the tests just check the outputs of the ops causing failues. The problem the test_meta point the fact that the second return was not consistent in its shape soooo idk.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the first source of failure for compilation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not work for meta_tensors. Their storage is Null so query._storage().data_ptr() returns 0 for all of them even.

Solution which I implemented here is pass in a kwarg to efficeint_attention_backward that we want to chunk the grads.

I created that function as a native_function. This is the part I don't like. But as you can see in the derivatives.yaml updates I need this func to be available here. Which from my experience means it needs to be a native_function. @albanD if there is another place I could write it I think that would be cleaner.

I also tried adding this to the return value of efficient_attention_forward -> (Tensor, Tensor, bool)

:( Torchgen no like.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fun fact - you don't need a function to be a native_function to use it in derivatives.yaml!

The best example of that is all of the functions in torch/csrc/autograd/FunctionsManual.h. Those functions are all directly referenced in derivatives.yaml.

Another reason you probably don't want it as a (cuda-specific) native function is because (I think) you want to evaluate that function at compile time, and not bake it into the graph.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there are calls to storage here, I think you might still want to have a dispatching operation somewhere to hide that call to make it composite compliant.

(cuda-specific) native function is because (I think) you want to evaluate that function at compile time, and not bake it into the graph.

Interesting, why is that specific to cuda-specific native functions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there are calls to storage here, I think you might still want to have a dispatching operation somewhere to hide that call to make it composite compliant.

Oh thanks @soulitzer! I think I'm wrong (but maybe Driss can confirm). It looks like we do have to since we're mucking with storage data pointers directly here, I don't think we can trace through this info at compile time - so we're kinda forced to bake this in as an op that gets evaluated at runtime (which means we'd graph break on it, since the whether that condition is true is data-dependent on the input tensor's data pointers).

Interesting, why is that specific to cuda-specific native functions?

oop I should have said this more clearly - I meant it as "operators that aren't registered as CompositeImplicitAutograd will be baked directly into the compiled graph". I just noticed that the new op has a function registered to the CUDA dispatch key, and so it isn't composite.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If avoiding the graph break is really important, we can probably brainstorm more on how to get that to work.

This totally needs more fleshing out, but one idea is:

(1) Have this function just check if the input's storages are aliased, and don't check the underlying data pointers. Note that doing that would require registering guards to dynamo on the aliasing relationship of the inputs. That isn't actually something we can do today, but I believe we need to do something similar in aot_autograd anyway - so maybe in the medium term we can get this to work.

(2) Have the kernel that gets invoked in the "inputs are aliased" situation do a runtime check that the data pointers actually match, and fail otherwise. (is this ok? or is it valid for users to want to pass query/key/value inputs that are views of each other, but don't share a data pointer. This feels a bit sketchy)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a noob question, and should have said that part more clearly too haha.

You're totally right - baking the op into the graph doesn't automatically make it a graph break. The problem is that in order to bake _chunk_grad_outputs_efficient_attention into the graph, we need to be able to trace through it and evaluate its output (a bool) at compile time, with fake tensors.

And in order to evaluate its output, we need to know if the input tensors point to the same underlying memory. But at the time we're tracing, we don't have the real tensor data - we just have fake tensors, which don't store any notion of "what's the memory address of my data". Because fake tensors are only meant to accurately represent metadata, and not actual tensor data.

The way that case is "normally" handled today is that certain aten ops (like aten.nonzero) are marked with a tag, Tag.data_dependent_output, and when fake tensor sees one of those ops, it raises an exception that dynamo knows to catch and create a graph break on.

But since this is a specific check about data pointers, and not the actual underlying data values, maybe there are other options that we can brainstorm.

Copy link
Contributor

@soulitzer soulitzer Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks! I guess the current state is that if an operation returns a non-Tensor we will always try to know that value ahead of time and potentially cause a graph break if we aren't able to. Would it be better if we could say, even though this operation is returning a bool (or some other non-Tensor value) that is data dependent, we don't need to know that value at compile time because we don't depend on it for control flow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm

Oh hmm maybe you mean running code evaluating the condition at compile time vs at runtime. (In this case, I think since you are passing that boolean to _scaled_dot_product_efficient_attention_backward_cuda and it eventually does control flow on that condition inside and since that function is itself a non-composite native function, whatever control flow happens inside shouldn't cause a graph break?)

yeah this was how I was thinking about initially. But you are right I found this problem because doing this check in python is always true -> AFAIK fake tensors' storage is null. If I were to make this function a non aten op do you think that would that break things?

I think another alternative would be to not do this chunking' view behavior which would potentially hurt eager mode perf but could get recored by compilation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I were to make this function a non aten op do you think that would that break things?

Feel free to try it - I... think so. It'll either break, or maybe when you try to call .data_ptr() on a FakeTensor it'll silently do the wrong thing (e.g. always return 0, so it assumes that the inputs always alias).

I think another alternative would be to not do this chunking' view behavior which would potentially hurt eager mode perf but could get recored by compilation?

Hmmm just another thought: Are the output shapes of _scaled_dot_product_efficient_backward() actually dependent on the chunking behavior? It looks like the outputs are the same size in both cases, and the only difference is whether or not they're separate tensors, or chunks of some underlying tensor.

Since this is a backward() function, then I don't think the grad_outputs are allowed to be mutated. So maybe it's ok if the meta function just unconditionally returns non-aliased grad_outputs. LMK what you think

Copy link
Contributor

@soulitzer soulitzer Dec 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the output shapes of _scaled_dot_product_efficient_backward() actually dependent on the chunking behavior?

If it is, sounds like we would need to tag it properly (and we'll always have to graph break)?

Also another way to avoid having an extra native function while avoiding composite compliant issues could be just to call this helper inside the kernels instead of outside. Could we just do that instead?

@drisspg drisspg added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 14, 2022
@drisspg
Copy link
Contributor Author

drisspg commented Dec 14, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants