-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathparameter_transfer.py
More file actions
350 lines (299 loc) · 11.7 KB
/
parameter_transfer.py
File metadata and controls
350 lines (299 loc) · 11.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
"""Convenience functions for the simplest parameter transfer cases.
Use `speechbrain.utils.checkpoints.Checkpointer` to find a checkpoint
and the path to the parameter file.
Authors
* Aku Rouhe 2020
* Andreas Nautsch 2023
* Adel Moumen 2023
"""
import pathlib
import platform
import warnings
from speechbrain.utils.checkpoints import (
DEFAULT_LOAD_HOOKS,
DEFAULT_TRANSFER_HOOKS,
PARAMFILE_EXT,
get_default_hook,
)
from speechbrain.utils.fetching import (
FetchConfig,
FetchSource,
LocalStrategy,
fetch,
)
from speechbrain.utils.logger import get_logger
logger = get_logger(__name__)
class Pretrainer:
"""Orchestrates pretraining
First optionally collects files from some source (local directory,
HuggingFace repository, base URL), into the `collect_in` directory, if
specified.
Then, calls load hooks for each of those files.
Arguments
---------
collect_in : str or Path, optional
Path to directory where the files are to be collected.
If `None`, then files will be referred to from cache or directly, if
possible (URLs will fail). There will not be a centralized target
directory with all the files.
loadables : mapping
Mapping from loadable key to object. This connects the keys to
the actual object instances.
paths : mapping
Mapping from loadable key to filepath. The last part
of the path is treated as file name, the rest of it
is treated as a "source" which can be either a directory
path or a magic source like Huggingface hub ID.
e.g. sb/asr-crdnn-libri/lm.ckpt
-> source=sb/asr-crdnn-libri, file=lm.ckpt
Note that when collecting, you can specify a default source,
which is used for all loadables that don't have a path specified.
custom_hooks : mapping
Mapping from loadable key to parameter transfer hook function. If you
want to use a custom loading function, specify it here.
conditions: mapping
An optional mapping from loadable keys to condition values,
useful for loading certain elements only if a flag is turned on
"""
def __init__(
self,
collect_in=None,
loadables=None,
paths=None,
custom_hooks=None,
conditions=None,
):
self.loadables = {}
self.set_collect_in(collect_in)
if loadables is not None:
self.add_loadables(loadables)
self.paths = {}
if paths is not None:
self.add_paths(paths)
self.custom_hooks = {}
if custom_hooks is not None:
self.add_custom_hooks(custom_hooks)
self.conditions = {}
if conditions is not None:
self.add_conditions(conditions)
self.is_local = []
def set_collect_in(self, path):
"""Change the collecting path"""
self.collect_in = pathlib.Path(path) if path is not None else None
def add_loadables(self, loadables):
"""Update the loadables dict from the given mapping.
Arguments
---------
loadables : mapping
Mapping from loadable key to object
"""
self.loadables.update(loadables)
def add_paths(self, paths):
"""Update the paths for different loadables.
When collecting parameters, paths here are preferred. Note that when
collecting, you can specify a default source, which is used for all
loadables that don't have a path specified.
Arguments
---------
paths : mapping
Mapping from loadable key to filepath. The last part
of the path is treated as file name, the rest of it
is treated as a "source" which can be either a directory
path or a magic source like Huggingface hub ID.
e.g. sb/asr-crdnn-libri/lm.ckpt
-> source=sb/asr-crdnn-libri, file=lm.ckpt
"""
self.paths.update(paths)
def add_custom_hooks(self, custom_hooks):
"""Update the custom hooks.
When loading parameters, hooks here are preferred over class defaults.
Arguments
---------
custom_hooks : mapping
Mapping from loadable key to parameter transfer hook function. If
you want to use a custom loading function, specify it here.
"""
self.custom_hooks.update(custom_hooks)
def add_conditions(self, conditions):
"""Update the conditions.
Arguments
---------
conditions: mapping
Mapping from loadable keys to condition values,
useful for loading certain elements only if a flag is turned on
"""
self.conditions.update(conditions)
@staticmethod
def split_path(path):
"""Splits a path to source and filename
This also handles URLs and Huggingface hub paths, in addition to
regular paths.
Arguments
---------
path : str
Returns
-------
str
Source
str
Filename
"""
def split(src):
"""Core function to split path."""
if "/" in src:
return src.rsplit("/", maxsplit=1)
else:
# Interpret as path to file in current directory.
return "./", src
if isinstance(path, FetchSource):
fetch_from, fetch_path = path
source, filename = split(fetch_path)
return FetchSource(fetch_from, source), filename
else:
return split(path)
def collect_files(
self,
default_source=None,
local_strategy=LocalStrategy.SYMLINK,
fetch_config=FetchConfig(),
):
"""Fetches parameters from known paths with fallback default_source
The actual parameter files may reside elsewhere, but this ensures a
symlink in the self.collect_in directory. The symlink always uses the
loadable key in the filename. This standardization makes it easier to
orchestrate pretraining on e.g. distributed setups.
Use the default_source if you have everything organized neatly into one
location, like a Huggingface hub repo.
Arguments
---------
default_source : str or Path or FetchSource
This is used for each loadable which doesn't have a path already
specified.
e.g. if the loadable has key `"asr"`, then the file to look for is
`<default_source>/asr.ckpt`
local_strategy : LocalStrategy
How to perform caching on the file for local storage.
fetch_config : FetchConfig
Configuration options like caching strategy for fetching files.
Returns
-------
dict
Mapping from loadable key to a local path from which loadable's
parameters can be loaded. This is not used in this class, but
can possibly be helpful.
"""
if self.collect_in is not None:
logger.debug(
f"Collecting files (or symlinks) for pretraining in {self.collect_in}."
)
self.collect_in.mkdir(exist_ok=True)
if (
platform.system() == "Windows"
and local_strategy == LocalStrategy.SYMLINK
):
warnings.warn(
"Requested Pretrainer collection using symlinks on Windows. This might not work; see `LocalStrategy` documentation. Consider unsetting `collect_in` in Pretrainer to avoid symlinking altogether."
)
else:
logger.debug(
"Fetching files for pretraining (no collection directory set)"
)
loadable_paths = {}
for name in self.loadables:
if not self.is_loadable(name):
continue
save_filename = name + PARAMFILE_EXT
if name in self.paths:
source, filename = self.split_path(self.paths[name])
elif default_source is not None:
filename = save_filename
source = default_source
else:
raise ValueError(
f"Path not specified for '{name}', "
"and no default_source given!"
)
# Fetch now handles multiprocessing!
path = fetch(
filename=filename,
source=source,
savedir=self.collect_in,
save_filename=save_filename,
local_strategy=local_strategy,
fetch_config=fetch_config,
)
loadable_paths[name] = path
if isinstance(source, FetchSource):
_fetch_from, source = source
logger.debug(f'Set local path in self.paths["{name}"] = {path}')
self.paths[name] = str(path)
self.is_local.append(name)
return loadable_paths
def is_loadable(self, name):
"""Returns True if no condition is defined or for the specified
loadable or if the condition is true
Arguments
---------
name: str
the name of the loadable
Returns
-------
is_loadable: bool
whether the item should be loaded
"""
if name not in self.conditions:
return True
condition = self.conditions[name]
if callable(condition):
return condition()
else:
return bool(condition)
def load_collected(self):
"""Loads the files that have been collected."""
logger.info(
f"Loading pretrained files for: {', '.join(self.loadables)}"
)
paramfiles = {}
for name in self.loadables:
if not self.is_loadable(name):
continue
filename = name + PARAMFILE_EXT
if name in self.is_local:
logger.debug(
f"Redirecting (loading from local path): {name} -> {self.paths[name]}"
)
paramfiles[name] = self.paths[name]
elif self.collect_in is not None:
paramfiles[name] = self.collect_in / filename
else:
raise ValueError(
f'Pretrainer has never collected `{name}`, did you forget a call to `collect_files`? Could not fall back to `collect_in`, as it was not specified (default is no longer "model_checkpoints").'
)
self._call_load_hooks(paramfiles)
def _call_load_hooks(self, paramfiles):
# This internal function finds the correct hook to call for every
# recoverable, and calls it.
for name, obj in self.loadables.items():
if not self.is_loadable(name):
continue
loadpath = paramfiles[name]
# First see if object has custom load hook:
if name in self.custom_hooks:
self.custom_hooks[name](obj, loadpath)
continue
# Try the default transfer hook:
default_hook = get_default_hook(obj, DEFAULT_TRANSFER_HOOKS)
if default_hook is not None:
default_hook(obj, loadpath)
continue
# Otherwise find the default loader for that type:
default_hook = get_default_hook(obj, DEFAULT_LOAD_HOOKS)
if default_hook is not None:
# Need to fake end-of-epoch:
end_of_epoch = False
default_hook(obj, loadpath, end_of_epoch)
continue
# If we got here, no custom hook or registered default hook exists
MSG = f"Don't know how to load {type(obj)}. Register default hook \
or add custom hook for this object."
raise RuntimeError(MSG)