Skip to content

Conversation

@voznesenskym
Copy link
Collaborator

@voznesenskym voznesenskym commented Nov 20, 2022

This is part 2 of the split plan for the bigger migration of critical parts of symbolic-shapes branch to master.

The whole can be found here:
#89363
And https://github.com/pytorch/pytorch/pull/89313/files in parts

The meta-description for this stack:

Step 2 of https://docs.google.com/document/u/1/d/1QJ-M4zfMkD-fjHIqW089RptjLl9EgozZGCceUbvmgfY/edit

Step 1 can be found here: #87570

The problem this PR solves is that today, we have a world wherein dynamo creates guards and installs them, but aot_autograd creates guards... and does not install them. This means that code executed in make_fx can produce new symbolic shape guards, that then do not get used anywhere.

The order today, before this PR, is:

  1. Dynamo starts interpresting a frame
  2. Dynamo call call_user_compiler which sets up a function which will lazily invoke create_aot_dispatcher_function at runtime
  3. Dynamo installs its own guards, pulls in shape_env guards only from dynamo shape_env, installs those as well
  4. Dynamo finishes the frame
  5. The func made in (2) is invoke, and create_aot_dispatcher_function calls make_fx which can create shape env guards, these guards are not used anywhere

Our solution to this is to ensure that make_fx's guards bubble up to dynamo, in a "unified cache", and we do this by piping a fake_mode down to aot_autograd from dynamo.

The unified cache architecture works by changing the lifecycle of when we compile the the aot_autograd function from runtime, to lowering time. We go from lazily compiling create_aot_dispatcher_function to always invoking it at lowering time. This is sound because the compiled_fn is protected by dynamo's guards. The order, therefore, is now:

  1. Dynamo starts interpresting a frame
  2. Dynamo call call_user_compiler which invokes compile_fn which invokes create_aot_dispatcher_function
  3. create_aot_dispatcher_function calls make_fx which can create shape env guards
  4. Dynamo installs its own guards, pulls in shape_env guards, installs those as well
  5. Dynamo finishes the frame

As an added bonus, this allows us to remove the duplicate fake tensor conversion code that used to always be invoked in create_aot_dispatcher_function through process_inputs, reusing the fake tensors from dynamo's conversion step. Outside of dynamo's invocation path, we still need this layer to exist, though, as aot_function is a public entry point to aot_autograd, so the process_inputs/fake tensor conversion code gets lifted up to there.

This also changes the story of create_aot_dispatcher_function - specifically, in that it now MUST be called with fake tensors if we are in fake tensor mode, as it will do no fake tensor conversion of its own. For Dynamo, we rely on the dynamo passing that down, for the public entry point, we rely on aot_function.

One annoying sort of stopgap around this is the parameter story. Dynamo fake-ifys only its inputs, and params get treated a little differently. This is made more confusing by the fact that while create_aot_dispatcher_function takes all fake tensors (params and inputs both), aot_function_simplified still needs real parameters as it uses them as inputs to the compiled aot_function. This means we need to, later on, either fake-ify parameters at dynamo time and pass them alongside the real ones, or to keep the param-only fake tensor conversion where it is in this pr.

— Split Strategy —

  1. **kwargs to all backends
  2. Pass fake_mode, fake_tensor inputs to aot_autograd
  3. Refactor aot_autograd as per #87570 pt1 - replace fake_tensor conversion in aot_autograd w/ fake_tensor from dynamo
  4. Refactor aot_autograd as per #87570 pt2 - move create_fn to lowering time, introduce a cache

cc @mlazos @soumith @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 20, 2022

@voznesenskym
Copy link
Collaborator Author

Part 3 and 4 #89373

@ezyang
Copy link
Contributor

ezyang commented Nov 20, 2022

I'm not going to review part 3 / 4 yet as this is still failing tests

@ezyang
Copy link
Contributor

ezyang commented Nov 20, 2022

There's no need to duplicate the long description on each PR, and it makes it less useful (because people will assume it's the same and skip the important part for each PR)

