[symbolic-shapes][dynamo] Use dynamo fake_mode, fake tensors in aot_autograd - Part 3/4 and 4/4 porting symbolic shapes #89373
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is part 3 and 4 of the split plan for the bigger migration of critical parts of symbolic-shapes branch to master.
The whole can be found here:
#89363
And https://github.com/pytorch/pytorch/pull/89313/files in parts
The meta-description for this stack:
Step 2 of https://docs.google.com/document/u/1/d/1QJ-M4zfMkD-fjHIqW089RptjLl9EgozZGCceUbvmgfY/edit Step
1 can be found here: #87570
The problem this PR solves is that today, we have a world wherein dynamo creates guards and installs them, but aot_autograd creates guards... and does not install them. This means that code executed in make_fx can produce new symbolic shape guards, that then do not get used anywhere.
The order today, before this PR, is:
Our solution to this is to ensure that make_fx's guards bubble up to dynamo, in a "unified cache", and we do this by piping a fake_mode down to aot_autograd from dynamo.
The unified cache architecture works by changing the lifecycle of when we compile the the aot_autograd function from runtime, to lowering time. We go from lazily compiling create_aot_dispatcher_function to always invoking it at lowering time. This is sound because the compiled_fn is protected by dynamo's guards. The order, therefore, is now:
As an added bonus, this allows us to remove the duplicate fake tensor conversion code that used to always be invoked in create_aot_dispatcher_function through process_inputs, reusing the fake tensors from dynamo's conversion step. Outside of dynamo's invocation path, we still need this layer to exist, though, as aot_function is a public entry point to aot_autograd, so the process_inputs/fake tensor conversion code gets lifted up to there.
This also changes the story of create_aot_dispatcher_function - specifically, in that it now MUST be called with fake tensors if we are in fake tensor mode, as it will do no fake tensor conversion of its own. For Dynamo, we rely on the dynamo passing that down, for the public entry point, we rely on aot_function.
One annoying sort of stopgap around this is the parameter story. Dynamo fake-ifys only its inputs, and params get treated a little differently. This is made more confusing by the fact that while create_aot_dispatcher_function takes all fake tensors (params and inputs both), aot_function_simplified still needs real parameters as it uses them as inputs to the compiled aot_function. This means we need to, later on, either fake-ify parameters at dynamo time and pass them alongside the real ones, or to keep the param-only fake tensor conversion where it is in this pr.
— Split Strategy —
cc @mlazos @soumith @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire