forked from speechbrain/speechbrain
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparameter_transfer.py
More file actions
283 lines (245 loc) · 9.42 KB
/
parameter_transfer.py
File metadata and controls
283 lines (245 loc) · 9.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
"""Convenience functions for the simplest parameter transfer cases.
Use `speechbrain.utils.checkpoints.Checkpointer` to find a checkpoint
and the path to the parameter file.
Authors
* Aku Rouhe 2020
"""
import logging
import pathlib
from speechbrain.pretrained.fetching import fetch
from speechbrain.utils.checkpoints import (
DEFAULT_LOAD_HOOKS,
DEFAULT_TRANSFER_HOOKS,
PARAMFILE_EXT,
get_default_hook,
)
logger = logging.getLogger(__name__)
class Pretrainer:
"""Orchestrates pretraining
First collects parameter file symlinks into the given directory. Then
calls load hooks for each of those parameter files.
Arguments
---------
collect_in : str or Path
Path to directory where the parameter file symlinks are collected.
loadables : mapping
Mapping from loadable key to object. This connects the keys to
the actual object instances.
paths : mapping
Mapping from loadable key to filepath. The last part
of the path is treated as file name, the rest of it
is treated as a "source" which can be either a directory
path or a magic source like Huggingface hub ID.
e.g. sb/asr-crdnn-libri/lm.ckpt
-> source=sb/asr-crdnn-libri, file=lm.ckpt
Note that when collecting, you can specify a default source,
which is used for all loadables that don't have a path specified.
custom_hooks : mapping
Mapping from loadable key to parameter transfer hook function. If you
want to use a custom loading function, specify it here.
conditions: mapping
An optional mapping from loadable keys to condition values,
useful for loading certain elements only if a flag is turned on
"""
def __init__(
self,
collect_in="./model_checkpoints",
loadables=None,
paths=None,
custom_hooks=None,
conditions=None,
):
self.loadables = {}
self.collect_in = pathlib.Path(collect_in)
if loadables is not None:
self.add_loadables(loadables)
self.paths = {}
if paths is not None:
self.add_paths(paths)
self.custom_hooks = {}
if custom_hooks is not None:
self.add_custom_hooks(custom_hooks)
self.conditions = {}
if conditions is not None:
self.add_conditions(conditions)
def set_collect_in(self, path):
"""Change the collecting path"""
self.collect_in = pathlib.Path(path)
def add_loadables(self, loadables):
"""Update the loadables dict from the given mapping.
Arguments
---------
loadables : mapping
Mapping from loadable key to object
"""
self.loadables.update(loadables)
def add_paths(self, paths):
"""Update the paths for different loadables.
When collecting parameters, paths here are preferred. Note that when
collecting, you can specify a default source, which is used for all
loadables that don't have a path specified.
Arguments
---------
paths : mapping
Mapping from loadable key to filepath. The last part
of the path is treated as file name, the rest of it
is treated as a "source" which can be either a directory
path or a magic source like Huggingface hub ID.
e.g. sb/asr-crdnn-libri/lm.ckpt
-> source=sb/asr-crdnn-libri, file=lm.ckpt
"""
self.paths.update(paths)
def add_custom_hooks(self, custom_hooks):
"""Update the custom hooks.
When loading parameters, hooks here are preferred over class defaults.
Arguments
---------
custom_hooks : mapping
Mapping from loadable key to parameter transfer hook function. If
you want to use a custom loading function, specify it here.
"""
self.custom_hooks.update(custom_hooks)
def add_conditions(self, conditions):
"""Update the conditions.
Arguments
---------
conditions: mapping
Mapping from loadable keys to condition values,
useful for loading certain elements only if a flag is turned on
"""
self.conditions.update(conditions)
@staticmethod
def split_path(path):
"""Splits a path to source and filename
This also handles URLs and Huggingface hub paths, in addition to
regular paths.
Arguments
---------
path : str
Returns
-------
str
Source
str
Filename
"""
if "/" in path:
return path.rsplit("/", maxsplit=1)
else:
# Interpret as path to file in current directory.
return "./", path
def collect_files(self, default_source=None):
"""Fetches parameters from known paths with fallback default_source
The actual parameter files may reside elsewhere, but this ensures a
symlink in the self.collect_in directory. The symlink always uses the
loadable key in the filename. This standardization makes it easier to
orchestrate pretraining on e.g. distributed setups.
Use the default_source if you have everything organized neatly into one
location, like a Huggingface hub repo.
Arguments
---------
default_source : str or Path
This is used for each loadable which doesn't have a path already
specified. If the loadable has key "asr", then the file to look for is
default_source/asr.ckpt
Returns
-------
dict
Mapping from loadable key to a local path from which loadable's
parameters can be loaded. This is not used in this class, but
can possibly be helpful.
"""
logger.debug(
f"Collecting files (or symlinks) for pretraining in {self.collect_in}."
)
self.collect_in.mkdir(exist_ok=True)
loadable_paths = {}
for name in self.loadables:
if not self.is_loadable(name):
continue
save_filename = name + PARAMFILE_EXT
if name in self.paths:
source, filename = self.split_path(self.paths[name])
elif default_source is not None:
filename = save_filename
source = default_source
else:
raise ValueError(
f"Path not specified for '{name}', "
"and no default_source given!"
)
path = fetch(
filename=filename,
source=source,
savedir=self.collect_in,
overwrite=False,
save_filename=save_filename,
use_auth_token=False,
revision=None,
)
loadable_paths[name] = path
return loadable_paths
def is_loadable(self, name):
"""Returns True if no condition is defined or for the specified
loadable or if the condition is true
Arguments
---------
name: str
the name of the loadable
Returns
-------
is_loadable: bool
whether the item should be loaded
"""
if name not in self.conditions:
return True
condition = self.conditions[name]
if callable(condition):
return condition()
else:
return bool(condition)
def load_collected(self, device=None):
"""Loads the files that have been collected.
Arguments
---------
device : str
Device on which to load, if you want to load to a specific device
directly ( otherwise just leave it to None ).
"""
logger.info(
f"Loading pretrained files for: {', '.join(self.loadables)}"
)
paramfiles = {}
for name in self.loadables:
if not self.is_loadable(name):
continue
filename = name + PARAMFILE_EXT
paramfiles[name] = self.collect_in / filename
self._call_load_hooks(paramfiles, device)
def _call_load_hooks(self, paramfiles, device=None):
# This internal function finds the correct hook to call for every
# recoverable, and calls it.
for name, obj in self.loadables.items():
if not self.is_loadable(name):
continue
loadpath = paramfiles[name]
# First see if object has custom load hook:
if name in self.custom_hooks:
self.custom_hooks[name](obj, loadpath, device=device)
continue
# Try the default transfer hook:
default_hook = get_default_hook(obj, DEFAULT_TRANSFER_HOOKS)
if default_hook is not None:
default_hook(obj, loadpath, device=device)
continue
# Otherwise find the default loader for that type:
default_hook = get_default_hook(obj, DEFAULT_LOAD_HOOKS)
if default_hook is not None:
# Need to fake end-of-epoch:
end_of_epoch = False
default_hook(obj, loadpath, end_of_epoch, device)
continue
# If we got here, no custom hook or registered default hook exists
MSG = f"Don't know how to load {type(obj)}. Register default hook \
or add custom hook for this object."
raise RuntimeError(MSG)