-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathpreprocess.py
More file actions
82 lines (70 loc) · 2.64 KB
/
preprocess.py
File metadata and controls
82 lines (70 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Preprocessors for audio"""
import torch
from speechbrain.augment.time_domain import Resample
class AudioNormalizer:
"""Normalizes audio into a standard format
Arguments
---------
sample_rate : int
The sampling rate to which the incoming signals should be converted.
mix : {"avg-to-mono", "keep"}
"avg-to-mono" - add all channels together and normalize by number of
channels. This also removes the channel dimension, resulting in [time]
format tensor.
"keep" - don't normalize channel information
Example
-------
>>> from speechbrain.dataio import audio_io
>>> example_file = (
... "tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac"
... )
>>> signal, sr = audio_io.load(example_file, channels_first=False)
>>> normalizer = AudioNormalizer(sample_rate=8000)
>>> normalized = normalizer(signal, sr)
>>> signal.shape
torch.Size([160000, 4])
>>> normalized.shape
torch.Size([80000])
NOTE
----
This will also upsample audio. However, upsampling cannot produce meaningful
information in the bandwidth which it adds. Generally models will not work
well for upsampled data if they have not specifically been trained to do so.
"""
def __init__(self, sample_rate=16000, mix="avg-to-mono"):
self.sample_rate = sample_rate
if mix not in ["avg-to-mono", "keep"]:
raise ValueError(f"Unexpected mixing configuration {mix}")
self.mix = mix
self._cached_resamplers = {}
def __call__(self, audio, sample_rate):
"""Perform normalization
Arguments
---------
audio : torch.Tensor
The input waveform torch tensor. Assuming [time, channels],
or [time].
sample_rate : int
Rate the audio was sampled at.
Returns
-------
audio : torch.Tensor
Channel- and sample-rate-normalized audio.
"""
if sample_rate not in self._cached_resamplers:
# Create a Resample instance from this newly seen SR to internal SR
self._cached_resamplers[sample_rate] = Resample(
sample_rate, self.sample_rate
)
resampler = self._cached_resamplers[sample_rate]
resampled = resampler(audio.unsqueeze(0)).squeeze(0)
return self._mix(resampled)
def _mix(self, audio):
"""Handle channel mixing"""
flat_input = audio.dim() == 1
if self.mix == "avg-to-mono":
if flat_input:
return audio
return torch.mean(audio, 1)
if self.mix == "keep":
return audio