Skip to content

Commit 77bd4d3

Browse files
mattipfacebook-github-bot
authored andcommitted
MAINT: speed up istft by using col2im (the original python code used … (#42826)
Summary: Fixes #42213 The [original python code](https://github.com/pytorch/audio/blob/v0.5.0/torchaudio/functional.py#L178) from `torchaudio` was converted to a native function, but used `eye` to allocate a Tensor and was much slower. Using `at::col2im` (which is the equivalent of `torch.nn.functional.fold`) solved the slowdown. Pull Request resolved: #42826 Reviewed By: smessmer Differential Revision: D23043673 Pulled By: mthrok fbshipit-source-id: 3f5d0779a87379b002340ea19c9ae5042a43e94e
1 parent 4665f3f commit 77bd4d3

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

aten/src/ATen/native/SpectralOps.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -367,16 +367,22 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
367367
Tensor y_tmp = input * window_tmp.view({1, 1, n_fft}); // size: (channel, n_frames, n_fft)
368368
y_tmp = y_tmp.transpose(1, 2); // size: (channel, n_fft, frame)
369369

370-
const Tensor eye = at::native::eye(n_fft, options).unsqueeze(1);
371-
Tensor y = at::conv_transpose1d(y_tmp, eye,
372-
/*bias*/ Tensor(),
373-
/*stride*/ {hop_length,},
374-
/*padding*/{0,}); // size: (channel, n_frames, n_fft)
370+
Tensor y = at::col2im(y_tmp,
371+
/*output_size*/ {1, (n_frames - 1) * hop_length + n_fft},
372+
/*kernel_size*/ {1, n_fft},
373+
/*dilation*/ {1, 1},
374+
/*padding*/ {0, 0},
375+
/*stride*/ {1, hop_length}
376+
).squeeze(2);
375377
window_tmp = window_tmp.pow(2).view({n_fft, 1}).repeat({1, n_frames}).unsqueeze(0); // size: (1, n_fft, n_frames)
376-
Tensor window_envelop = at::conv_transpose1d(window_tmp, eye,
377-
/*bias*/ Tensor(),
378-
/*stride*/ {hop_length, },
379-
/*padding*/{0, }); // size: (1, 1, expected_output_signal_len)
378+
Tensor window_envelop = at::col2im(window_tmp,
379+
/*output_size*/ {1, (n_frames - 1) * hop_length + n_fft},
380+
/*kernel_size*/ {1, n_fft},
381+
/*dilation*/ {1, 1},
382+
/*padding*/ {0, 0},
383+
/*stride*/ {1, hop_length}
384+
).squeeze(2); // size: (1, 1, expected_output_signal_len)
385+
380386
TORCH_INTERNAL_ASSERT(expected_output_signal_len == y.size(2));
381387
TORCH_INTERNAL_ASSERT(expected_output_signal_len == window_envelop.size(2));
382388

0 commit comments

Comments
 (0)