@@ -330,8 +330,8 @@ def reverberate(waveforms, rir_waveform, rescale_amp="avg"):
330330 Shape should be `[batch, time]` or `[batch, time, channels]`.
331331 rir_waveform : tensor
332332 RIR tensor, shape should be [time, channels].
333- rescale_amp : str
334- Whether reverberated signal is rescaled (None) and with respect either
333+ rescale_amp : str or None
334+ Whether reverberated signal is rescaled (None to avoid ) and with respect either
335335 to original signal "peak" amplitude or "avg" average amplitude.
336336 Choose between [None, "avg", "peak"].
337337
@@ -356,10 +356,11 @@ def reverberate(waveforms, rir_waveform, rescale_amp="avg"):
356356 elif len (rir_waveform .shape ) == 2 :
357357 rir_waveform = rir_waveform .unsqueeze (- 1 )
358358
359- # Compute the average amplitude of the clean
360- orig_amplitude = compute_amplitude (
361- waveforms , waveforms .size (1 ), rescale_amp
362- )
359+ if rescale_amp is not None :
360+ # Compute the average amplitude of the clean
361+ orig_amplitude = compute_amplitude (
362+ waveforms , waveforms .size (1 ), rescale_amp
363+ )
363364
364365 # Compute index of the direct signal, so we can preserve alignment
365366 value_max , direct_index = rir_waveform .abs ().max (axis = 1 , keepdim = True )
@@ -376,10 +377,11 @@ def reverberate(waveforms, rir_waveform, rescale_amp="avg"):
376377 rotation_index = direct_index ,
377378 )
378379
379- # Rescale to the peak amplitude of the clean waveform
380- waveforms = rescale (
381- waveforms , waveforms .size (1 ), orig_amplitude , rescale_amp
382- )
380+ if rescale_amp is not None :
381+ # Rescale to the peak amplitude of the clean waveform
382+ waveforms = rescale (
383+ waveforms , waveforms .size (1 ), orig_amplitude , rescale_amp
384+ )
383385
384386 if len (orig_shape ) == 1 :
385387 waveforms = waveforms .squeeze (0 ).squeeze (- 1 )
@@ -457,7 +459,7 @@ def _sinc(x):
457459 return torch .sin (x ) / x
458460
459461 # The zero is at the middle index
460- return torch .cat ([_sinc (x [:pad ]), torch .ones (1 ), _sinc (x [pad + 1 :])])
462+ return torch .cat ([_sinc (x [:pad ]), torch .ones (1 ), _sinc (x [pad + 1 :])])
461463
462464 # Compute a low-pass filter with cutoff frequency notch_freq.
463465 hlpf = sinc (3 * (notch_freq - notch_width ) * inputs )
@@ -494,7 +496,8 @@ def overlap_and_add(signal, frame_step):
494496 -------
495497 A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions.
496498 output_size = (frames - 1) * frame_step + frame_length
497- Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
499+ Based on
500+ https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
498501
499502 Example
500503 -------
0 commit comments