"""copy while preserving strides"""
if isinstance(x, torch._subclasses.FakeTensor):
# No need to clone fake tensors?
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should understand why we are cloning fake tensors in the first place

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In particular, even if fake tensors have no data, you may still be obligated to clone them in case metadata mutation happens

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was causing issues on the branch, but we can pull it, maybe. Let me check.

triggered = True

def compiler(gm, input):
def compiler(gm, input, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why isn't this in the other PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed, will move

example: Any
is_unspecialized: bool
# Must only be a fake tensor
fake_tensor = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between this and example?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You refer to this as fake_example below, make the naming consistent

should_specialize=self.tensor_should_specialize(),
)
if graph_arg and config.fake_tensor_propagation:
example = tensor_variable.proxy.node.meta["example_value"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it matter that you're pulling out example_value here and not val?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, Val reads better tho

assert isinstance(
self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor
)
return [self.fake_tensor]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not an example, right? The fake tensor is canonical for all possible tensors which would pass the guards, because its sizes/strides have been symintified.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

hasher_type=None,
static_argnums=None,
fake_mode=None,
fake_inputs=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used in next PR

gm: torch.fx.GraphModule,
example_inputs,
fake_mode=None,
fake_inputs=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do I need both example inputs and fake inputs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure you should do this, but fake_mode memoizes conversions of the same real tensor into a fake tensor, so if you already called from_real_tensor on a tensor before, it will consistently give you back the same fake tensor (with the correct symints) in that case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And then, I guess, you could also technically just enable fake mode before calling the compilation function, and then you don't even need to do any of the kwargs business.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You still need a way to pass in fake mode, so kwargs needs to be there.

There's also modes of operation for aot_autograd that take real tensors.

@ezyang
Copy link
Contributor

ezyang commented Nov 20, 2022

So, thinking about the calling convention holistically here, I think there's a decent case to be made that we should not pass real tensors to backends in the dynamic shapes world, for the simple reason that the backend may unsoundly assume that the given tensors sizes even when dynamo actually didn't guard against it. So no real inputs, they should be fakeified. The exception to the rule is parameters which are assumed to be static and can be passed in as is, this also seems fine and can be modeled as passing in a mix of fake and non-fake (non-fake you are even allowed to burn in data ptr, which is what cudagraphs will do)

@ezyang
Copy link
Contributor

ezyang commented Nov 20, 2022

And if this means we need to make dynamic shapes specific changes to the calling convention, so be it. We can tag the functions so that dynamo understands what the backend supports

@voznesenskym
Copy link
Collaborator Author

So, thinking about the calling convention holistically here, I think there's a decent case to be made that we should not pass real tensors to backends in the dynamic shapes world, for the simple reason that the backend may unsoundly assume that the given tensors sizes even when dynamo actually didn't guard against it. So no real inputs, they should be fakeified. The exception to the rule is parameters which are assumed to be static and can be passed in as is, this also seems fine and can be modeled as passing in a mix of fake and non-fake (non-fake you are even allowed to burn in data ptr, which is what cudagraphs will do)

I agree - this is reasonable. However, dynamo can be in fake tensor mode, but not in dynamic tensor mode - but dynamic always precludes fake.

So, the solution there could be where we invoke

compiler_fn(gm, example_inputs, fake_mode, fake_inputs)

we can do something like:

example_inputs = fake_inputs if config.dynamic_shapes else example_inputs
compiler_fn(gm, example_inputs, fake_mode, fake_inputs)

With the idea that we still have a mode (non dynamic) wherein backends need real inputs. And then, for the dynamic case, they are equivalent. We can even change their names from real/fake to something like user_provided_inputs/tracing_inputs.

@voznesenskym
Copy link
Collaborator Author

https://github.com/pytorch/pytorch/pull/88546/files FWIW, the original PR had fake tensor inputs in the example_inputs field, and there was no fake_inputs

@albanD albanD removed their request for review November 22, 2022 22:02
@malfet malfet deleted the voz/symbolic-shapes-to-master-2-2 branch April 24, 2023 21:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants