@@ -183,10 +183,28 @@ Tensor irfft(const Tensor& self, const int64_t signal_ndim, const bool normalize
183183 normalized, onesided);
184184}
185185
186+ template <typename Stream, typename T>
187+ static Stream& write_opt (Stream& SS, const optional<T>& value) {
188+ if (value) {
189+ SS << *value;
190+ } else {
191+ SS << " None" ;
192+ }
193+ return SS;
194+ }
186195
196+ /* Short-time Fourier Transform, for signal analysis.
197+ *
198+ * This is modeled after librosa but with support for complex time-domain
199+ * signals and complex windows.
200+ *
201+ * NOTE: librosa's center and pad_mode arguments are currently only implemented
202+ * in python because it uses torch.nn.functional.pad which is python-only.
203+ */
187204Tensor stft (const Tensor& self, const int64_t n_fft, const optional<int64_t > hop_lengthOpt,
188205 const optional<int64_t > win_lengthOpt, const Tensor& window,
189- const bool normalized, const bool onesided) {
206+ const bool normalized, const optional<bool > onesidedOpt,
207+ const optional<bool > return_complexOpt) {
190208 #define REPR (SS ) \
191209 SS << " stft(" << self.toString () << self.sizes () << " , n_fft=" << n_fft \
192210 << " , hop_length=" << hop_length << " , win_length=" << win_length \
@@ -196,15 +214,28 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
196214 } else { \
197215 SS << " None" ; \
198216 } \
199- SS << " , normalized=" << normalized << " , onesided=" << onesided << " )"
217+ SS << " , normalized=" << normalized << " , onesided=" ; \
218+ write_opt (SS, onesidedOpt) << " , return_complex=" ; \
219+ write_opt (SS, return_complexOpt) << " ) "
200220
201221 // default_init hop_length and win_length
202222 auto hop_length = hop_lengthOpt.value_or (n_fft >> 2 );
203223 auto win_length = win_lengthOpt.value_or (n_fft);
224+ const bool return_complex = return_complexOpt.value_or (
225+ self.is_complex () || (window.defined () && window.is_complex ()));
226+ if (!return_complexOpt && !return_complex) {
227+ TORCH_WARN (" stft will return complex tensors by default in future, use"
228+ " return_complex=False to preserve the current output format." );
229+ }
204230
205- if (!at::isFloatingType (self.scalar_type ()) || self. dim () > 2 || self.dim () < 1 ) {
231+ if (!at::isFloatingType (self.scalar_type ()) && ! at::isComplexType ( self.scalar_type ()) ) {
206232 std::ostringstream ss;
207- REPR (ss) << " : expected a 1D or 2D tensor of floating types" ;
233+ REPR (ss) << " : expected a tensor of floating point or complex values" ;
234+ AT_ERROR (ss.str ());
235+ }
236+ if (self.dim () > 2 || self.dim () < 1 ) {
237+ std::ostringstream ss;
238+ REPR (ss) << " : expected a 1D or 2D tensor" ;
208239 AT_ERROR (ss.str ());
209240 }
210241 Tensor input = self;
@@ -240,11 +271,12 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
240271 auto window_ = window;
241272 if (win_length < n_fft) {
242273 // pad center
243- window_ = at::zeros ({n_fft}, self.options ());
244274 auto left = (n_fft - win_length) / 2 ;
245275 if (window.defined ()) {
276+ window_ = at::zeros ({n_fft}, window.options ());
246277 window_.narrow (0 , left, win_length).copy_ (window);
247278 } else {
279+ window_ = at::zeros ({n_fft}, self.options ());
248280 window_.narrow (0 , left, win_length).fill_ (1 );
249281 }
250282 }
@@ -257,19 +289,40 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
257289 if (window_.defined ()) {
258290 input = input.mul (window_);
259291 }
260- // rfft and transpose to get (batch x fft_size x num_frames)
261- auto out = input.rfft (1 , normalized, onesided).transpose_ (1 , 2 );
292+
293+ // FFT and transpose to get (batch x fft_size x num_frames)
294+ const bool complex_fft = input.is_complex ();
295+ const auto onesided = onesidedOpt.value_or (!complex_fft);
296+
297+ Tensor out;
298+ if (complex_fft) {
299+ TORCH_CHECK (!onesided, " Cannot have onesided output if window or input is complex" );
300+ out = at::native::fft (at::view_as_real (input), 1 , normalized);
301+ } else {
302+ out = at::native::rfft (input, 1 , normalized, onesided);
303+ }
304+ out.transpose_ (1 , 2 );
305+
262306 if (self.dim () == 1 ) {
263- return out.squeeze_ (0 );
307+ out.squeeze_ (0 );
308+ }
309+
310+ if (return_complex) {
311+ return at::view_as_complex (out);
264312 } else {
265313 return out;
266314 }
267315}
268316
317+ /* Inverse Short-time Fourier Transform
318+ *
319+ * This is modeled after librosa but with support for complex time-domain
320+ * signals and complex windows.
321+ */
269322Tensor istft (const Tensor& self, const int64_t n_fft, const optional<int64_t > hop_lengthOpt,
270323 const optional<int64_t > win_lengthOpt, const Tensor& window,
271- const bool center, const bool normalized, const bool onesided ,
272- const optional<int64_t > lengthOpt) {
324+ const bool center, const bool normalized, const c10::optional< bool > onesidedOpt ,
325+ const optional<int64_t > lengthOpt, const bool return_complex ) {
273326 #define REPR (SS ) \
274327 SS << " istft(" << self.toString () << self.sizes () << " , n_fft=" << n_fft \
275328 << " , hop_length=" << hop_length << " , win_length=" << win_length \
@@ -279,26 +332,23 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
279332 } else { \
280333 SS << " None" ; \
281334 } \
282- SS << " , center=" << center << " , normalized=" << normalized << " , onesided=" << onesided << " , length=" ; \
283- if (lengthOpt.has_value ()) { \
284- SS << lengthOpt.value (); \
285- } else { \
286- SS << " None" ; \
287- } \
288- SS << " )"
335+ SS << " , center=" << center << " , normalized=" << normalized << " , onesided=" ; \
336+ write_opt (SS, onesidedOpt) << " , length=" ; \
337+ write_opt (SS, lengthOpt) << " , return_complex=" << return_complex << " ) "
289338
290339 // default_init hop_length and win_length
291340 const auto hop_length = hop_lengthOpt.value_or (n_fft >> 2 );
292341 const auto win_length = win_lengthOpt.value_or (n_fft);
293342
294- const auto input_dim = self.dim ();
295- const auto n_frames = self.size (-2 );
296- const auto fft_size = self.size (-3 );
343+ Tensor input = self.is_complex () ? at::view_as_real (self) : self;
344+ const auto input_dim = input.dim ();
345+ const auto n_frames = input.size (-2 );
346+ const auto fft_size = input.size (-3 );
297347
298348 const auto expected_output_signal_len = n_fft + hop_length * (n_frames - 1 );
299349
300- const auto options = at::device (self .device ()).dtype (self .dtype ());
301- if (self .numel () == 0 ) {
350+ const auto options = at::device (input .device ()).dtype (input .dtype ());
351+ if (input .numel () == 0 ) {
302352 std::ostringstream ss;
303353 REPR (ss) << " : input tensor cannot be empty." ;
304354 AT_ERROR (ss.str ());
@@ -308,12 +358,13 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
308358 REPR (ss) << " : expected a tensor with 3 or 4 dimensions, but got " << input_dim;
309359 AT_ERROR (ss.str ());
310360 }
311- if (self .size (-1 ) != 2 ) {
361+ if (input .size (-1 ) != 2 ) {
312362 std::ostringstream ss;
313363 REPR (ss) << " : expected the last dimension to be 2 (corresponding to real and imaginary parts), but got " << self.size (-1 );
314364 AT_ERROR (ss.str ());
315365 }
316366
367+ const bool onesided = onesidedOpt.value_or (fft_size != n_fft);
317368 if (onesided) {
318369 if (n_fft / 2 + 1 != fft_size) {
319370 std::ostringstream ss;
@@ -355,13 +406,21 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
355406 TORCH_INTERNAL_ASSERT (window_tmp.size (0 ) == n_fft);
356407 }
357408
358- Tensor input = self;
359409 if (input_dim == 3 ) {
360410 input = input.unsqueeze (0 );
361411 }
362412
363413 input = input.transpose (1 , 2 ); // size: (channel, n_frames, fft_size, 2)
364- input = at::native::irfft (input, 1 , normalized, onesided, {n_fft, }); // size: (channel, n_frames, n_fft)
414+
415+ if (return_complex) {
416+ TORCH_CHECK (!onesided, " Cannot have onesided output if window or input is complex" );
417+ input = at::native::ifft (input, 1 , normalized); // size: (channel, n_frames, n_fft)
418+ input = at::view_as_complex (input);
419+ } else {
420+ TORCH_CHECK (!window.defined () || !window.is_complex (),
421+ " Complex windows are incompatible with return_complex=False" );
422+ input = at::native::irfft (input, 1 , normalized, onesided, {n_fft,}); // size: (channel, n_frames, n_fft)
423+ }
365424 TORCH_INTERNAL_ASSERT (input.size (2 ) == n_fft);
366425
367426 Tensor y_tmp = input * window_tmp.view ({1 , 1 , n_fft}); // size: (channel, n_frames, n_fft)
@@ -408,4 +467,21 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
408467 #undef REPR
409468}
410469
470+ Tensor stft (const Tensor& self, const int64_t n_fft, const optional<int64_t > hop_lengthOpt,
471+ const optional<int64_t > win_lengthOpt, const Tensor& window,
472+ const bool normalized, const optional<bool > onesidedOpt) {
473+ return at::native::stft (
474+ self, n_fft, hop_lengthOpt, win_lengthOpt, window, normalized, onesidedOpt,
475+ /* return_complex=*/ c10::nullopt );
476+ }
477+
478+ Tensor istft (const Tensor& self, const int64_t n_fft, const optional<int64_t > hop_lengthOpt,
479+ const optional<int64_t > win_lengthOpt, const Tensor& window,
480+ const bool center, const bool normalized, const optional<bool > onesidedOpt,
481+ const optional<int64_t > lengthOpt) {
482+ return at::native::istft (
483+ self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized,
484+ onesidedOpt, lengthOpt, /* return_complex=*/ false );
485+ }
486+
411487}} // at::native
0 commit comments