Skip to content

Conversation

@voznesenskym
Copy link
Collaborator

This is part 3 and 4 of the split plan for the bigger migration of critical parts of symbolic-shapes branch to master.

The whole can be found here:
#89363
And https://github.com/pytorch/pytorch/pull/89313/files in parts

The meta-description for this stack:

Step 2 of https://docs.google.com/document/u/1/d/1QJ-M4zfMkD-fjHIqW089RptjLl9EgozZGCceUbvmgfY/edit
Step
1 can be found here: #87570

The problem this PR solves is that today, we have a world wherein dynamo creates guards and installs them, but aot_autograd creates guards... and does not install them. This means that code executed in make_fx can produce new symbolic shape guards, that then do not get used anywhere.

The order today, before this PR, is:

  1. Dynamo starts interpresting a frame
  2. Dynamo call call_user_compiler which sets up a function which will lazily invoke create_aot_dispatcher_function at runtime
  3. Dynamo installs its own guards, pulls in shape_env guards only from dynamo shape_env, installs those as well
  4. Dynamo finishes the frame
  5. The func made in (2) is invoke, and create_aot_dispatcher_function calls make_fx which can create shape env guards, these guards are not used anywhere

Our solution to this is to ensure that make_fx's guards bubble up to dynamo, in a "unified cache", and we do this by piping a fake_mode down to aot_autograd from dynamo.

The unified cache architecture works by changing the lifecycle of when we compile the the aot_autograd function from runtime, to lowering time. We go from lazily compiling create_aot_dispatcher_function to always invoking it at lowering time. This is sound because the compiled_fn is protected by dynamo's guards. The order, therefore, is now:

  1. Dynamo starts interpresting a frame
  2. Dynamo call call_user_compiler which invokes compile_fn which invokes create_aot_dispatcher_function
  3. create_aot_dispatcher_function calls make_fx which can create shape env guards
  4. Dynamo installs its own guards, pulls in shape_env guards, installs those as well
  5. Dynamo finishes the frame

As an added bonus, this allows us to remove the duplicate fake tensor conversion code that used to always be invoked in create_aot_dispatcher_function through process_inputs, reusing the fake tensors from dynamo's conversion step. Outside of dynamo's invocation path, we still need this layer to exist, though, as aot_function is a public entry point to aot_autograd, so the process_inputs/fake tensor conversion code gets lifted up to there.

This also changes the story of create_aot_dispatcher_function - specifically, in that it now MUST be called with fake tensors if we are in fake tensor mode, as it will do no fake tensor conversion of its own. For Dynamo, we rely on the dynamo passing that down, for the public entry point, we rely on aot_function.

One annoying sort of stopgap around this is the parameter story. Dynamo fake-ifys only its inputs, and params get treated a little differently. This is made more confusing by the fact that while create_aot_dispatcher_function takes all fake tensors (params and inputs both), aot_function_simplified still needs real parameters as it uses them as inputs to the compiled aot_function. This means we need to, later on, either fake-ify parameters at dynamo time and pass them alongside the real ones, or to keep the param-only fake tensor conversion where it is in this pr.

— Split Strategy —

  1. **kwargs to all backends
  2. Pass fake_mode, fake_tensor inputs to aot_autograd
  3. Refactor aot_autograd as per #87570 pt1 - replace fake_tensor conversion in aot_autograd w/ fake_tensor from dynamo
  4. Refactor aot_autograd as per #87570 pt2 - move create_fn to lowering time, introduce a cache

cc @mlazos @soumith @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 20, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89373

Note: Links to docs will display an error until the docs builds have been completed.

❌ 20 Failures

As of commit 6e16df2:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants