Skip to content

Conversation

@zou3519
Copy link
Contributor

@zou3519 zou3519 commented Mar 23, 2022

Stack from ghstack:

This PR adds tests for when only some inputs are Tensor Subclasses.

Why is this important to test?

Consider the following hypothetical out-of-place operation:

def my_add(x, y):
  result = x.clone()
  result.add_(y)
  return result

You may expect this to work the same as torch.add. If x is not a Tensor
Subclass, but y is a Tensor subclass, then this returns us a regular
Tensor, NOT a Tensor subclass!

This is exactly the type of in-place operations that causes vmap to
fail and will be problematic for certain Tensor Subclasses in the future
so we're adding tests to make sure Composite pytorch operations don't do
this.

What exactly does this PR do?

Composite compliance now takes a sample input and produces a test case
where some of the sample inputs are Tensor Subclasses. It then sends
this through the original operation, once with Python Mode and one
without.

(Why once with Python Mode? Because we want to use it to detect the
pattern of "create a Tensor and call resize_ on it")

Finally, it repeats this process for all possiblities where the inputs
are Tensor subclasses. For example, if the sample input is (x, y), then
we test all four of the following cases:

  • Subclass(x), y
  • x, Subclass(y)
  • Subclass(x), Subclass(y)
  • x, y

Test Plan

  • run tests

Differential Revision: D35186862

This PR adds tests for when only some inputs are Tensor Subclasses.

Why is this important to test?
==============================

Consider the following hypothetical out-of-place operation:
```
def my_add(x, y):
  result = x.clone()
  result.add_(y)
  return result
```

You may expect this to work the same as torch.add. If x is not a Tensor
Subclass, but y is a Tensor subclass, then this returns us a regular
Tensor, NOT a Tensor subclass!

This is exactly the type of in-place operations that causes `vmap` to
fail and will be problematic for certain Tensor Subclasses in the future
so we're adding tests to make sure Composite pytorch operations don't do
this.

What exactly does this PR do?
=============================
Composite compliance now takes a sample input and produces a test case
where some of the sample inputs are Tensor Subclasses. It then sends
this through the original operation, once with Python Mode and one
without.

(Why once with Python Mode? Because we want to use it to detect the
pattern of "create a Tensor and call resize_ on it")

Finally, it repeats this process for all possiblities where the inputs
are Tensor subclasses. For example, if the sample input is (x, y), then
we test all four of the following cases:
- Subclass(x), y
- x, Subclass(y)
- Subclass(x), Subclass(y)
- x, y

Test Plan
=========
- run tests

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 23, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit b8d76a1 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@zou3519 zou3519 requested review from albanD, bdhirsh and ezyang March 24, 2022 13:56

# Introspection please save us
def is_inplace(func):
return func.overloadpacket.__name__[-1] == '_'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to exclude things like __add__ no?
Also add back __iadd__.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will add those cases. I don't know if this catches everything but we're trying to be best effort here until there is a better solution

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the tags should provide a better solution here when they're ready cc @anjali411

