-
Notifications
You must be signed in to change notification settings - Fork 26.3k
AOT Autograd refactor + cleanup, handle intermediate views of bases, use view replay, fix non-tensor input handling #89532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89532
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7ba150f: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
… of bases, use view replay, fix non-tensor input handling" This PR is a pretty large refactor of the AOT Autograd logic, to clean things up + fix a few more broken edge cases. The changes are roughly: (1) (largest change) - for outputs of the fw that alias in some way, we used to *not* return them in the fw graph, and instead return a long tuple of ints, corresponding to the metadata of the outputs’ sizes/strides/storage_offset. The wrapper around the CompiledFunction.forward() would then figure out how to regenerate the alias with a big .as_strided call, indexing into the giant tuple of ints from the fwd graph to get the size/stride metadata. Now instead, the compiled forward graph returns the actual aliased tensor outputs along with every other output, and the wrapper uses that output to regenerate the alias. Even though the aliased outputs are now returned in the compiled graph (and in the `CompiledFunction.forward()`), I explicitly removed them from the backward graph. That felt like the right call (the aliases shouldn’t participate in the compiled backward, because we don’t actually care about them - we just use them to regenerate the “real” aliases in the forward, but I’m open to other ideas. Doing this required the following: (a) I updated the `CompiledFunction.forward()` to wrap all aliased outputs into an opaque `TensorAlias` wrapper object, so that the `autograd.Function` will know to not assign gradients to them (b) I updated `CompiledFunction.backward()` to filter out the grad_outputs that correspond to aliased outputs (which I assert are all None) (c) Before tracing the `joint_forward_backward()`, I update `tangents` to filter out tangents corresponding to aliased outputs (3) Cleaned up the metadata and removed some redundant info. Take a look at the `ViewAndMutationMetadata` class (4) Precompute more things so that the hot path code should be faster. For example, when applying mutations back to mutated inputs, we used to loop through all inputs. Now, we precompute the indices of inputs that need to be mutated and only loop through those. This should be a meaningful speedup, since many models get graphs with 200+ inputs, and only a handful need mutations (5) Added support for graphs with outputs that alias intermediates. This should fix a bug that has shown up on multiple models in the benchmark suite, where a graph returns an output that aliases an intermediate, and later tries to mutate that output (there’s are a few gh issues for this that I tried to find but couldn’t) The way I handled this is that I check the `._base` attribute of every output of the forward. Any `._base`’s that don’t already exist as other outputs are then added as extra outputs to the graph They also get their own metadata slots in `ViewAndMutationMetadata.output_info` (which is not strictly necessary, but made handling them easier). I then also tag every output with a `._base` as having `OutputType.alias_of_intermediate`. In the wrapper around the `CompiledFunction.forward()`, for every output that is an alias of an intermediate, I discard that output and regenerate it off its intermediate. (6) We now use Alban’s view-replay logic, instead of always doing `.as_strided()`. This is notably best effort, and still falls back to as_strided() in many cases. In particular: in the synthetic base tests, the aliased inputs are created in eager mode, so they are forced to always replay with .as_strided(). (7) Fixed non-tensor input handling. This was also breaking an internal test (cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire aazzolini). I confirmed the invariant we have inside of `aot_dispatch_deduplicated_autograd` is that we are given a flattened list of inputs (flattened by pytrees), but we are *not* guaranteed that the inputs are tensor-only. (8) I think I responded to and fixed any other relevant PR feedback from the original PR. cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
|
@pytorchbot merge |
Merge failedReason: The following mandatory check(s) failed (Rule Dig deeper by viewing the failures on hud Details for Dev Infra teamRaised by workflow job |
… of bases, use view replay, fix non-tensor input handling" This PR is a pretty large refactor of the AOT Autograd logic, to clean things up + fix a few more broken edge cases. The changes are roughly: (1) (largest change) - for outputs of the fw that alias in some way, we used to *not* return them in the fw graph, and instead return a long tuple of ints, corresponding to the metadata of the outputs’ sizes/strides/storage_offset. The wrapper around the CompiledFunction.forward() would then figure out how to regenerate the alias with a big .as_strided call, indexing into the giant tuple of ints from the fwd graph to get the size/stride metadata. Now instead, the compiled forward graph returns the actual aliased tensor outputs along with every other output, and the wrapper uses that output to regenerate the alias. Even though the aliased outputs are now returned in the compiled graph (and in the `CompiledFunction.forward()`), I explicitly removed them from the backward graph. That felt like the right call (the aliases shouldn’t participate in the compiled backward, because we don’t actually care about them - we just use them to regenerate the “real” aliases in the forward, but I’m open to other ideas. Doing this required the following: (a) I updated the `CompiledFunction.forward()` to wrap all aliased outputs into an opaque `TensorAlias` wrapper object, so that the `autograd.Function` will know to not assign gradients to them (b) I updated `CompiledFunction.backward()` to filter out the grad_outputs that correspond to aliased outputs (which I assert are all None) (c) Before tracing the `joint_forward_backward()`, I update `tangents` to filter out tangents corresponding to aliased outputs (3) Cleaned up the metadata and removed some redundant info. Take a look at the `ViewAndMutationMetadata` class (4) Precompute more things so that the hot path code should be faster. For example, when applying mutations back to mutated inputs, we used to loop through all inputs. Now, we precompute the indices of inputs that need to be mutated and only loop through those. This should be a meaningful speedup, since many models get graphs with 200+ inputs, and only a handful need mutations (5) Added support for graphs with outputs that alias intermediates. This should fix a bug that has shown up on multiple models in the benchmark suite, where a graph returns an output that aliases an intermediate, and later tries to mutate that output (there’s are a few gh issues for this that I tried to find but couldn’t) The way I handled this is that I check the `._base` attribute of every output of the forward. Any `._base`’s that don’t already exist as other outputs are then added as extra outputs to the graph They also get their own metadata slots in `ViewAndMutationMetadata.output_info` (which is not strictly necessary, but made handling them easier). I then also tag every output with a `._base` as having `OutputType.alias_of_intermediate`. In the wrapper around the `CompiledFunction.forward()`, for every output that is an alias of an intermediate, I discard that output and regenerate it off its intermediate. (6) We now use Alban’s view-replay logic, instead of always doing `.as_strided()`. This is notably best effort, and still falls back to as_strided() in many cases. In particular: in the synthetic base tests, the aliased inputs are created in eager mode, so they are forced to always replay with .as_strided(). (7) Fixed non-tensor input handling. This was also breaking an internal test (cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire aazzolini). I confirmed the invariant we have inside of `aot_dispatch_deduplicated_autograd` is that we are given a flattened list of inputs (flattened by pytrees), but we are *not* guaranteed that the inputs are tensor-only. (8) I think I responded to and fixed any other relevant PR feedback from the original PR. cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
|
@pytorchbot merge |
Merge failedReason: The following mandatory check(s) failed (Rule Dig deeper by viewing the failures on hud Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge failedReason: The following mandatory check(s) failed (Rule Dig deeper by viewing the failures on hud Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge failedReason: The following mandatory check(s) failed (Rule Dig deeper by viewing the failures on hud Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge failedReason: The following mandatory check(s) failed (Rule Dig deeper by viewing the failures on hud Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge failedReason: The following mandatory check(s) failed (Rule Dig deeper by viewing the failures on hud Details for Dev Infra teamRaised by workflow job |
… bases, use view replay, fix non-tensor input handling" Original PR: #89532 [ghstack-poisoned]
|
CLA bot seems broken, new PR for landing at #92076 |
|
/easycla |
… bases, use view replay, fix non-tensor input handling" (#92076) Original PR: #89532 Pull Request resolved: #92076 Approved by: https://github.com/janeyx99, https://github.com/albanD
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
This PR is a pretty large refactor of the AOT Autograd logic, to clean things up + fix a few more broken edge cases. The changes are roughly:
(1) (largest change) - for outputs of the fw that alias in some way, we used to not return them in the fw graph, and instead return a long tuple of ints, corresponding to the metadata of the outputs’ sizes/strides/storage_offset. The wrapper around the CompiledFunction.forward() would then figure out how to regenerate the alias with a big .as_strided call, indexing into the giant tuple of ints from the fwd graph to get the size/stride metadata.
Now instead, the compiled forward graph returns the actual aliased tensor outputs along with every other output, and the wrapper uses that output to regenerate the alias.
Even though the aliased outputs are now returned in the compiled graph (and in the
CompiledFunction.forward()), I explicitly removed them from the backward graph. That felt like the right call (the aliases shouldn’t participate in the compiled backward, because we don’t actually care about them - we just use them to regenerate the “real” aliases in the forward, but I’m open to other ideas. Doing this required the following: (a) I updated theCompiledFunction.forward()to wrap all aliased outputs into an opaqueTensorAliaswrapper object, so that theautograd.Functionwill know to not assign gradients to them (b) I updatedCompiledFunction.backward()to filter out the grad_outputs that correspond to aliased outputs (which I assert are all None)(c) Before tracing the
joint_forward_backward(), I updatetangentsto filter out tangents corresponding to aliased outputs(3) Cleaned up the metadata and removed some redundant info. Take a look at the
ViewAndMutationMetadataclass(4) Precompute more things so that the hot path code should be faster. For example, when applying mutations back to mutated inputs, we used to loop through all inputs. Now, we precompute the indices of inputs that need to be mutated and only loop through those. This should be a meaningful speedup, since many models get graphs with 200+ inputs, and only a handful need mutations
(5) Added support for graphs with outputs that alias intermediates. This should fix a bug that has shown up on multiple models in the benchmark suite, where a graph returns an output that aliases an intermediate, and later tries to mutate that output (there’s are a few gh issues for this that I tried to find but couldn’t)
The way I handled this is that I check the
._baseattribute of every output of the forward. Any._base’s that don’t already exist as other outputs are then added as extra outputs to the graph They also get their own metadata slots inViewAndMutationMetadata.output_info(which is not strictly necessary, but made handling them easier). I then also tag every output with a._baseas havingOutputType.alias_of_intermediate. In the wrapper around theCompiledFunction.forward(), for every output that is an alias of an intermediate, I discard that output and regenerate it off its intermediate.(6) We now use Alban’s view-replay logic, instead of always doing
.as_strided(). This is notably best effort, and still falls back to as_strided() in many cases. In particular: in the synthetic base tests, the aliased inputs are created in eager mode, so they are forced to always replay with .as_strided().(7) Fixed non-tensor input handling. This was also breaking an internal test (cc @albanD @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire @mlazos @yanboliang @chunyuan-w @aazzolini). I confirmed the invariant we have inside of
aot_dispatch_deduplicated_autogradis that we are given a flattened list of inputs (flattened by pytrees), but we are not guaranteed that the inputs are tensor-only.(8) I think I responded to and fixed any other relevant PR feedback from the original PR.
Stack from ghstack (oldest at bottom):
cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire