Skip to content

Commit fe9178c

Browse files
fix bug in reverberate with rescale_amp (speechbrain#2871)
Co-authored-by: Parcollet Titouan <parcollet.titouan@gmail.com>
1 parent 0ff5c0c commit fe9178c

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

speechbrain/processing/signal_processing.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,8 @@ def reverberate(waveforms, rir_waveform, rescale_amp="avg"):
330330
Shape should be `[batch, time]` or `[batch, time, channels]`.
331331
rir_waveform : tensor
332332
RIR tensor, shape should be [time, channels].
333-
rescale_amp : str
334-
Whether reverberated signal is rescaled (None) and with respect either
333+
rescale_amp : str or None
334+
Whether reverberated signal is rescaled (None to avoid) and with respect either
335335
to original signal "peak" amplitude or "avg" average amplitude.
336336
Choose between [None, "avg", "peak"].
337337
@@ -356,10 +356,11 @@ def reverberate(waveforms, rir_waveform, rescale_amp="avg"):
356356
elif len(rir_waveform.shape) == 2:
357357
rir_waveform = rir_waveform.unsqueeze(-1)
358358

359-
# Compute the average amplitude of the clean
360-
orig_amplitude = compute_amplitude(
361-
waveforms, waveforms.size(1), rescale_amp
362-
)
359+
if rescale_amp is not None:
360+
# Compute the average amplitude of the clean
361+
orig_amplitude = compute_amplitude(
362+
waveforms, waveforms.size(1), rescale_amp
363+
)
363364

364365
# Compute index of the direct signal, so we can preserve alignment
365366
value_max, direct_index = rir_waveform.abs().max(axis=1, keepdim=True)
@@ -376,10 +377,11 @@ def reverberate(waveforms, rir_waveform, rescale_amp="avg"):
376377
rotation_index=direct_index,
377378
)
378379

379-
# Rescale to the peak amplitude of the clean waveform
380-
waveforms = rescale(
381-
waveforms, waveforms.size(1), orig_amplitude, rescale_amp
382-
)
380+
if rescale_amp is not None:
381+
# Rescale to the peak amplitude of the clean waveform
382+
waveforms = rescale(
383+
waveforms, waveforms.size(1), orig_amplitude, rescale_amp
384+
)
383385

384386
if len(orig_shape) == 1:
385387
waveforms = waveforms.squeeze(0).squeeze(-1)
@@ -457,7 +459,7 @@ def _sinc(x):
457459
return torch.sin(x) / x
458460

459461
# The zero is at the middle index
460-
return torch.cat([_sinc(x[:pad]), torch.ones(1), _sinc(x[pad + 1 :])])
462+
return torch.cat([_sinc(x[:pad]), torch.ones(1), _sinc(x[pad + 1:])])
461463

462464
# Compute a low-pass filter with cutoff frequency notch_freq.
463465
hlpf = sinc(3 * (notch_freq - notch_width) * inputs)
@@ -494,7 +496,8 @@ def overlap_and_add(signal, frame_step):
494496
-------
495497
A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions.
496498
output_size = (frames - 1) * frame_step + frame_length
497-
Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
499+
Based on
500+
https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
498501
499502
Example
500503
-------

0 commit comments

Comments
 (0)