return False
if len(lst) == 0:
return False
return isinstance(lst[0], torch.Tensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not all(isinstance(t, torch.Tensor) for t in lst)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure


try:
op(*new_args, **new_kwargs)
except RuntimeError as err:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the goal here to catch only the error that would raise within the torch_dispatch impl or any error that might happen when a function is not properly implemented?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any error that might happen when a function is not properly implemented?

We don't really want to catch errors where a function has a bug in it, but we do because of the following.

There's two things we want to catch:

  • errors that would raise within the torch_dispatch impl
  • data_ptr accesses

The first is easy to filter for (we could make the error a different error class), the second is always going to be a RuntimeError due to how it is implemented (if you try to access the data_ptr of the wrapper Tensor, it raises you some internal RuntimeError).

So the most general thing to catch here was RuntimeError

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok!
A small comment here specifying that we want both of these captured would be nice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do!

"through the CompositeImplicitAutograd Compliance section in "
"aten/src/ATen/native/README.md for how to resolve this. "
) from err
raise_composite_compliance_error(err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long do the new tests take to run?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Total runtime is 1:14 on both CPU and CUDA.
CPU-only is: 32s

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me wonder if testing is the wrong way to check these properties. If these tests are quick LGTM.

@zou3519
Copy link
Contributor Author

zou3519 commented Mar 25, 2022

If these tests are quick LGTM.

@ezyang this is one of the reasons why I wanted symbolic shape stuff for OpInfos -- some of these OpInfos are too large and we really don't need them to be so large for these tests :).

This makes me wonder if testing is the wrong way to check these properties.

I don't think testing as done in this PR is the optimal way to check these properties due to the exponential runtime. I don't have any better ideas for how to test it with our current infrastructure (and this is exactly what we do for vmap testing -- we have to test the cases where some inputs are BatchedTensors and some aren't...). Here are the things I would want in an ideal world:

  • If we, for example, had a Python source of truth for our composite operations then we could statically analyze the Python and that should be faster
  • Alternatively, if we enforced that out-of-place operations be written in functional form but then used a backend or a "mutation pass" (the opposite of a functionalization pass) to generate fast kernels that do end up doing in-place operations, then we wouldn't even need to test this, because what we would be transforming over are the out-of-place operations

@albanD
Copy link
Collaborator

albanD commented Mar 25, 2022

If we, for example, had a Python source of truth for our composite operations then we could statically analyze the Python and that should be faster

I'm not sure about this one though. The same test would be needed to ensure that the source of truth does match the c++ impl (unless you generate the c++ from the python of course).

zou3519 added 2 commits March 25, 2022 13:26
…siteCompiance"

This PR adds tests for when only some inputs are Tensor Subclasses.

Why is this important to test?
==============================

Consider the following hypothetical out-of-place operation:
```
def my_add(x, y):
  result = x.clone()
  result.add_(y)
  return result
```

You may expect this to work the same as torch.add. If x is not a Tensor
Subclass, but y is a Tensor subclass, then this returns us a regular
Tensor, NOT a Tensor subclass!

This is exactly the type of in-place operations that causes `vmap` to
fail and will be problematic for certain Tensor Subclasses in the future
so we're adding tests to make sure Composite pytorch operations don't do
this.

What exactly does this PR do?
=============================
Composite compliance now takes a sample input and produces a test case
where some of the sample inputs are Tensor Subclasses. It then sends
this through the original operation, once with Python Mode and one
without.

(Why once with Python Mode? Because we want to use it to detect the
pattern of "create a Tensor and call resize_ on it")

Finally, it repeats this process for all possiblities where the inputs
are Tensor subclasses. For example, if the sample input is (x, y), then
we test all four of the following cases:
- Subclass(x), y
- x, Subclass(y)
- Subclass(x), Subclass(y)
- x, y

Test Plan
=========
- run tests

[ghstack-poisoned]
…siteCompiance"

This PR adds tests for when only some inputs are Tensor Subclasses.

Why is this important to test?
==============================

Consider the following hypothetical out-of-place operation:
```
def my_add(x, y):
  result = x.clone()
  result.add_(y)
  return result
```

You may expect this to work the same as torch.add. If x is not a Tensor
Subclass, but y is a Tensor subclass, then this returns us a regular
Tensor, NOT a Tensor subclass!

This is exactly the type of in-place operations that causes `vmap` to
fail and will be problematic for certain Tensor Subclasses in the future
so we're adding tests to make sure Composite pytorch operations don't do
this.

What exactly does this PR do?
=============================
Composite compliance now takes a sample input and produces a test case
where some of the sample inputs are Tensor Subclasses. It then sends
this through the original operation, once with Python Mode and one
without.

(Why once with Python Mode? Because we want to use it to detect the
pattern of "create a Tensor and call resize_ on it")

Finally, it repeats this process for all possiblities where the inputs
are Tensor subclasses. For example, if the sample input is (x, y), then
we test all four of the following cases:
- Subclass(x), y
- x, Subclass(y)
- Subclass(x), Subclass(y)
- x, y

Test Plan
=========
- run tests

[ghstack-poisoned]
@zou3519
Copy link
Contributor Author

zou3519 commented Mar 28, 2022

@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Mar 28, 2022
…ce (#74645)

Summary:
Pull Request resolved: #74645

This PR adds tests for when only some inputs are Tensor Subclasses.

Why is this important to test?
==============================

Consider the following hypothetical out-of-place operation:
```
def my_add(x, y):
  result = x.clone()
  result.add_(y)
  return result
```

You may expect this to work the same as torch.add. If x is not a Tensor
Subclass, but y is a Tensor subclass, then this returns us a regular
Tensor, NOT a Tensor subclass!

This is exactly the type of in-place operations that causes `vmap` to
fail and will be problematic for certain Tensor Subclasses in the future
so we're adding tests to make sure Composite pytorch operations don't do
this.

What exactly does this PR do?
=============================
Composite compliance now takes a sample input and produces a test case
where some of the sample inputs are Tensor Subclasses. It then sends
this through the original operation, once with Python Mode and one
without.

(Why once with Python Mode? Because we want to use it to detect the
pattern of "create a Tensor and call resize_ on it")

Finally, it repeats this process for all possiblities where the inputs
are Tensor subclasses. For example, if the sample input is (x, y), then
we test all four of the following cases:
- Subclass(x), y
- x, Subclass(y)
- Subclass(x), Subclass(y)
- x, y

Test Plan
=========
- run tests

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D35186862

Pulled By: zou3519

fbshipit-source-id: 102477507b56583463668db7523a6586d92b357d
@facebook-github-bot facebook-github-bot deleted the gh/zou3519/417/head branch April 1, 2022 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants