-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[symbolic-shapes][dynamo] Pass fake_mode and fake_inputs to aot_autograd - Part 2/4 porting symbolic shapes #89371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Part 3 and 4 #89373 |
|
I'm not going to review part 3 / 4 yet as this is still failing tests |
|
There's no need to duplicate the long description on each PR, and it makes it less useful (because people will assume it's the same and skip the important part for each PR) |
| """copy while preserving strides""" | ||
| if isinstance(x, torch._subclasses.FakeTensor): | ||
| # No need to clone fake tensors? | ||
| return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should understand why we are cloning fake tensors in the first place
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In particular, even if fake tensors have no data, you may still be obligated to clone them in case metadata mutation happens
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was causing issues on the branch, but we can pull it, maybe. Let me check.
| triggered = True | ||
|
|
||
| def compiler(gm, input): | ||
| def compiler(gm, input, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why isn't this in the other PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed, will move
| example: Any | ||
| is_unspecialized: bool | ||
| # Must only be a fake tensor | ||
| fake_tensor = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the difference between this and example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You refer to this as fake_example below, make the naming consistent
| should_specialize=self.tensor_should_specialize(), | ||
| ) | ||
| if graph_arg and config.fake_tensor_propagation: | ||
| example = tensor_variable.proxy.node.meta["example_value"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it matter that you're pulling out example_value here and not val?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, Val reads better tho
| assert isinstance( | ||
| self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor | ||
| ) | ||
| return [self.fake_tensor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not an example, right? The fake tensor is canonical for all possible tensors which would pass the guards, because its sizes/strides have been symintified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
| hasher_type=None, | ||
| static_argnums=None, | ||
| fake_mode=None, | ||
| fake_inputs=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Used in next PR
| gm: torch.fx.GraphModule, | ||
| example_inputs, | ||
| fake_mode=None, | ||
| fake_inputs=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do I need both example inputs and fake inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure you should do this, but fake_mode memoizes conversions of the same real tensor into a fake tensor, so if you already called from_real_tensor on a tensor before, it will consistently give you back the same fake tensor (with the correct symints) in that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And then, I guess, you could also technically just enable fake mode before calling the compilation function, and then you don't even need to do any of the kwargs business.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You still need a way to pass in fake mode, so kwargs needs to be there.
There's also modes of operation for aot_autograd that take real tensors.
|
So, thinking about the calling convention holistically here, I think there's a decent case to be made that we should not pass real tensors to backends in the dynamic shapes world, for the simple reason that the backend may unsoundly assume that the given tensors sizes even when dynamo actually didn't guard against it. So no real inputs, they should be fakeified. The exception to the rule is parameters which are assumed to be static and can be passed in as is, this also seems fine and can be modeled as passing in a mix of fake and non-fake (non-fake you are even allowed to burn in data ptr, which is what cudagraphs will do) |
|
And if this means we need to make dynamic shapes specific changes to the calling convention, so be it. We can tag the functions so that dynamo understands what the backend supports |
I agree - this is reasonable. However, dynamo can be in fake tensor mode, but not in dynamic tensor mode - but dynamic always precludes fake. So, the solution there could be where we invoke
we can do something like: With the idea that we still have a mode (non dynamic) wherein backends need real inputs. And then, for the dynamic case, they are equivalent. We can even change their names from real/fake to something like user_provided_inputs/tracing_inputs. |
|
https://github.com/pytorch/pytorch/pull/88546/files FWIW, the original PR had fake tensor inputs in the example_inputs field, and there was no fake_inputs |
This is part 2 of the split plan for the bigger migration of critical parts of symbolic-shapes branch to master.
The whole can be found here:
#89363
And https://github.com/pytorch/pytorch/pull/89313/files in parts
The meta-description for this stack:
Step 2 of https://docs.google.com/document/u/1/d/1QJ-M4zfMkD-fjHIqW089RptjLl9EgozZGCceUbvmgfY/edit
Step 1 can be found here: #87570
The problem this PR solves is that today, we have a world wherein dynamo creates guards and installs them, but aot_autograd creates guards... and does not install them. This means that code executed in make_fx can produce new symbolic shape guards, that then do not get used anywhere.
The order today, before this PR, is:
Our solution to this is to ensure that make_fx's guards bubble up to dynamo, in a "unified cache", and we do this by piping a fake_mode down to aot_autograd from dynamo.
The unified cache architecture works by changing the lifecycle of when we compile the the aot_autograd function from runtime, to lowering time. We go from lazily compiling create_aot_dispatcher_function to always invoking it at lowering time. This is sound because the compiled_fn is protected by dynamo's guards. The order, therefore, is now:
As an added bonus, this allows us to remove the duplicate fake tensor conversion code that used to always be invoked in create_aot_dispatcher_function through process_inputs, reusing the fake tensors from dynamo's conversion step. Outside of dynamo's invocation path, we still need this layer to exist, though, as aot_function is a public entry point to aot_autograd, so the process_inputs/fake tensor conversion code gets lifted up to there.
This also changes the story of create_aot_dispatcher_function - specifically, in that it now MUST be called with fake tensors if we are in fake tensor mode, as it will do no fake tensor conversion of its own. For Dynamo, we rely on the dynamo passing that down, for the public entry point, we rely on aot_function.
One annoying sort of stopgap around this is the parameter story. Dynamo fake-ifys only its inputs, and params get treated a little differently. This is made more confusing by the fact that while create_aot_dispatcher_function takes all fake tensors (params and inputs both), aot_function_simplified still needs real parameters as it uses them as inputs to the compiled aot_function. This means we need to, later on, either fake-ify parameters at dynamo time and pass them alongside the real ones, or to keep the param-only fake tensor conversion where it is in this pr.
— Split Strategy —
cc @mlazos @soumith @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire