-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathlogger.py
More file actions
320 lines (278 loc) · 9.9 KB
/
logger.py
File metadata and controls
320 lines (278 loc) · 9.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
"""Managing the logger, utilities
Author
* Fang-Pen Lin 2012 https://fangpenlin.com/posts/2012/08/26/good-logging-practice-in-python/
* Peter Plantinga 2020
* Aku Rouhe 2020
"""
import functools
import logging
import logging.config
import math
import os
import sys
import torch
import tqdm
import yaml
from speechbrain.utils.data_utils import recursive_update
from speechbrain.utils.distributed import if_main_process
from speechbrain.utils.superpowers import run_shell
ORDERS_ABBREV = {
-24: "y",
-21: "z",
-18: "a",
-15: "f",
-12: "p",
-9: "n",
-6: "µ",
-3: "m",
0: "",
3: "k",
6: "M",
9: "G",
12: "T",
15: "P",
18: "E",
21: "Z",
24: "Y",
}
# Short scale
# Negative powers of ten in lowercase, positive in uppercase
ORDERS_WORDS = {
-24: "septillionths",
-21: "sextillionths",
-18: "quintillionths",
-15: "quadrillionths",
-12: "trillionths",
-9: "billionths",
-6: "millionths",
-3: "thousandths",
0: "",
3: "Thousand",
6: "Million",
9: "Billion",
12: "Trillion",
15: "Quadrillion",
18: "Quintillion",
21: "Sextillion",
24: "Septillion",
}
class MultiProcessLoggerAdapter(logging.LoggerAdapter):
r"""
Logger adapter that handles multi-process logging, ensuring logs are written
only on the main process if specified. This class extends `logging.LoggerAdapter`
and provides additional functionality for controlling logging in multi-process
environments, with the option to limit logs to the main process only.
This class is heavily inspired by HuggingFace Accelerate toolkit:
https://github.com/huggingface/accelerate/blob/85b1a03552cf8d58e036634e004220c189bfb247/src/accelerate/logging.py#L22
"""
@staticmethod
def _should_log(main_process_only: bool) -> bool:
r"""
Determines if logging should occur based on whether the code is running
on the main process or not.
Arguments
---------
main_process_only : bool
A flag indicating if logging should be restricted to the main process.
Returns
-------
bool
True if logging should be performed (based on the process and the flag),
False otherwise.
"""
return not main_process_only or (
main_process_only and if_main_process()
)
def log(self, level: int, msg: str, *args, **kwargs):
r"""
Logs a message with the specified log level, respecting the `main_process_only`
flag to decide whether to log based on the current process.
Arguments
---------
level : int
Logging level (e.g., logging.INFO, logging.WARNING).
msg : str
The message to log.
*args : tuple
Additional positional arguments passed to the logger.
**kwargs : dict
Additional keyword arguments passed to the logger, including:
- main_process_only (bool): If True, log only from the main process (default: True).
- stacklevel (int): The stack level to use when logging (default: 2).
Notes
-----
If `main_process_only` is True, the log will only be written if the current process
is the main process, as determined by `if_main_process()`.
"""
main_process_only = kwargs.pop("main_process_only", True)
kwargs.setdefault("stacklevel", 2)
if self.isEnabledFor(level):
if self._should_log(main_process_only):
msg, kwargs = self.process(msg, kwargs)
self.logger.log(level, msg, *args, **kwargs)
@functools.lru_cache(None)
def warning_once(self, *args, **kwargs):
r"""
Logs a warning message only once by using caching to prevent duplicate warnings.
Arguments
---------
*args : tuple
Positional arguments passed to the warning log.
**kwargs : dict
Keyword arguments passed to the warning log.
Notes
-----
This method is decorated with `functools.lru_cache(None)`, ensuring that the warning
message is logged only once regardless of how many times the method is called.
"""
self.warning(*args, **kwargs)
def get_logger(name: str) -> MultiProcessLoggerAdapter:
"""
Retrieves a logger with the specified name, applying a log level from the environment variable
`SB_LOG_LEVEL` if set, or defaults to `INFO` level.
If the environment variable `SB_LOG_LEVEL` is not defined, it defaults to `INFO` level and sets
this level in the environment for future use. The environment variable can be set manually or
automatically in `Brain` class following `setup_logging`.
Arguments
---------
name : str
The name of the logger to retrieve.
Returns
-------
MultiProcessLoggerAdapter
An instance of `MultiProcessLoggerAdapter` wrapping the logger with the specified name.
"""
logger = logging.getLogger(name)
log_level = os.environ.get("SB_LOG_LEVEL", None)
if log_level is None:
log_level = "DEBUG"
os.environ["SB_LOG_LEVEL"] = log_level
logger.setLevel(log_level.upper())
return MultiProcessLoggerAdapter(logger, {})
def setup_logging(
config_path="log-config.yaml",
overrides={},
default_level="DEBUG",
):
"""Setup logging configuration.
Arguments
---------
config_path : str
The path to a logging config file.
overrides : dict
A dictionary of the same structure as the config dict
with any updated values that need to be applied.
default_level : str
The log level to use if the config file is not found.
Python logging allows ints or strings:
https://docs.python.org/3/library/logging.html#logging.Logger.setLevel
but strings are used here as environment variables have to be
strings. The available levels are listed here:
https://docs.python.org/3/library/logging.html#levels
"""
if os.path.exists(config_path):
with open(config_path, encoding="utf-8") as f:
config = yaml.safe_load(f)
recursive_update(config, overrides)
logging.config.dictConfig(config)
else:
logging.basicConfig(level=default_level)
os.environ["SB_LOG_LEVEL"] = default_level
class TqdmCompatibleStreamHandler(logging.StreamHandler):
"""TQDM compatible StreamHandler.
Writes and prints should be passed through tqdm.tqdm.write
so that the tqdm progressbar doesn't get messed up.
"""
def emit(self, record):
"""TQDM compatible StreamHandler."""
try:
msg = self.format(record)
stream = self.stream
tqdm.tqdm.write(msg, end=self.terminator, file=stream)
self.flush()
except RecursionError:
raise
except Exception:
self.handleError(record)
def format_order_of_magnitude(number, abbreviate=True):
"""Formats number to the appropriate order of magnitude for printing.
Arguments
---------
number : int, float
The number to format.
abbreviate : bool
Whether to use abbreviations (k,M,G) or words (Thousand, Million,
Billion). Numbers will be either like: "123.5k" or "123.5 Thousand".
Returns
-------
str
The formatted number. Note that the order of magnitude token is part
of the string.
Example
-------
>>> print(format_order_of_magnitude(123456))
123.5k
>>> print(format_order_of_magnitude(0.00000123, abbreviate=False))
1.2 millionths
>>> print(format_order_of_magnitude(5, abbreviate=False))
5
"""
style = ORDERS_ABBREV if abbreviate else ORDERS_WORDS
precision = "{num:3.1f}"
order = 3 * int(math.floor(math.log(math.fabs(number), 1000)))
# Fallback for very large numbers:
while order not in style and order != 0:
order = order - int(math.copysign(3, order)) # Bring 3 units towards 0
order_token = style[order]
if order != 0:
formatted_number = precision.format(num=number / 10**order)
else:
if isinstance(number, int):
formatted_number = str(number)
else:
formatted_number = precision.format(num=number)
if abbreviate or not order_token:
return formatted_number + order_token
else:
return formatted_number + " " + order_token
def get_environment_description():
"""Returns a string describing the current Python / SpeechBrain environment.
Useful for making experiments as replicable as possible.
Returns
-------
str
The string is formatted ready to be written to a file.
Example
-------
>>> get_environment_description().splitlines()[0]
'SpeechBrain system description'
"""
python_version_str = "Python version:\n" + sys.version + "\n"
try:
freezed, _, _ = run_shell("pip freeze")
python_packages_str = "Installed Python packages:\n"
python_packages_str += freezed.decode(errors="replace")
except OSError:
python_packages_str = "Could not list python packages with pip freeze"
try:
git_hash, _, _ = run_shell("git rev-parse --short HEAD")
git_str = "Git revision:\n" + git_hash.decode(errors="replace")
except OSError:
git_str = "Could not get git revision"
if torch.cuda.is_available():
if torch.version.cuda is None:
cuda_str = "ROCm version:\n" + torch.version.hip
else:
cuda_str = "CUDA version:\n" + torch.version.cuda
else:
cuda_str = "CUDA not available"
result = "SpeechBrain system description\n"
result += "==============================\n"
result += python_version_str
result += "==============================\n"
result += python_packages_str
result += "==============================\n"
result += git_str
result += "==============================\n"
result += cuda_str
return result