-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Pass dynamo's fake_mode down to aot_autograd, remove duplicate fake tensor conversion, install aot guards in dynamo #88546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…/aot_autograd_plumb
…/aot_autograd_plumb
…/aot_autograd_plumb
…/aot_autograd_plumb
…/aot_autograd_plumb
| assert fake_mode, "Fake mode must be passed in" | ||
| new_gm = deepcopy_to_fake_tensor(gm, fake_mode) | ||
| with fake_mode, enable_python_dispatcher(): | ||
| with enable_python_dispatcher(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So FakeTensorMode is no longer enabled when this pass runs? The bad situation I'm imagining could happen is:
(1) The graph contains a factory function, so interpreting the graph ends up creating a real tensor with real sizes (bad because we do unnecessary compute, etc)
(2) If the graph contains a SymInt-related op, like s1.__floordiv__(s2), and we grab the real sizes off of our real tensor, then we'll break on this pass when we run that op.
The original idea behind me adding that temporary extra FakeTensorMode and using it here was that:
(1) We only want to create fake tensors and symbolic shapes when running this pass
(2) Using the existing ShapeEnv in dynamo might be bad, because we could end up installing a bunch of redundant guards in it when we run this pass. It sounds like this was the wrong thing to do though, since mixing tensors across multiple FakeTensorModes causes issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was just my brain dump understanding. If we're going to try to remove this pass soon anyway though, and this change fixed existing issues in the mean time, then I'm for landing it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can keep fake_mode here, we need it for the deepcopy anyway, I just thought it was spurious.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only carefully reviewed the AOT parts, but mostly LGTM.
| flat_fn, flat_args: List[Tensor], aot_config: AOTConfig | ||
| def _create_aot_dispatcher_function( | ||
| flat_fn, fake_flat_tensor_args: List[Tensor], aot_config: AOTConfig, fake_mode, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably worth an assertion here
functorch/_src/aot_autograd.py
Outdated
|
|
||
|
|
||
| fake_mode = None | ||
| if "fake_mode" in top_kwargs and config.use_fake_tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point it's worth duplicating the necessary args/kwargs to this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah, another pr tho, I think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm nah ill do it now
functorch/_src/aot_autograd.py
Outdated
| fake_mode, | ||
| ) | ||
|
|
||
| compiled_fn = compile(fn, *fake_flat_tensor_args, *inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why even have a wrapper?
functorch/_src/aot_autograd.py
Outdated
| return out | ||
|
|
||
|
|
||
| def aot_function_simplified( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imo just delete aot_function_simplified and directly call _create_aot_dispatcher_function. The more layers of indirection we can remove the better :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
This contains #88546 plus a big pile of inductor hacks. Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
…ductor" This contains #88546 plus a big pile of inductor hacks. ``` $ TORCHDYNAMO_DYNAMIC_SHAPES=1 AOT_DYNAMIC_SHAPES=1 python benchmarks/dynamo/torchbench.py --accuracy --backend inductor --training --only BERT_pytorch cuda train BERT_pytorch PASS ``` Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…ductor" This contains #88546 plus a big pile of inductor hacks. ``` $ TORCHDYNAMO_DYNAMIC_SHAPES=1 AOT_DYNAMIC_SHAPES=1 python benchmarks/dynamo/torchbench.py --accuracy --backend inductor --training --only BERT_pytorch cuda train BERT_pytorch PASS ``` I don't know if we're actually generating generic kernels though. Maybe Chillee can check. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…ductor" This contains #88546 plus a big pile of inductor hacks. ``` $ TORCHDYNAMO_DYNAMIC_SHAPES=1 AOT_DYNAMIC_SHAPES=1 python benchmarks/dynamo/torchbench.py --accuracy --backend inductor --training --only BERT_pytorch cuda train BERT_pytorch PASS ``` I don't know if we're actually generating generic kernels though. Maybe Chillee can check. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…ensor conversion, install aot guards in dynamo (#88546)
Step 2 of https://docs.google.com/document/u/1/d/1QJ-M4zfMkD-fjHIqW089RptjLl9EgozZGCceUbvmgfY/edit
Step 1 can be found here: #87570
The problem this PR solves is that today, we have a world wherein dynamo creates guards and installs them, but aot_autograd creates guards... and does not install them. This means that code executed in make_fx can produce new symbolic shape guards, that then do not get used anywhere.
The order today, before this PR, is:
Our solution to this is to ensure that make_fx's guards bubble up to dynamo, in a "unified cache", and we do this by piping a fake_mode down to aot_autograd from dynamo.
The unified cache architecture works by changing the lifecycle of when we compile the the aot_autograd function from runtime, to lowering time. We go from lazily compiling create_aot_dispatcher_function to always invoking it at lowering time. This is sound because the compiled_fn is protected by dynamo's guards. The order, therefore, is now:
As an added bonus, this allows us to remove the duplicate fake tensor conversion code that used to always be invoked in create_aot_dispatcher_function through process_inputs, reusing the fake tensors from dynamo's conversion step. Outside of dynamo's invocation path, we still need this layer to exist, though, as aot_function is a public entry point to aot_autograd, so the process_inputs/fake tensor conversion code gets lifted up to there. This also changes the story of
create_aot_dispatcher_function- specifically, in that it now MUST be called with fake tensors if we are in fake tensor mode, as it will do no fake tensor conversion of its own. For Dynamo, we rely on the dynamo passing that down, for the public entry point, we rely on aot_function.One annoying sort of stopgap around this is the parameter story. Dynamo fake-ifys only its inputs, and params get treated a little differently. This is made more confusing by the fact that while
create_aot_dispatcher_functiontakes all fake tensors (params and inputs both), aot_function_simplified still needs real parameters as it uses them as inputs to the compiled aot_function. This means we need to, later on, either fake-ify parameters at dynamo time and pass them alongside the real ones, or to keep the param-only fake tensor conversion where it is in this pr.cc @mlazos @soumith @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire