Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 75 additions & 47 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def _check_checkpoint_sequential(
module_lists_to_compare,
num_chunks,
input,
use_reentrant,
):

# not checkpointed
out = model(input)
out_not_checkpointed = out.detach().clone()
Expand All @@ -69,7 +69,9 @@ def _check_checkpoint_sequential(
detached.requires_grad = True

# pass list of modules to checkpoint
out = checkpoint_sequential(model_to_compare, num_chunks, detached)
out = checkpoint_sequential(
model_to_compare, num_chunks, detached, use_reentrant=use_reentrant
)
out_checkpointed = out.detach().clone()
model.zero_grad()
out.sum().backward()
Expand All @@ -96,21 +98,26 @@ def __init__(self):

def forward(self, input_var):
self.counter += 1
return input_var
# For reentrant, need to have autograd actually
# pack a tensor to trigger recomp
ret = input_var * 2
return ret

# checkpointed
modules = [Net() for _ in range(10)]
for m in modules:
self.assertEqual(m.counter, 0)
input_var = torch.randn(3, 4, requires_grad=True)
out = checkpoint_sequential(modules, 2, input_var)
for m in modules:
self.assertEqual(m.counter, 1)
out.sum().backward()
for m in modules[:(len(modules) // 2)]:
self.assertEqual(m.counter, 2)
for m in modules[(len(modules) // 2):]:
self.assertEqual(m.counter, 1)
for use_reentrant in [True, False]:
with self.subTest(use_reentrant=use_reentrant):
modules = [Net() for _ in range(10)]
for m in modules:
self.assertEqual(m.counter, 0)
input_var = torch.randn(3, 4, requires_grad=True)
out = checkpoint_sequential(modules, 2, input_var, use_reentrant=use_reentrant)
for m in modules:
self.assertEqual(m.counter, 1)
out.sum().backward()
for m in modules[:(len(modules) // 2)]:
self.assertEqual(m.counter, 2)
for m in modules[(len(modules) // 2):]:
self.assertEqual(m.counter, 1)

def test_checkpoint_valid(self):
model = nn.Sequential(
Expand All @@ -132,27 +139,42 @@ def test_checkpoint_valid(self):
torch.autograd.grad(
outputs=[out], grad_outputs=[torch.ones(1, 5)], inputs=[input_var], create_graph=True
)
# works with use_reentrant=False, and grads are the same
out = model(input_var)
grads_no_checkpoint = torch.autograd.grad(
outputs=[out], grad_outputs=[torch.ones(1, 5)], inputs=[input_var], create_graph=True,
)
out_checkpoint = checkpoint_sequential(modules, chunks, input_var, use_reentrant=False)
# check outputs are the same
self.assertEqual(out_checkpoint, out)
grads_checkpoint = torch.autograd.grad(
outputs=[out_checkpoint], grad_outputs=[torch.ones(1, 5)], inputs=[input_var], create_graph=True,
)
self.assertEqual(grads_no_checkpoint, grads_checkpoint)

def test_checkpoint(self):
model = nn.Sequential(
nn.Linear(100, 50),
nn.ReLU(),
nn.Linear(50, 20),
nn.ReLU(),
nn.Linear(20, 5),
nn.ReLU()
)
for use_reentrant in [True, False]:
with self.subTest(use_reentrant=use_reentrant):
model = nn.Sequential(
nn.Linear(100, 50),
nn.ReLU(),
nn.Linear(50, 20),
nn.ReLU(),
nn.Linear(20, 5),
nn.ReLU()
)

# Compare uncheckpointed model with its checkpointed counterparts
# In addition to running checkpoint_sequential on the nn.Sequential
# instance, we also run the function on the list of functions within
# the module.
self._check_checkpoint_sequential(
model,
[list(model.children()), model],
2,
torch.randn(1, 100, requires_grad=True)
)
# Compare uncheckpointed model with its checkpointed counterparts
# In addition to running checkpoint_sequential on the nn.Sequential
# instance, we also run the function on the list of functions within
# the module.
self._check_checkpoint_sequential(
model,
[list(model.children()), model],
2,
torch.randn(1, 100, requires_grad=True),
use_reentrant=use_reentrant,
)

def test_checkpoint_module_list(self):
class ModuleListNet(nn.Module):
Expand All @@ -173,15 +195,18 @@ def forward(self, input):
input = layer(input)
return input

model = ModuleListNet()

# Compare uncheckpointed model with its checkpointed counterparts.
self._check_checkpoint_sequential(
model,
[list(model.module_list.children()), model.module_list],
2,
torch.randn(1, 100, requires_grad=True),
)
for use_reentrant in [True, False]:
with self.subTest(use_reentrant=use_reentrant):
model = ModuleListNet()

# Compare uncheckpointed model with its checkpointed counterparts.
self._check_checkpoint_sequential(
model,
[list(model.module_list.children()), model.module_list],
2,
torch.randn(1, 100, requires_grad=True),
use_reentrant=use_reentrant,
)

def test_checkpoint_sequential_deprecated_multiple_args(self):
class Two(nn.Module):
Expand All @@ -192,18 +217,21 @@ def forward(self, a, b):
a = torch.randn(1, 100, requires_grad=True)
b = torch.randn(1, 100, requires_grad=True)

with self.assertRaises(TypeError):
checkpoint_sequential(model, 1, a, b) # type: ignore[call-arg]
for use_reentrant in [True, False]:
with self.subTest(use_reentrant=use_reentrant):
with self.assertRaises(TypeError):
checkpoint_sequential(model, 1, a, b) # type: ignore[call-arg]

def test_checkpoint_sequential_deprecated_no_args(self):
class Noop(nn.Module):
def forward(self):
pass

model = nn.Sequential(Noop())

with self.assertRaises(TypeError):
checkpoint_sequential(model, 1) # type: ignore[call-arg]
for use_reentrant in [True, False]:
with self.subTest(use_reentrant=use_reentrant):
with self.assertRaises(TypeError):
checkpoint_sequential(model, 1) # type: ignore[call-arg]

def test_checkpoint_rng_cpu(self):
for _ in range(5):
Expand Down
20 changes: 16 additions & 4 deletions torch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
preserve_rng_state(bool, optional): Omit stashing and restoring
the RNG state during each checkpoint.
Default: ``True``
use_reentrant(bool, optional): Use checkpointing
use_reentrant(bool, optional): Use (the default) checkpointing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what that means? There is a Default: True in this section.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this to indicate that reentrant is the default / original way of checkpointing. Usually the user would not be concerned about the implementation detail of checkpointing so clarifying which flag enables the current "default" seems valuable to me but can take this out if you prefer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is already done by the Default: specification (note that we (try to) use this form consistently across the doc. So the user expects to find that information there already.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, will remove

implementation that requires re-entrant autograd.
If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
implementation that does not require re-entrant autograd. This
Expand Down Expand Up @@ -256,7 +256,7 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
)


def checkpoint_sequential(functions, segments, input, **kwargs):
def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwargs):
r"""A helper function for checkpointing sequential models.

Sequential models execute a list of modules/functions in order
Expand Down Expand Up @@ -290,6 +290,14 @@ def checkpoint_sequential(functions, segments, input, **kwargs):
preserve_rng_state(bool, optional): Omit stashing and restoring
the RNG state during each checkpoint.
Default: ``True``
use_reentrant(bool, optional): Use (the default) checkpointing
implementation that requires re-entrant autograd.
If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
implementation that does not require re-entrant autograd. This
allows ``checkpoint`` to support additional functionality, such as
working as expected with ``torch.autograd.grad`` and support for
keyword arguments input into the checkpointed function.
Default: ``True``

Returns:
Output of running :attr:`functions` sequentially on :attr:`*inputs`
Expand Down Expand Up @@ -319,8 +327,12 @@ def forward(input):
end = -1
for start in range(0, segment_size * (segments - 1), segment_size):
end = start + segment_size - 1
input = checkpoint(run_function(start, end, functions), input,
preserve_rng_state=preserve)
input = checkpoint(
run_function(start, end, functions),
input,
use_reentrant=use_reentrant,
preserve_rng_state=preserve
)
return run_function(end + 1, len(functions) - 1, functions)(input)

def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args, **kwargs):
Expand Down