Skip to content

Conversation

@SplitInfinity
Copy link

@SplitInfinity SplitInfinity commented Jul 13, 2020

Stack from ghstack:

Summary
This commit enables the use of torch.no_grad() in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

Test Plan
This commit adds a test case to the existing with statements tests for
torch.no_grad().

Fixes
This commit fixes #40259.

Differential Revision: D22649519

**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

[ghstack-poisoned]
@SplitInfinity SplitInfinity requested a review from apaszke as a code owner July 13, 2020 23:08
SplitInfinity pushed a commit that referenced this pull request Jul 13, 2020
**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

ghstack-source-id: faada1e
Pull Request resolved: #41371
@SplitInfinity SplitInfinity changed the title [JIT] Add JIT support for torch.no_grad [WIP][JIT] Add JIT support for torch.no_grad Jul 13, 2020
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jul 13, 2020
@dr-ci
Copy link

dr-ci bot commented Jul 13, 2020

💊 CI failures summary and remediations

As of commit 7295ab9 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 54 times.

@SplitInfinity
Copy link
Author

Okay, so let me explain my thinking on this one.

After looking at the source code for torch.no_grad I concluded that it is not scriptable as is because it uses inheritance to allow no_grad to be used as a decorator, and some Python magic to correctly handle no_grad when it is applied to generators (which JIT doesn't support). The rest of the implementation, however, is quite simple:

    def __enter__(self):
        self.prev = torch.is_grad_enabled()
        torch._C.set_grad_enabled(False)

    def __exit__(self, *args):
        torch.set_grad_enabled(self.prev)

I think the best path forward is to create a JITable version of no_grad that has the same __enter__ and __exit__ implementations shown above, but without a dependency on Python. To this end, I added aten::is_grad_enabled and aten::set_grad_enabled that call GradMode::is_enabled() and GradMode::set_enabled() respectively. Thanks to preexisting name resolution mechanisms, torch.is_grad_enabled and torch.set_grad_enabled get emitted as the aforementioned ops. I wrote a new class called NoGrad in the scope of the tests, but I would like to hoist that out and put it somewhere that is globally discoverable and visible so that everyone can use it. Some additional magic will also be required to make sure that torch.no_grad resolves correctly and finds and uses that JITable no_grad implementation.

Some limitations of this approach:

  • no_grad cannot be used as a decorator
  • as I was writing this, I learned that torch.set_grad_enabled is itself something that can be used in a with statement. So, emitting it as an aten::set_grad_enabled op that calls GradMode::set_enabled() might have some unintended consequences if someone tries to JIT code that uses torch.set_grad_enabled directly.

@SplitInfinity SplitInfinity requested review from eellison and suo July 13, 2020 23:50
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a good start, there's a couple things we have to figure out:

  • it would be nice if we could figure out a way to support torch.no_grad() without having a special resolving mechanism for it. We should at least scope out what that would entail before deciding what to do.

  • how do we want to register operators that read/write from global state. Ideally this would be expressible in the schema, see #39497 (comment). A workaround is to use hasSideEffects.

@eellison
Copy link
Contributor

hey sorry i had this as an old tab and didn't see your comment on the PR, one sec re reading

@eellison
Copy link
Contributor

I think the best path forward is to create a JITable version of no_grad that has the same enter and exit implementations shown above, but without a dependency on Python

Is the main difficulty here is we have to make torch.no_grad work as a decorator as well? What about special-casing objects with _DecoratorContextManager as the one concrete supertype and ignoring the superclass ?

@SplitInfinity
Copy link
Author

SplitInfinity commented Jul 14, 2020

I think the best path forward is to create a JITable version of no_grad that has the same enter and exit implementations shown above, but without a dependency on Python

Is the main difficulty here is we have to make torch.no_grad work as a decorator as well? What about special-casing objects with _DecoratorContextManager as the one concrete supertype and ignoring the superclass ?

Most of the JIT use cases I've seen are in with statements, not as decorators. That's also what I was focusing on in this PR. So I think the main problems are:

  • dealing with the fact that torch.no_grad uses inheritance (your suggestion addresses this)
  • making torch.jit.no_grad() resolve to a scripted version of the corresponding Python class in code that is JITed. In general, torch.x is assumed to be an aten::x op at the moment, and scripted classes have to be explicitly annotated with @torch.jit.script to be compiled and added to the compilation unit and show up during name resolution.

@eellison
Copy link
Contributor

making torch.jit.no_grad() resolve to a scripted version of the corresponding Python class in code that is JITed

You can address this here: https://github.com/pytorch/pytorch/blob/master/torch/jit/_builtins.py#L118 It's how we resolve torch.unique and other ops bound in torch/functional.py _is_special_functional_bound_op

**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

[ghstack-poisoned]
@SplitInfinity SplitInfinity requested a review from albanD as a code owner July 14, 2020 23:55
SplitInfinity pushed a commit that referenced this pull request Jul 14, 2020
**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

ghstack-source-id: 7577a1c
Pull Request resolved: #41371
@SplitInfinity
Copy link
Author

Okay, uploaded a new version after offline discussion with @eellison. This new version

  • gets rid of the new NoGrad class I added and uses torch.no_grad directly
  • makes some changes to torch.no_grad to make it scriptable; most of these seem harmless and I will add a reviewer for those changes once it is ready to review

The one problem left to solve is how to allow the implementation of no_grad to use inheritance and to ignore it for JIT purposes. This is usually not a good idea, but in this specific case, no_grad extends _DecoratorContextManager which allows it to be used as a decorator and on generator functions. These use cases are out of scope for JIT, so omitting the functionality of the base class in JIT would not be a problem.

**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

[ghstack-poisoned]
**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

[ghstack-poisoned]
@SplitInfinity SplitInfinity changed the title [WIP][JIT] Add JIT support for torch.no_grad [JIT] Add JIT support for torch.no_grad Jul 15, 2020
**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

[ghstack-poisoned]
SplitInfinity pushed a commit that referenced this pull request Jul 15, 2020
**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

ghstack-source-id: ce0a272
Pull Request resolved: #41371
>>> z.requires_grad
False
"""
def __init__(self):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be useful to refactor out into a separate class.

**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

Differential Revision: [D22649519](https://our.internmc.facebook.com/intern/diff/D22649519)

[ghstack-poisoned]
@SplitInfinity SplitInfinity requested a review from colesbury August 3, 2020 18:01
**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

Differential Revision: [D22649519](https://our.internmc.facebook.com/intern/diff/D22649519)

[ghstack-poisoned]

return y, y.requires_grad

test_input = torch.randn(5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't this test pass even if you don't have a no_grad block?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know what the scope of support for autograd generally is with torchscript, but I think you'd want a subset of the no_grad tests in test_autograd, i.e.:

def test_no_grad(self):
x = torch.ones(5, 5, requires_grad=True)
y = torch.ones(5, 5) * 4
with torch.no_grad():
w = x + y
@torch.no_grad()
def adder(x, y):
return x + y
z = adder(x, y)
self.assertFalse(w.requires_grad)
self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
self.assertIsNone(w.grad_fn)
self.assertFalse(z.requires_grad)
self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5)))
self.assertIsNone(z.grad_fn)
# test nested decorator and with-statement on no_grad
with torch.no_grad():
self.assertFalse(torch.is_grad_enabled())
w = adder(x, y)
self.assertFalse(torch.is_grad_enabled())
def test_set_grad_generator_functions(self):
@torch.no_grad()
def gen_no_grad():
for i in range(10):
self.assertEqual(torch.is_grad_enabled(), False)
yield i
with torch.enable_grad():
for _ in gen_no_grad():
self.assertEqual(torch.is_grad_enabled(), True)
@torch.enable_grad()
def gen_enable_grad():
for i in range(10):
self.assertEqual(torch.is_grad_enabled(), True)
yield i
with torch.no_grad():
for _ in gen_enable_grad():
self.assertEqual(torch.is_grad_enabled(), False)
def test_no_grad_python_function(self):
"""Python Functions should respect grad mode."""
x = torch.ones(5, 5, requires_grad=True)
class MyOp(Function):
@staticmethod
def forward(self, x):
return x + 1
@staticmethod
def backward(self, dy):
return dy
with torch.no_grad():
y = MyOp.apply(x)
self.assertFalse(y.requires_grad)

pytorch/test/test_autograd.py

Lines 2026 to 2059 in 34025eb

def test_no_grad_assignment(self):
x = torch.randn(5, 5, requires_grad=True)
y = torch.randn(5)
with torch.no_grad():
x[0] = y
self.assertTrue(x.requires_grad)
self.assertIsNone(x.grad_fn)
def test_no_grad_modifies_version(self):
x = torch.randn(5, requires_grad=True)
y = torch.randn(5, requires_grad=True)
z = (x * y).sum()
with torch.no_grad():
x *= 2
self.assertRaisesRegex(RuntimeError, 'modified by an inplace operation',
lambda: z.backward())
def test_no_grad_input(self):
class MyFunction(Function):
@staticmethod
def forward(self, x):
return x
@staticmethod
def backward(self, grad_output):
return grad_output
x = torch.randn(5, requires_grad=True)
with torch.no_grad():
y = MyFunction.apply(x)
self.assertTrue(x.requires_grad)
self.assertIsNone(y.grad_fn)

pytorch/test/test_autograd.py

Lines 4210 to 4234 in 34025eb

def test_custom_function_return_view_in_nograd(self):
class Alias(Function):
@staticmethod
def forward(ctx, x):
return x[:]
@staticmethod
def backward(ctx, gx):
return gx
inp = torch.rand(2, requires_grad=True)
with torch.no_grad():
output = Alias.apply(inp)
with torch.no_grad():
expected_output = inp[:]
# Calling the custom function should operate as if we called an equivalent op
self.assertEqual(output.requires_grad, expected_output.requires_grad)
# Check that in-place modification on view throws
leaf_grad_err = "A view was created in no_grad mode and is being modified inplace"
with self.assertRaisesRegex(RuntimeError, leaf_grad_err):
output.zero_()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, I adapted a handful of these for JIT. I'll consult my JIT reviewers on whether what I added is enough.

**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

Differential Revision: [D22649519](https://our.internmc.facebook.com/intern/diff/D22649519)

[ghstack-poisoned]
SplitInfinity pushed a commit that referenced this pull request Aug 3, 2020
**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

ghstack-source-id: d2dc1d4
Pull Request resolved: #41371
@SplitInfinity
Copy link
Author

Ping @albanD @eellison

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, a small comments.

**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

Differential Revision: [D22649519](https://our.internmc.facebook.com/intern/diff/D22649519)

[ghstack-poisoned]
SplitInfinity pushed a commit that referenced this pull request Aug 18, 2020
**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

ghstack-source-id: 8a88e0a
Pull Request resolved: #41371
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
You can rebase on top of master and it should fix the windows CI.

**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

Differential Revision: [D22649519](https://our.internmc.facebook.com/intern/diff/D22649519)

[ghstack-poisoned]
SplitInfinity pushed a commit that referenced this pull request Aug 18, 2020
**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

ghstack-source-id: 921ba90
Pull Request resolved: #41371
**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

Differential Revision: [D22649519](https://our.internmc.facebook.com/intern/diff/D22649519)

[ghstack-poisoned]
SplitInfinity pushed a commit that referenced this pull request Aug 27, 2020
**Summary**
This commit enables the use of `torch.no_grad()` in a with item of a
with statement within JIT. Note that the use of this context manager as
a decorator is not supported.

**Test Plan**
This commit adds a test case to the existing with statements tests for
`torch.no_grad()`.

**Fixes**
This commit fixes #40259.

ghstack-source-id: 73cf332
Pull Request resolved: #41371
@facebook-github-bot
Copy link
Contributor

@SplitInfinity merged this pull request in 87d7c36.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants