-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[JIT] Add JIT support for torch.no_grad #41371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. [ghstack-poisoned]
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. ghstack-source-id: faada1e Pull Request resolved: #41371
💊 CI failures summary and remediationsAs of commit 7295ab9 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 54 times. |
|
Okay, so let me explain my thinking on this one. After looking at the source code for I think the best path forward is to create a JITable version of Some limitations of this approach:
|
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a good start, there's a couple things we have to figure out:
-
it would be nice if we could figure out a way to support torch.no_grad() without having a special resolving mechanism for it. We should at least scope out what that would entail before deciding what to do.
-
how do we want to register operators that read/write from global state. Ideally this would be expressible in the schema, see #39497 (comment). A workaround is to use hasSideEffects.
|
hey sorry i had this as an old tab and didn't see your comment on the PR, one sec re reading |
Is the main difficulty here is we have to make |
Most of the JIT use cases I've seen are in
|
You can address this here: https://github.com/pytorch/pytorch/blob/master/torch/jit/_builtins.py#L118 It's how we resolve torch.unique and other ops bound in torch/functional.py _is_special_functional_bound_op |
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. [ghstack-poisoned]
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. ghstack-source-id: 7577a1c Pull Request resolved: #41371
|
Okay, uploaded a new version after offline discussion with @eellison. This new version
The one problem left to solve is how to allow the implementation of |
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. [ghstack-poisoned]
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. [ghstack-poisoned]
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. [ghstack-poisoned]
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. ghstack-source-id: ce0a272 Pull Request resolved: #41371
| >>> z.requires_grad | ||
| False | ||
| """ | ||
| def __init__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be useful to refactor out into a separate class.
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. Differential Revision: [D22649519](https://our.internmc.facebook.com/intern/diff/D22649519) [ghstack-poisoned]
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. Differential Revision: [D22649519](https://our.internmc.facebook.com/intern/diff/D22649519) [ghstack-poisoned]
test/jit/test_with.py
Outdated
|
|
||
| return y, y.requires_grad | ||
|
|
||
| test_input = torch.randn(5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't this test pass even if you don't have a no_grad block?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know what the scope of support for autograd generally is with torchscript, but I think you'd want a subset of the no_grad tests in test_autograd, i.e.:
Lines 912 to 973 in 34025eb
| def test_no_grad(self): | |
| x = torch.ones(5, 5, requires_grad=True) | |
| y = torch.ones(5, 5) * 4 | |
| with torch.no_grad(): | |
| w = x + y | |
| @torch.no_grad() | |
| def adder(x, y): | |
| return x + y | |
| z = adder(x, y) | |
| self.assertFalse(w.requires_grad) | |
| self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5))) | |
| self.assertIsNone(w.grad_fn) | |
| self.assertFalse(z.requires_grad) | |
| self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5))) | |
| self.assertIsNone(z.grad_fn) | |
| # test nested decorator and with-statement on no_grad | |
| with torch.no_grad(): | |
| self.assertFalse(torch.is_grad_enabled()) | |
| w = adder(x, y) | |
| self.assertFalse(torch.is_grad_enabled()) | |
| def test_set_grad_generator_functions(self): | |
| @torch.no_grad() | |
| def gen_no_grad(): | |
| for i in range(10): | |
| self.assertEqual(torch.is_grad_enabled(), False) | |
| yield i | |
| with torch.enable_grad(): | |
| for _ in gen_no_grad(): | |
| self.assertEqual(torch.is_grad_enabled(), True) | |
| @torch.enable_grad() | |
| def gen_enable_grad(): | |
| for i in range(10): | |
| self.assertEqual(torch.is_grad_enabled(), True) | |
| yield i | |
| with torch.no_grad(): | |
| for _ in gen_enable_grad(): | |
| self.assertEqual(torch.is_grad_enabled(), False) | |
| def test_no_grad_python_function(self): | |
| """Python Functions should respect grad mode.""" | |
| x = torch.ones(5, 5, requires_grad=True) | |
| class MyOp(Function): | |
| @staticmethod | |
| def forward(self, x): | |
| return x + 1 | |
| @staticmethod | |
| def backward(self, dy): | |
| return dy | |
| with torch.no_grad(): | |
| y = MyOp.apply(x) | |
| self.assertFalse(y.requires_grad) |
Lines 2026 to 2059 in 34025eb
| def test_no_grad_assignment(self): | |
| x = torch.randn(5, 5, requires_grad=True) | |
| y = torch.randn(5) | |
| with torch.no_grad(): | |
| x[0] = y | |
| self.assertTrue(x.requires_grad) | |
| self.assertIsNone(x.grad_fn) | |
| def test_no_grad_modifies_version(self): | |
| x = torch.randn(5, requires_grad=True) | |
| y = torch.randn(5, requires_grad=True) | |
| z = (x * y).sum() | |
| with torch.no_grad(): | |
| x *= 2 | |
| self.assertRaisesRegex(RuntimeError, 'modified by an inplace operation', | |
| lambda: z.backward()) | |
| def test_no_grad_input(self): | |
| class MyFunction(Function): | |
| @staticmethod | |
| def forward(self, x): | |
| return x | |
| @staticmethod | |
| def backward(self, grad_output): | |
| return grad_output | |
| x = torch.randn(5, requires_grad=True) | |
| with torch.no_grad(): | |
| y = MyFunction.apply(x) | |
| self.assertTrue(x.requires_grad) | |
| self.assertIsNone(y.grad_fn) |
Lines 4210 to 4234 in 34025eb
| def test_custom_function_return_view_in_nograd(self): | |
| class Alias(Function): | |
| @staticmethod | |
| def forward(ctx, x): | |
| return x[:] | |
| @staticmethod | |
| def backward(ctx, gx): | |
| return gx | |
| inp = torch.rand(2, requires_grad=True) | |
| with torch.no_grad(): | |
| output = Alias.apply(inp) | |
| with torch.no_grad(): | |
| expected_output = inp[:] | |
| # Calling the custom function should operate as if we called an equivalent op | |
| self.assertEqual(output.requires_grad, expected_output.requires_grad) | |
| # Check that in-place modification on view throws | |
| leaf_grad_err = "A view was created in no_grad mode and is being modified inplace" | |
| with self.assertRaisesRegex(RuntimeError, leaf_grad_err): | |
| output.zero_() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, I adapted a handful of these for JIT. I'll consult my JIT reviewers on whether what I added is enough.
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. Differential Revision: [D22649519](https://our.internmc.facebook.com/intern/diff/D22649519) [ghstack-poisoned]
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. ghstack-source-id: d2dc1d4 Pull Request resolved: #41371
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, a small comments.
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. Differential Revision: [D22649519](https://our.internmc.facebook.com/intern/diff/D22649519) [ghstack-poisoned]
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. ghstack-source-id: 8a88e0a Pull Request resolved: #41371
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
You can rebase on top of master and it should fix the windows CI.
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. Differential Revision: [D22649519](https://our.internmc.facebook.com/intern/diff/D22649519) [ghstack-poisoned]
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. ghstack-source-id: 921ba90 Pull Request resolved: #41371
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. Differential Revision: [D22649519](https://our.internmc.facebook.com/intern/diff/D22649519) [ghstack-poisoned]
**Summary** This commit enables the use of `torch.no_grad()` in a with item of a with statement within JIT. Note that the use of this context manager as a decorator is not supported. **Test Plan** This commit adds a test case to the existing with statements tests for `torch.no_grad()`. **Fixes** This commit fixes #40259. ghstack-source-id: 73cf332 Pull Request resolved: #41371
|
@SplitInfinity merged this pull request in 87d7c36. |
Stack from ghstack:
Summary
This commit enables the use of
torch.no_grad()in a with item of awith statement within JIT. Note that the use of this context manager as
a decorator is not supported.
Test Plan
This commit adds a test case to the existing with statements tests for
torch.no_grad().Fixes
This commit fixes #40259.
Differential Revision: D22649519