Skip to content

Commit 85da7fe

Browse files
committed
Address review comments
1 parent b588cb8 commit 85da7fe

File tree

3 files changed

+38
-8
lines changed

3 files changed

+38
-8
lines changed

aten/src/ATen/native/SpectralOps.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,14 @@ static Stream& write_opt(Stream& SS, const optional<T>& value) {
193193
return SS;
194194
}
195195

196+
/* Short-time Fourier Transform, for signal analysis.
197+
*
198+
* This is modelled after librosa but with support for complex time-domain
199+
* signals and complex windows.
200+
*
201+
* NOTE: librosa's center and pad_mode arguments are currently only implemented
202+
* in python because it uses torch.nn.functional.pad which is python-only.
203+
*/
196204
Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
197205
const optional<int64_t> win_lengthOpt, const Tensor& window,
198206
const bool normalized, const optional<bool> onesidedOpt,
@@ -306,6 +314,11 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
306314
}
307315
}
308316

317+
/* Inverse Short-time Fourier Transform
318+
*
319+
* This is modelled after librosa but with support for complex time-domain
320+
* signals and complex windows.
321+
*/
309322
Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
310323
const optional<int64_t> win_lengthOpt, const Tensor& window,
311324
const bool center, const bool normalized, const c10::optional<bool> onesidedOpt,

test/test_spectral_ops.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def _complex_stft(x, *args, **kwargs):
3030

3131

3232
def _hermitian_conj(x, dim):
33+
"""Returns the hermitian conjugate along a single dimension
34+
35+
H(x)[i] = conj(x[-i])
36+
"""
3337
out = torch.empty_like(x)
3438
mid = (x.size(dim) - 1) // 2
3539
idx = [slice(None)] * out.dim()
@@ -60,10 +64,17 @@ def _complex_istft(x, *args, **kwargs):
6064
x_antihermitian = (x - hconj) / 2
6165
istft_real = torch.istft(x_hermitian[slc], *args, **kwargs, onesided=True)
6266
istft_imag = torch.istft(-1j * x_antihermitian[slc], *args, **kwargs, onesided=True)
63-
return istft_real + 1j * istft_imag
67+
return torch.complex(istft_real, istft_imag)
68+
69+
70+
def _stft_reference(x, hop_length, window):
71+
r"""Reference stft implementation
72+
73+
This doesn't implement all of torch.stft, only the STFT definition:
6474
75+
.. math:: X(m, \omega) = \sum_n x[n]w[n - m] e^{-jn\omega}
6576
66-
def stft_definition(x, hop_length, window):
77+
"""
6778
n_fft = window.numel()
6879
X = torch.empty((n_fft, (x.numel() - n_fft + hop_length) // hop_length),
6980
device=x.device, dtype=torch.cdouble)
@@ -482,7 +493,7 @@ def test_complex_stft_definition(self, device, dtype):
482493

483494
for args in test_args:
484495
window = torch.randn(args[1], device=device, dtype=dtype)
485-
expected = stft_definition(args[0], args[2], window)
496+
expected = _stft_reference(args[0], args[2], window)
486497
actual = torch.stft(*args, window=window, center=False)
487498
self.assertEqual(actual, expected)
488499

torch/functional.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,12 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
399399
return_complex: Optional[bool] = None) -> Tensor:
400400
r"""Short-time Fourier transform (STFT).
401401
402+
The STFT computes the Fourier transform of short overlapping windows of the
403+
input. This giving frequency components of the signal as they change over
404+
time. The interface of this function is modelled after librosa_.
405+
406+
.. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html
407+
402408
Ignoring the optional batch dimension, this method computes the following
403409
expression:
404410
@@ -452,10 +458,10 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
452458
dimension represents the real and imaginary components.
453459
454460
Returns either a complex tensor of size :math:`(* \times N \times T)` if
455-
:attr:`return_complex`, or a real tensor of size :math:`(* \times N \times
456-
T \times 2)`. :math:`*` is the optional batch size of :attr:`input`,
457-
:math:`N` is the number of frequencies where STFT is applied and :math:`T`
458-
is the total number of frames used.
461+
:attr:`return_complex` is true, or a real tensor of size :math:`(* \times N
462+
\times T \times 2)`. Where :math:`*` is the optional batch size of
463+
:attr:`input`, :math:`N` is the number of frequencies where STFT is applied
464+
and :math:`T` is the total number of frames used.
459465
460466
.. warning::
461467
This function changed signature at version 0.4.1. Calling with the
@@ -479,7 +485,7 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
479485
Default: ``False``
480486
onesided (bool, optional): controls whether to return half of results to
481487
avoid redundancy for real inputs.
482-
Default: ``True`` for real input and window, ``False`` otherwise.
488+
Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise.
483489
return_complex (bool, optional): whether to return a complex tensor, or
484490
a real tensor with an extra last dimension for the real and
485491
imaginary components.

0 commit comments

Comments
 (0)