-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Test case where some inputs are Tensor Subclasses in CompositeCompiance #74645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR adds tests for when only some inputs are Tensor Subclasses. Why is this important to test? ============================== Consider the following hypothetical out-of-place operation: ``` def my_add(x, y): result = x.clone() result.add_(y) return result ``` You may expect this to work the same as torch.add. If x is not a Tensor Subclass, but y is a Tensor subclass, then this returns us a regular Tensor, NOT a Tensor subclass! This is exactly the type of in-place operations that causes `vmap` to fail and will be problematic for certain Tensor Subclasses in the future so we're adding tests to make sure Composite pytorch operations don't do this. What exactly does this PR do? ============================= Composite compliance now takes a sample input and produces a test case where some of the sample inputs are Tensor Subclasses. It then sends this through the original operation, once with Python Mode and one without. (Why once with Python Mode? Because we want to use it to detect the pattern of "create a Tensor and call resize_ on it") Finally, it repeats this process for all possiblities where the inputs are Tensor subclasses. For example, if the sample input is (x, y), then we test all four of the following cases: - Subclass(x), y - x, Subclass(y) - Subclass(x), Subclass(y) - x, y Test Plan ========= - run tests [ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit b8d76a1 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
|
|
||
| # Introspection please save us | ||
| def is_inplace(func): | ||
| return func.overloadpacket.__name__[-1] == '_' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to exclude things like __add__ no?
Also add back __iadd__.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I will add those cases. I don't know if this catches everything but we're trying to be best effort here until there is a better solution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the tags should provide a better solution here when they're ready cc @anjali411
| return False | ||
| if len(lst) == 0: | ||
| return False | ||
| return isinstance(lst[0], torch.Tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not all(isinstance(t, torch.Tensor) for t in lst)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
|
|
||
| try: | ||
| op(*new_args, **new_kwargs) | ||
| except RuntimeError as err: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the goal here to catch only the error that would raise within the torch_dispatch impl or any error that might happen when a function is not properly implemented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any error that might happen when a function is not properly implemented?
We don't really want to catch errors where a function has a bug in it, but we do because of the following.
There's two things we want to catch:
- errors that would raise within the torch_dispatch impl
- data_ptr accesses
The first is easy to filter for (we could make the error a different error class), the second is always going to be a RuntimeError due to how it is implemented (if you try to access the data_ptr of the wrapper Tensor, it raises you some internal RuntimeError).
So the most general thing to catch here was RuntimeError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok!
A small comment here specifying that we want both of these captured would be nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do!
| "through the CompositeImplicitAutograd Compliance section in " | ||
| "aten/src/ATen/native/README.md for how to resolve this. " | ||
| ) from err | ||
| raise_composite_compliance_error(err) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How long do the new tests take to run?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Total runtime is 1:14 on both CPU and CUDA.
CPU-only is: 32s
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes me wonder if testing is the wrong way to check these properties. If these tests are quick LGTM.
@ezyang this is one of the reasons why I wanted symbolic shape stuff for OpInfos -- some of these OpInfos are too large and we really don't need them to be so large for these tests :).
I don't think testing as done in this PR is the optimal way to check these properties due to the exponential runtime. I don't have any better ideas for how to test it with our current infrastructure (and this is exactly what we do for vmap testing -- we have to test the cases where some inputs are BatchedTensors and some aren't...). Here are the things I would want in an ideal world:
|
I'm not sure about this one though. The same test would be needed to ensure that the source of truth does match the c++ impl (unless you generate the c++ from the python of course). |
…siteCompiance" This PR adds tests for when only some inputs are Tensor Subclasses. Why is this important to test? ============================== Consider the following hypothetical out-of-place operation: ``` def my_add(x, y): result = x.clone() result.add_(y) return result ``` You may expect this to work the same as torch.add. If x is not a Tensor Subclass, but y is a Tensor subclass, then this returns us a regular Tensor, NOT a Tensor subclass! This is exactly the type of in-place operations that causes `vmap` to fail and will be problematic for certain Tensor Subclasses in the future so we're adding tests to make sure Composite pytorch operations don't do this. What exactly does this PR do? ============================= Composite compliance now takes a sample input and produces a test case where some of the sample inputs are Tensor Subclasses. It then sends this through the original operation, once with Python Mode and one without. (Why once with Python Mode? Because we want to use it to detect the pattern of "create a Tensor and call resize_ on it") Finally, it repeats this process for all possiblities where the inputs are Tensor subclasses. For example, if the sample input is (x, y), then we test all four of the following cases: - Subclass(x), y - x, Subclass(y) - Subclass(x), Subclass(y) - x, y Test Plan ========= - run tests [ghstack-poisoned]
…siteCompiance" This PR adds tests for when only some inputs are Tensor Subclasses. Why is this important to test? ============================== Consider the following hypothetical out-of-place operation: ``` def my_add(x, y): result = x.clone() result.add_(y) return result ``` You may expect this to work the same as torch.add. If x is not a Tensor Subclass, but y is a Tensor subclass, then this returns us a regular Tensor, NOT a Tensor subclass! This is exactly the type of in-place operations that causes `vmap` to fail and will be problematic for certain Tensor Subclasses in the future so we're adding tests to make sure Composite pytorch operations don't do this. What exactly does this PR do? ============================= Composite compliance now takes a sample input and produces a test case where some of the sample inputs are Tensor Subclasses. It then sends this through the original operation, once with Python Mode and one without. (Why once with Python Mode? Because we want to use it to detect the pattern of "create a Tensor and call resize_ on it") Finally, it repeats this process for all possiblities where the inputs are Tensor subclasses. For example, if the sample input is (x, y), then we test all four of the following cases: - Subclass(x), y - x, Subclass(y) - Subclass(x), Subclass(y) - x, y Test Plan ========= - run tests [ghstack-poisoned]
|
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
…ce (#74645) Summary: Pull Request resolved: #74645 This PR adds tests for when only some inputs are Tensor Subclasses. Why is this important to test? ============================== Consider the following hypothetical out-of-place operation: ``` def my_add(x, y): result = x.clone() result.add_(y) return result ``` You may expect this to work the same as torch.add. If x is not a Tensor Subclass, but y is a Tensor subclass, then this returns us a regular Tensor, NOT a Tensor subclass! This is exactly the type of in-place operations that causes `vmap` to fail and will be problematic for certain Tensor Subclasses in the future so we're adding tests to make sure Composite pytorch operations don't do this. What exactly does this PR do? ============================= Composite compliance now takes a sample input and produces a test case where some of the sample inputs are Tensor Subclasses. It then sends this through the original operation, once with Python Mode and one without. (Why once with Python Mode? Because we want to use it to detect the pattern of "create a Tensor and call resize_ on it") Finally, it repeats this process for all possiblities where the inputs are Tensor subclasses. For example, if the sample input is (x, y), then we test all four of the following cases: - Subclass(x), y - x, Subclass(y) - Subclass(x), Subclass(y) - x, y Test Plan ========= - run tests Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D35186862 Pulled By: zou3519 fbshipit-source-id: 102477507b56583463668db7523a6586d92b357d
Stack from ghstack:
This PR adds tests for when only some inputs are Tensor Subclasses.
Why is this important to test?
Consider the following hypothetical out-of-place operation:
You may expect this to work the same as torch.add. If x is not a Tensor
Subclass, but y is a Tensor subclass, then this returns us a regular
Tensor, NOT a Tensor subclass!
This is exactly the type of in-place operations that causes
vmaptofail and will be problematic for certain Tensor Subclasses in the future
so we're adding tests to make sure Composite pytorch operations don't do
this.
What exactly does this PR do?
Composite compliance now takes a sample input and produces a test case
where some of the sample inputs are Tensor Subclasses. It then sends
this through the original operation, once with Python Mode and one
without.
(Why once with Python Mode? Because we want to use it to detect the
pattern of "create a Tensor and call resize_ on it")
Finally, it repeats this process for all possiblities where the inputs
are Tensor subclasses. For example, if the sample input is (x, y), then
we test all four of the following cases:
Test Plan
Differential Revision: D35186862