-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathrepro.py
More file actions
172 lines (144 loc) · 5.06 KB
/
repro.py
File metadata and controls
172 lines (144 loc) · 5.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""Reproducibility tools
Author:
* Artem Ploujnikov 2025
"""
import re
import torch
import speechbrain as sb
from speechbrain.utils.logger import get_logger
logger = get_logger(__name__)
@sb.utils.checkpoints.register_checkpoint_hooks
class SaveableGenerator:
"""A wrapper that can be used to store the state of
the random number generator in a checkpoint. It helps
with reproducibility in long-running experiments.
Currently, this only supports CPU and Cuda devices
natively. If you need training on other architectures,
consider implementing a custom generator.
Running it on an unsupported device not using the Torch
generator interface will simply fail to restore the
state but will not cause an error.
Typical in hparams:
```yaml
generator: !new:model.custom_model.SaveableGenerator # <-- Include the wrapper
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
model: !ref <model>
lr_scheduler: !ref <lr_annealing>
counter: !ref <epoch_counter>
generator: !ref <generator>
```
Arguments
---------
generators : Mapping[str, Generator], optional
A dictionary of named generator objects. If not provided,
the default generators for CPU and Cuda will be used
Examples
--------
>>> import torch
>>> from speechbrain.utils.repro import SaveableGenerator
>>> from speechbrain.utils.checkpoints import Checkpointer
>>> gena, genb = [torch.Generator().manual_seed(x) for x in [42, 24]]
>>> saveable_gen = SaveableGenerator(
... generators={"a": gena, "b": genb}
... )
>>> tempdir = getfixture('tmpdir')
>>> checkpointer = Checkpointer(
... tempdir,
... recoverables={"generator": saveable_gen})
>>> torch.randint(0, 10, (1,), generator=gena).item()
2
>>> torch.randint(0, 10, (1,), generator=genb).item()
4
>>> _ = checkpointer.save_checkpoint()
>>> torch.randint(0, 10, (1,), generator=gena).item()
7
>>> torch.randint(0, 10, (1,), generator=genb).item()
5
>>> _ = checkpointer.recover_if_possible()
>>> torch.randint(0, 10, (1,), generator=gena).item()
7
>>> torch.randint(0, 10, (1,), generator=genb).item()
5
"""
def __init__(self, generators=None):
if generators is None:
generators = {"default": torch.default_generator}
if torch.cuda.is_available():
for idx in range(torch.cuda.device_count()):
generators[f"cuda:{idx}"] = _CudaDefaultGeneratorWrapper(
idx
)
self.generators = generators
@sb.utils.checkpoints.mark_as_saver
def save(self, path):
"""Save the generator state for later recovery
Arguments
---------
path : str, Path
Where to save. Will overwrite.
"""
save_dict = {
key: generator.get_state()
for key, generator in self.generators.items()
}
torch.save(save_dict, path)
@sb.utils.checkpoints.mark_as_loader
def load(self, path, end_of_epoch):
"""
Loads the generator state if the corresponding devices are
present
Arguments
---------
path : str, Path
Where to load from.
end_of_epoch : bool
Whether the checkpoint was end-of-epoch or not.
"""
del end_of_epoch
save_dict = torch.load(path)
for key, state in save_dict.items():
if key == "default":
torch.default_generator.set_state(state)
continue
match = re.match(r"cuda:(\d+)", key)
if match:
if not torch.cuda.is_available():
logger.warning(
"Unable to restore RNG for %s, CUDA unavailable", key
)
continue
idx = int(match.group(1))
if idx > torch.cuda.device_count() - 1:
logger.warning(
"Unable to restore RNG for %s, device not found", key
)
continue
self.generators[key].set_state(state)
class _CudaDefaultGeneratorWrapper:
"""A generator wrapper for default generators - because torch no longer
exposes default_generators
This class should not be used outside of SaveableGenerator
Arguments
---------
device : int|str
The device index or identifier"""
def __init__(self, device):
self.device = device
def get_state(self):
"""Returns the generator state
Returns
-------
result : torch.Tensor
The generator state
"""
return torch.cuda.get_rng_state(self.device)
def set_state(self, new_state):
""" "Sets the generator state
Arguments
---------
new_state : dict
The new state
"""
torch.cuda.set_rng_state(new_state, self.device)