|
12 | 12 | from itertools import product |
13 | 13 | from functools import partial |
14 | 14 | from collections import OrderedDict |
| 15 | +from tempfile import NamedTemporaryFile |
15 | 16 |
|
16 | 17 | import torch |
17 | 18 |
|
|
37 | 38 | download_file, get_function_arglist, load_tests, skipIfMps,\ |
38 | 39 | TEST_WITH_UBSAN, IS_PPC, \ |
39 | 40 | parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \ |
40 | | - skipIfTorchDynamo |
| 41 | + skipIfTorchDynamo, IS_WINDOWS |
41 | 42 | from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION |
42 | 43 | from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ |
43 | 44 | module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \ |
@@ -2450,6 +2451,60 @@ def hook_fn(module, state_dict, prefix, local_metadata, strict, missing_keys, un |
2450 | 2451 | model[0][0]._register_load_state_dict_pre_hook(hook_fn, with_module=True) |
2451 | 2452 | model.load_state_dict(model.state_dict(), strict=True) |
2452 | 2453 |
|
| 2454 | + @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") |
| 2455 | + def test_register_state_dict_pre_hook_backward_compat(self): |
| 2456 | + called = False |
| 2457 | + |
| 2458 | + def my_state_dict_pre_hook(*args, **kwargs): |
| 2459 | + nonlocal called |
| 2460 | + called = True |
| 2461 | + |
| 2462 | + m = nn.Linear(1, 1) |
| 2463 | + self.assertTrue(hasattr(m, '_state_dict_pre_hooks')) |
| 2464 | + delattr(m, '_state_dict_pre_hooks') |
| 2465 | + # Save and load, ensure we can still call state_dict |
| 2466 | + # without running into issues. |
| 2467 | + with NamedTemporaryFile() as f: |
| 2468 | + # Note that torch.save / torch.load is not recommended |
| 2469 | + # to save / load modules. |
| 2470 | + torch.save(m, f.name) |
| 2471 | + m = torch.load(f.name) |
| 2472 | + |
| 2473 | + # Ensure we can run state_dict without issues |
| 2474 | + _ = m.state_dict() |
| 2475 | + self.assertFalse(called) |
| 2476 | + m.register_state_dict_pre_hook(my_state_dict_pre_hook) |
| 2477 | + _ = m.state_dict() |
| 2478 | + self.assertTrue(called) |
| 2479 | + |
| 2480 | + def test_register_state_dict_pre_hook(self): |
| 2481 | + _state_dict_prefix = "foo." |
| 2482 | + state_dict_pre_hook_count = 0 |
| 2483 | + |
| 2484 | + class MyModule(torch.nn.Module): |
| 2485 | + def __init__(self): |
| 2486 | + super().__init__() |
| 2487 | + self.a = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)) |
| 2488 | + |
| 2489 | + def forward(self, x): |
| 2490 | + return self.a(x) |
| 2491 | + |
| 2492 | + def my_state_dict_pre_hook(module, prefix, keep_vars): |
| 2493 | + nonlocal keep_var_setting |
| 2494 | + self.assertEqual(keep_vars, keep_var_setting) |
| 2495 | + nonlocal state_dict_pre_hook_count |
| 2496 | + state_dict_pre_hook_count += 1 |
| 2497 | + self.assertTrue(prefix.startswith(_state_dict_prefix)) |
| 2498 | + |
| 2499 | + mod = MyModule() |
| 2500 | + mod.register_state_dict_pre_hook(my_state_dict_pre_hook) |
| 2501 | + # Test to ensure submodules run the hook as well. |
| 2502 | + mod.a.register_state_dict_pre_hook(my_state_dict_pre_hook) |
| 2503 | + for keep_var_setting in [True, False]: |
| 2504 | + _ = mod.state_dict(prefix=_state_dict_prefix, keep_vars=keep_var_setting) |
| 2505 | + self.assertEqual(2, state_dict_pre_hook_count) |
| 2506 | + state_dict_pre_hook_count = 0 |
| 2507 | + |
2453 | 2508 | @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") |
2454 | 2509 | def test_load_state_dict_ref_cycle(self): |
2455 | 2510 | # load_state_dict shouldn't cause a reference cycle involving Tensors |
|
0 commit comments