Skip to content

Commit 73c7752

Browse files
committed
Update on "Test FSDP with submodule non-reentrant checkpointing"
With combining FSDP with reentrant checkpointing, the post backward hook might run twice, and then hit [this error](https://github.com/pytorch/pytorch/blob/e20ec44544c17d6d3d411f88b870e05043bda731/torch/distributed/fsdp/_runtime_utils.py#L487). This is because reentrant backward uses nested autograd GraphTasks. The inner GraphTask is not aware of the outer one and therefore will flush pending `AccumulateGrad` invocations on exit, which in turn triggers the post backward hooks registered by FSDP. Later, the outer GraphTask will trigger that again, leading to the above error. PR #89791 relaxes the FSDP training state check, but we still run into grad value check failures occasionally. Therefore, this PR only lands the test for non-reentrant test, and we can enable the reentrant test when the accuracy issues are addressed. [ghstack-poisoned]
1 parent 6319739 commit 73c7752

File tree

1 file changed

+26
-53
lines changed

1 file changed

+26
-53
lines changed

test/distributed/fsdp/test_fsdp_checkpoint.py

Lines changed: 26 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,9 @@ def test_basic_checkpoint_end_to_end(
279279
dist.barrier()
280280

281281

282+
instantiate_parametrized_tests(TestFSDPCheckpoint)
283+
284+
282285
class CheckpointModule(nn.Module):
283286
def __init__(self, checkpoint: bool = False, use_reentrant: bool = True):
284287
super().__init__()
@@ -300,64 +303,56 @@ def __init__(self, checkpoint: bool = False, use_reentrant: bool = True):
300303
self.l1 = nn.Linear(100, 100)
301304
self.s1 = CheckpointModule(checkpoint, use_reentrant)
302305
self.s2 = CheckpointModule(checkpoint, use_reentrant)
306+
self.relu = nn.ReLU()
303307
self.l2 = nn.Linear(100, 100)
304308

305309
def forward(self, x):
306-
return self.l2(self.s2(self.s1(self.l1(x))))
310+
return self.l2(self.relu(self.s2(self.s1(self.l1(x)))))
307311

308312

309313
class TestModel(nn.Module):
310314
def __init__(self, checkpoint: bool = False, use_reentrant: bool = True):
311315
super().__init__()
312316
self.l1 = nn.Linear(100, 100)
313-
self.m1 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant)
314-
self.m2 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant)
317+
self.relu = nn.ReLU()
318+
self.checkpoint1 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant)
319+
self.checkpoint2 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant)
315320
self.l2 = nn.Linear(100, 100)
316321

317322
def forward(self, x):
318-
return self.l2(self.m2(self.m1(self.l1(x))))
323+
return self.l2(self.relu(self.checkpoint2(self.checkpoint1(self.l1(x)))))
319324

320325

321326
class TestFSDPCheckpointSubmodule(FSDPTest):
322327

328+
# TODO: grad value checks occasionally fails when use_reentrant = True
323329
@skip_if_lt_x_gpu(2)
324-
def test_checkpoint_submodule_nonreentrant(self):
325-
model = TestModel().cuda()
330+
@parametrize("use_reentrant", [False])
331+
def test_checkpoint_submodule(self, use_reentrant: bool):
332+
model = TestModel(use_reentrant=use_reentrant).cuda()
326333
model_ac = deepcopy(model)
327334

328335
for _, m in model_ac.named_modules():
329336
if isinstance(m, CheckpointModule):
330337
m.checkpoint = True
331-
m.use_reentrant = False
332338

333-
self.assertTrue(model_ac.m1.s1.checkpoint)
334-
self.assertTrue(model_ac.m2.s2.checkpoint)
339+
self.assertTrue(model_ac.checkpoint1.s1.checkpoint)
340+
self.assertTrue(model_ac.checkpoint2.s2.checkpoint)
341+
342+
fsdp_kwargs = {
343+
"device_id": torch.cuda.current_device(),
344+
"sharding_strategy": ShardingStrategy.NO_SHARD,
345+
}
335346

336347
# Wrap no checkpointing model submodules with FSDP
337-
model.m1 = FSDP(
338-
module=model.m1,
339-
device_id=torch.cuda.current_device(),
340-
sharding_strategy=ShardingStrategy.NO_SHARD,
341-
)
342-
model.m2 = FSDP(
343-
module=model.m2,
344-
device_id=torch.cuda.current_device(),
345-
sharding_strategy=ShardingStrategy.NO_SHARD,
346-
)
348+
model.m1 = FSDP(module=model.checkpoint1, **fsdp_kwargs)
349+
model.m2 = FSDP(module=model.checkpoint2, **fsdp_kwargs)
347350

348351
# Wrap checkpointing model submodules with FSDP
349-
model_ac.m1 = FSDP(
350-
module=model_ac.m1,
351-
device_id=torch.cuda.current_device(),
352-
sharding_strategy=ShardingStrategy.NO_SHARD,
353-
)
354-
model_ac.m2 = FSDP(
355-
module=model_ac.m2,
356-
device_id=torch.cuda.current_device(),
357-
sharding_strategy=ShardingStrategy.NO_SHARD,
358-
)
352+
model_ac.m1 = FSDP(module=model_ac.checkpoint1, **fsdp_kwargs)
353+
model_ac.m2 = FSDP(module=model_ac.checkpoint2, **fsdp_kwargs)
359354

360-
x = torch.randn(2, 100).cuda()
355+
x = torch.randn(2, 100, device="cuda")
361356

362357
model(x).sum().backward()
363358
model_ac(x).sum().backward()
@@ -366,30 +361,8 @@ def test_checkpoint_submodule_nonreentrant(self):
366361
self.assertTrue(p1.grad.allclose(p2.grad))
367362

368363

369-
@skip_if_lt_x_gpu(2)
370-
def test_checkpoint_submodule_reentrant(self):
371-
model = TestModel(checkpoint=True, use_reentrant=True).cuda()
364+
instantiate_parametrized_tests(TestFSDPCheckpointSubmodule)
372365

373-
model.m1 = FSDP(
374-
module=model.m1,
375-
device_id=torch.cuda.current_device(),
376-
sharding_strategy=ShardingStrategy.NO_SHARD,
377-
)
378-
model.m2 = FSDP(
379-
module=model.m2,
380-
device_id=torch.cuda.current_device(),
381-
sharding_strategy=ShardingStrategy.NO_SHARD,
382-
)
383-
384-
x = torch.randn(2, 100).cuda()
385-
386-
with self.assertRaisesRegex(
387-
AssertionError, "but got HandleTrainingState.BACKWARD_POST"
388-
):
389-
model(x).sum().backward()
390-
391-
392-
instantiate_parametrized_tests(TestFSDPCheckpoint)
393366

394367
if __name__ == "__main__":
395368
run_tests()

0 commit comments

Comments
 (0)