Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Nov 22, 2022

This PR is a pretty large refactor of the AOT Autograd logic, to clean things up + fix a few more broken edge cases. The changes are roughly:



(1) (largest change) - for outputs of the fw that alias in some way, we used to not return them in the fw graph, and instead return a long tuple of ints, corresponding to the metadata of the outputs’ sizes/strides/storage_offset. The wrapper around the CompiledFunction.forward() would then figure out how to regenerate the alias with a big .as_strided call, indexing into the giant tuple of ints from the fwd graph to get the size/stride metadata.

Now instead, the compiled forward graph returns the actual aliased tensor outputs along with every other output, and the wrapper uses that output to regenerate the alias.

Even though the aliased outputs are now returned in the compiled graph (and in the CompiledFunction.forward()), I explicitly removed them from the backward graph. That felt like the right call (the aliases shouldn’t participate in the compiled backward, because we don’t actually care about them - we just use them to regenerate the “real” aliases in the forward, but I’m open to other ideas. Doing this required the following:

(a) I updated the CompiledFunction.forward() to wrap all aliased outputs into an opaque TensorAlias wrapper object, so that the autograd.Function will know to not assign gradients to them
(b) I updated CompiledFunction.backward() to filter out the grad_outputs that correspond to aliased outputs (which I assert are all None)
(c) Before tracing the joint_forward_backward(), I update tangents to filter out tangents corresponding to aliased outputs

(3) Cleaned up the metadata and removed some redundant info. Take a look at the ViewAndMutationMetadata class

(4) Precompute more things so that the hot path code should be faster. For example, when applying mutations back to mutated inputs, we used to loop through all inputs. Now, we precompute the indices of inputs that need to be mutated and only loop through those. This should be a meaningful speedup, since many models get graphs with 200+ inputs, and only a handful need mutations

(5) Added support for graphs with outputs that alias intermediates. This should fix a bug that has shown up on multiple models in the benchmark suite, where a graph returns an output that aliases an intermediate, and later tries to mutate that output (there’s are a few gh issues for this that I tried to find but couldn’t)

The way I handled this is that I check the ._base attribute of every output of the forward. Any ._base’s that don’t already exist as other outputs are then added as extra outputs to the graph They also get their own metadata slots in ViewAndMutationMetadata.output_info (which is not strictly necessary, but made handling them easier). I then also tag every output with a ._base as having OutputType.alias_of_intermediate. In the wrapper around the CompiledFunction.forward(), for every output that is an alias of an intermediate, I discard that output and regenerate it off its intermediate.

(6) We now use Alban’s view-replay logic, instead of always doing .as_strided(). This is notably best effort, and still falls back to as_strided() in many cases. In particular: in the synthetic base tests, the aliased inputs are created in eager mode, so they are forced to always replay with .as_strided().

(7) Fixed non-tensor input handling. This was also breaking an internal test (cc @albanD @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire @mlazos @yanboliang @chunyuan-w @aazzolini). I confirmed the invariant we have inside of aot_dispatch_deduplicated_autograd is that we are given a flattened list of inputs (flattened by pytrees), but we are not guaranteed that the inputs are tensor-only.

(8) I think I responded to and fixed any other relevant PR feedback from the original PR.

Stack from ghstack (oldest at bottom):

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 22, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89532

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7ba150f:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

bdhirsh added a commit that referenced this pull request Nov 22, 2022
ghstack-source-id: a83fedf
Pull Request resolved: #89532
bdhirsh added a commit that referenced this pull request Nov 23, 2022
ghstack-source-id: b05f6fe
Pull Request resolved: #89532
bdhirsh added a commit that referenced this pull request Nov 23, 2022
ghstack-source-id: 750a341
Pull Request resolved: #89532
@anjali411 anjali411 removed their request for review November 28, 2022 15:06
ezyang added a commit that referenced this pull request Jan 11, 2023
ghstack-source-id: bf9b7c2
Pull Request resolved: #89532
… of bases, use view replay, fix non-tensor input handling"

This PR is a pretty large refactor of the AOT Autograd logic, to clean things up + fix a few more broken edge cases. The changes are roughly:



(1) (largest change) - for outputs of the fw that alias in some way, we used to *not* return them in the fw graph, and instead return a long tuple of ints, corresponding to the metadata of the outputs’ sizes/strides/storage_offset. The wrapper around the CompiledFunction.forward() would then figure out how to regenerate the alias with a big .as_strided call, indexing into the giant tuple of ints from the fwd graph to get the size/stride metadata.

Now instead, the compiled forward graph returns the actual aliased tensor outputs along with every other output, and the wrapper uses that output to regenerate the alias.

Even though the aliased outputs are now returned in the compiled graph (and in the `CompiledFunction.forward()`), I explicitly removed them from the backward graph. That felt like the right call (the aliases shouldn’t participate in the compiled backward, because we don’t actually care about them - we just use them to regenerate the “real” aliases in the forward, but I’m open to other ideas. Doing this required the following:

(a) I updated the `CompiledFunction.forward()` to wrap all aliased outputs into an opaque `TensorAlias` wrapper object, so that the `autograd.Function` will know to not assign gradients to them
(b) I updated `CompiledFunction.backward()` to filter out the grad_outputs that correspond to aliased outputs (which I assert are all None)
(c) Before tracing the `joint_forward_backward()`, I update `tangents` to filter out tangents corresponding to aliased outputs


(3) Cleaned up the metadata and removed some redundant info. Take a look at the `ViewAndMutationMetadata` class

(4) Precompute more things so that the hot path code should be faster. For example, when applying mutations back to mutated inputs, we used to loop through all inputs. Now, we precompute the indices of inputs that need to be mutated and only loop through those. This should be a meaningful speedup, since many models get graphs with 200+ inputs, and only a handful need mutations

(5) Added support for graphs with outputs that alias intermediates. This should fix a bug that has shown up on multiple models in the benchmark suite, where a graph returns an output that aliases an intermediate, and later tries to mutate that output (there’s are a few gh issues for this that I tried to find but couldn’t)

The way I handled this is that I check the `._base` attribute of every output of the forward. Any `._base`’s that don’t already exist as other outputs are then added as extra outputs to the graph They also get their own metadata slots in `ViewAndMutationMetadata.output_info` (which is not strictly necessary, but made handling them easier). I then also tag every output with a `._base` as having `OutputType.alias_of_intermediate`. In the wrapper around the `CompiledFunction.forward()`, for every output that is an alias of an intermediate, I discard that output and regenerate it off its intermediate.


(6) We now use Alban’s view-replay logic, instead of always doing `.as_strided()`. This is notably best effort, and still falls back to as_strided() in many cases. In particular: in the synthetic base tests, the aliased inputs are created in eager mode, so they are forced to always replay with .as_strided().

(7) Fixed non-tensor input handling. This was also breaking an internal test (cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire aazzolini). I confirmed the invariant we have inside of `aot_dispatch_deduplicated_autograd` is that we are given a flattened list of inputs (flattened by pytrees), but we are *not* guaranteed that the inputs are tensor-only.

(8) I think I responded to and fixed any other relevant PR feedback from the original PR.






cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@ezyang ezyang added the keep-going Don't stop on first failure, keep running tests until the end label Jan 11, 2023
ezyang added a commit that referenced this pull request Jan 11, 2023
ghstack-source-id: 8bf5d21
Pull Request resolved: #89532
@ezyang
Copy link
Contributor

ezyang commented Jan 11, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: The following mandatory check(s) failed (Rule superuser):

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

… of bases, use view replay, fix non-tensor input handling"

This PR is a pretty large refactor of the AOT Autograd logic, to clean things up + fix a few more broken edge cases. The changes are roughly:



(1) (largest change) - for outputs of the fw that alias in some way, we used to *not* return them in the fw graph, and instead return a long tuple of ints, corresponding to the metadata of the outputs’ sizes/strides/storage_offset. The wrapper around the CompiledFunction.forward() would then figure out how to regenerate the alias with a big .as_strided call, indexing into the giant tuple of ints from the fwd graph to get the size/stride metadata.

Now instead, the compiled forward graph returns the actual aliased tensor outputs along with every other output, and the wrapper uses that output to regenerate the alias.

Even though the aliased outputs are now returned in the compiled graph (and in the `CompiledFunction.forward()`), I explicitly removed them from the backward graph. That felt like the right call (the aliases shouldn’t participate in the compiled backward, because we don’t actually care about them - we just use them to regenerate the “real” aliases in the forward, but I’m open to other ideas. Doing this required the following:

(a) I updated the `CompiledFunction.forward()` to wrap all aliased outputs into an opaque `TensorAlias` wrapper object, so that the `autograd.Function` will know to not assign gradients to them
(b) I updated `CompiledFunction.backward()` to filter out the grad_outputs that correspond to aliased outputs (which I assert are all None)
(c) Before tracing the `joint_forward_backward()`, I update `tangents` to filter out tangents corresponding to aliased outputs


(3) Cleaned up the metadata and removed some redundant info. Take a look at the `ViewAndMutationMetadata` class

(4) Precompute more things so that the hot path code should be faster. For example, when applying mutations back to mutated inputs, we used to loop through all inputs. Now, we precompute the indices of inputs that need to be mutated and only loop through those. This should be a meaningful speedup, since many models get graphs with 200+ inputs, and only a handful need mutations

(5) Added support for graphs with outputs that alias intermediates. This should fix a bug that has shown up on multiple models in the benchmark suite, where a graph returns an output that aliases an intermediate, and later tries to mutate that output (there’s are a few gh issues for this that I tried to find but couldn’t)

The way I handled this is that I check the `._base` attribute of every output of the forward. Any `._base`’s that don’t already exist as other outputs are then added as extra outputs to the graph They also get their own metadata slots in `ViewAndMutationMetadata.output_info` (which is not strictly necessary, but made handling them easier). I then also tag every output with a `._base` as having `OutputType.alias_of_intermediate`. In the wrapper around the `CompiledFunction.forward()`, for every output that is an alias of an intermediate, I discard that output and regenerate it off its intermediate.


(6) We now use Alban’s view-replay logic, instead of always doing `.as_strided()`. This is notably best effort, and still falls back to as_strided() in many cases. In particular: in the synthetic base tests, the aliased inputs are created in eager mode, so they are forced to always replay with .as_strided().

(7) Fixed non-tensor input handling. This was also breaking an internal test (cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire aazzolini). I confirmed the invariant we have inside of `aot_dispatch_deduplicated_autograd` is that we are given a flattened list of inputs (flattened by pytrees), but we are *not* guaranteed that the inputs are tensor-only.

(8) I think I responded to and fixed any other relevant PR feedback from the original PR.






cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jan 12, 2023
ghstack-source-id: df1b152
Pull Request resolved: #89532
@ezyang
Copy link
Contributor

ezyang commented Jan 12, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: The following mandatory check(s) failed (Rule superuser):

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

@ezyang
Copy link
Contributor

ezyang commented Jan 12, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: The following mandatory check(s) failed (Rule superuser):

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

@ezyang
Copy link
Contributor

ezyang commented Jan 12, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: The following mandatory check(s) failed (Rule superuser):

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

@ezyang
Copy link
Contributor

ezyang commented Jan 12, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: The following mandatory check(s) failed (Rule superuser):

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

@ezyang
Copy link
Contributor

ezyang commented Jan 12, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: The following mandatory check(s) failed (Rule superuser):

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

ezyang added a commit that referenced this pull request Jan 12, 2023
… bases, use view replay, fix non-tensor input handling"

Original PR: #89532

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jan 12, 2023
… bases, use view replay, fix non-tensor input handling"

Original PR: #89532

ghstack-source-id: df1b152
Pull Request resolved: #92076
@ezyang
Copy link
Contributor

ezyang commented Jan 12, 2023

CLA bot seems broken, new PR for landing at #92076

@ezyang
Copy link
Contributor

ezyang commented Jan 12, 2023

/easycla

pytorchmergebot pushed a commit that referenced this pull request Jan 12, 2023
… bases, use view replay, fix non-tensor input handling" (#92076)

Original PR: #89532

Pull Request resolved: #92076
Approved by: https://github.com/janeyx99, https://github.com/albanD
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Mar 13, 2023
@github-actions github-actions bot closed this Apr 12, 2023
@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/346/head branch June 8, 2023 15:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end module: dynamo module: inductor release notes: torch.func release notes category for torch.vmap or torch.func.* APIs skip-pr-sanity-checks Stale

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants