Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions aten/src/ATen/native/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,40 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
}
if (!at::isFloatingType(self.type().scalarType())) {
std::ostringstream ss;
ss << "Expected an input tensor of floating types, but got input "
ss << "Expected an input tensor of floating types, but got input="
<< self.type() << self.sizes();
throw std::runtime_error(ss.str());

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

}

auto signal_tensor_ndim = signal_ndim + static_cast<int>(complex_input); // add complex dim
if (self.dim() != signal_tensor_ndim && self.dim() != signal_tensor_ndim + 1) {
if (self.dim() < signal_tensor_ndim) {
std::ostringstream ss;
ss << "Given signal_ndim=" << signal_ndim << ", expected an input tensor "
<< "of " << signal_tensor_ndim << "D or " << signal_tensor_ndim + 1
<< "D (batch mode)";
<< "of at least" << signal_tensor_ndim << "D";
if (complex_input) {
ss << " (complex input adds an extra dimension)";
}
ss << ", but got input " << self.type() << self.sizes();
ss << ", but got input=" << self.type() << self.sizes();
throw std::runtime_error(ss.str());
}
bool is_batched = self.dim() == signal_tensor_ndim + 1;

Tensor input = self;
auto self_shape = self.sizes();
auto batch_ndim = self.dim() - signal_tensor_ndim;

if (!is_batched) {
Tensor input = self;
// flatten the batch dims
if (batch_ndim == 0) {
// slightly faster path for non-batch mode
input = input.unsqueeze(0);
} else if (batch_ndim > 1) {
std::vector<int64_t> flatten_input_shape(signal_tensor_ndim + 1);
std::copy(self_shape.begin() + batch_ndim, self_shape.end(), flatten_input_shape.begin() + 1);
flatten_input_shape[0] = -1;
input = input.reshape(flatten_input_shape);

}
// now we assume that input is batched

// now we assume that input is batched as [ B x signal_dims... ]

if (complex_input) {
if (input.size(signal_ndim + 1) != 2) {
Expand Down Expand Up @@ -104,8 +113,8 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
std::ostringstream ss;
ss << "Expected given signal_sizes=" << signal_sizes << " to have same "
<< "shape with input at signal dimension " << i << ", but got "
<< "signal_sizes[" << i << "] = " << signal_sizes[i] << " and "
<< "input.size(" << i + (int)is_batched << ") = " << input_size;
<< "signal_sizes=" << signal_sizes << " and input=" << self.type()
<< self.sizes();
throw std::runtime_error(ss.str());
}
}
Expand All @@ -119,12 +128,18 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
checked_signal_sizes, normalized, onesided,
output_sizes);

// un-batch if needed
if (!is_batched) {
return output.squeeze_(0);
} else {
return output;
// unflatten the batch dims
if (batch_ndim == 0) {
// slightly faster path for non-batch mode
output = output.squeeze(0);
} else if (batch_ndim > 1) {
auto output_ndim = self.dim() + static_cast<int>(complex_output) - static_cast<int>(complex_input);
std::vector<int64_t> unflatten_output_shape(output_ndim);
std::copy(self_shape.begin(), self_shape.begin() + batch_ndim, unflatten_output_shape.begin());
std::copy(output_sizes.begin() + 1, output_sizes.end(), unflatten_output_shape.begin() + batch_ndim);
output = output.reshape(unflatten_output_shape);
}
return output;
}

Tensor fft(const Tensor& self, const int64_t signal_ndim, const bool normalized) {
Expand Down
6 changes: 3 additions & 3 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1929,9 +1929,9 @@ def rfft_irfft(x):
_test_real((2, 3, 4), 2)
_test_real((2, 3, 4, 3), 3)

_test_complex((2, 10, 2), 1)
_test_complex((2, 3, 4, 2), 2)
_test_complex((2, 3, 4, 3, 2), 3)
_test_complex((2, 2, 10, 2), 1)
_test_complex((1, 2, 3, 4, 2), 2)
_test_complex((2, 1, 3, 4, 3, 2), 3)

def test_variable_traverse(self):
def get_out_and_unrefed_cycle():
Expand Down
22 changes: 10 additions & 12 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3485,11 +3485,7 @@ def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x):
def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x):
x = prepro_fn(build_fn(*sizes))
signal_numel = 1
if x.dim() == signal_ndim:
start_dim = 0
else:
start_dim = 1
signal_sizes = x.size()[start_dim:start_dim + signal_ndim]
signal_sizes = x.size()[-signal_ndim:]
for normalized, onesided in product((True, False), repeat=2):
res = x.rfft(signal_ndim, normalized=normalized, onesided=onesided)
if not onesided: # check Hermitian symmetry
Expand All @@ -3504,10 +3500,12 @@ def test_one_sample(res, test_num=10):
if len(sizes) == signal_ndim:
test_one_sample(res)
else:
nb = res.size(0)
output_non_batch_shape = res.size()[-(signal_ndim + 1):]
flatten_batch_res = res.view(-1, *output_non_batch_shape)
nb = flatten_batch_res.size(0)
test_idxs = torch.LongTensor(min(nb, 4)).random_(nb)
for test_idx in test_idxs.tolist():
test_one_sample(res[test_idx])
test_one_sample(flatten_batch_res[test_idx])
# compare with C2C
xc = torch.stack([x, torch.zeros_like(x)], -1)
xc_res = xc.fft(signal_ndim, normalized=normalized)
Expand All @@ -3523,18 +3521,18 @@ def test_one_sample(res, test_num=10):

# contiguous case
_test_real((100,), 1)
_test_real((100, 100), 1)
_test_real((10, 1, 10, 100), 1)
_test_real((100, 100), 2)
_test_real((20, 80, 60), 2)
_test_real((2, 2, 5, 80, 60), 2)
_test_real((50, 40, 70), 3)
_test_real((30, 50, 25, 20), 3)
_test_real((30, 1, 50, 25, 20), 3)

_test_complex((100, 2), 1)
_test_complex((100, 100, 2), 1)
_test_complex((100, 100, 2), 2)
_test_complex((20, 80, 60, 2), 2)
_test_complex((1, 20, 80, 60, 2), 2)
_test_complex((50, 40, 70, 2), 3)
_test_complex((30, 50, 25, 20, 2), 3)
_test_complex((6, 5, 50, 25, 20, 2), 3)

# non-contiguous case
_test_real((165,), 1, lambda x: x.narrow(0, 25, 100)) # input is not aligned to complex type
Expand Down
43 changes: 26 additions & 17 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6031,7 +6031,7 @@
Complex-to-complex Discrete Fourier Transform

This method computes the complex-to-complex discrete Fourier transform.
Ignoring the batch dimension, it computes the following expression:
Ignoring the batch dimensions, it computes the following expression:

.. math::
X[\omega_1, \dots, \omega_d] =
Expand All @@ -6044,10 +6044,10 @@
This method supports 1D, 2D and 3D complex-to-complex transforms, indicated
by :attr:`signal_ndim`. :attr:`input` must be a tensor with last dimension
of size 2, representing the real and imaginary components of complex
numbers, and should have ``signal_ndim + 1`` dimensions or ``signal_ndim + 2``
dimensions with batched data. If :attr:`normalized` is set to ``True``, this
normalizes the result by dividing it with :math:`\sqrt{\prod_{i=1}^K N_i}` so
that the operator is unitary.
numbers, and should have at least ``signal_ndim + 1`` dimensions with optionally
arbitrary number of leading batch dimensions. If :attr:`normalized` is set to
``True``, this normalizes the result by dividing it with
:math:`\sqrt{\prod_{i=1}^K N_i}` so that the operator is unitary.

Returns the real and the imaginary parts together as one tensor of the same
shape of :attr:`input`.
Expand All @@ -6059,7 +6059,8 @@
:func:`torch.backends.mkl.is_available` to check if MKL is installed.

Arguments:
input (Tensor): the input tensor
input (Tensor): the input tensor of at least :attr:`signal_ndim` ``+ 1``
dimensions
signal_ndim (int): the number of dimensions in each signal.
:attr:`signal_ndim` can only be 1, 2 or 3
normalized (bool, optional): controls whether to return normalized results.
Expand Down Expand Up @@ -6119,6 +6120,12 @@
0.2740 1.3332
[torch.FloatTensor of size (4,3,2)]

>>> # arbitrary number of batch dimensions, 2D FFT
>>> x = torch.randn(3, 3, 5, 5, 2)
>>> y = torch.fft(x, 2)
>>> y.shape
torch.Size([3, 3, 5, 5, 2])

""")

add_docstr(torch.ifft,
Expand All @@ -6128,7 +6135,7 @@
Complex-to-complex Inverse Discrete Fourier Transform

This method computes the complex-to-complex inverse discrete Fourier
transform. Ignoring the batch dimension, it computes the following
transform. Ignoring the batch dimensions, it computes the following
expression:

.. math::
Expand All @@ -6155,7 +6162,8 @@
:func:`torch.backends.mkl.is_available` to check if MKL is installed.

Arguments:
input (Tensor): the input tensor
input (Tensor): the input tensor of at least :attr:`signal_ndim` ``+ 1``
dimensions
signal_ndim (int): the number of dimensions in each signal.
:attr:`signal_ndim` can only be 1, 2 or 3
normalized (bool, optional): controls whether to return normalized results.
Expand Down Expand Up @@ -6217,11 +6225,11 @@
formats of the input and output.

This method supports 1D, 2D and 3D real-to-complex transforms, indicated
by :attr:`signal_ndim`. :attr:`input` must be a tensor with ``signal_ndim``
dimensions or ``signal_ndim + 1`` dimensions with batched data. If
:attr:`normalized` is set to ``True``, this normalizes the result by multiplying
it with :math:`\sqrt{\prod_{i=1}^K N_i}` so that the operator is unitary, where
:math:`N_i` is the size of signal dimension :math:`i`.
by :attr:`signal_ndim`. :attr:`input` must be a tensor with at least
``signal_ndim`` dimensions with optionally arbitrary number of leading batch
dimensions. If :attr:`normalized` is set to ``True``, this normalizes the result
by multiplying it with :math:`\sqrt{\prod_{i=1}^K N_i}` so that the operator is
unitary, where :math:`N_i` is the size of signal dimension :math:`i`.

The real-to-complex Fourier transform results follow conjugate symmetry:

Expand All @@ -6243,7 +6251,7 @@
:func:`torch.backends.mkl.is_available` to check if MKL is installed.

Arguments:
input (Tensor): the input tensor
input (Tensor): the input tensor of at least :attr:`signal_ndim` dimensions
signal_ndim (int): the number of dimensions in each signal.
:attr:`signal_ndim` can only be 1, 2 or 3
normalized (bool, optional): controls whether to return normalized results.
Expand Down Expand Up @@ -6287,8 +6295,8 @@
``rfft(signal, onesided=True)``. In such case, set the :attr:`onesided`
argument of this method to ``True``. Moreover, the original signal shape
information can sometimes be lost, optionally set :attr:`signal_sizes` to be
the size of the original signal (without batch dimension if in batched mode) to
recover it with correct shape.
the size of the original signal (without the batch dimensions if in batched
mode) to recover it with correct shape.

Therefore, to invert an :func:`~torch.rfft`, the :attr:`normalized` and
:attr:`onesided` arguments should be set identically for :func:`~torch.irfft`,
Expand All @@ -6313,7 +6321,8 @@
:func:`torch.backends.mkl.is_available` to check if MKL is installed.

Arguments:
input (Tensor): the input tensor
input (Tensor): the input tensor of at least :attr:`signal_ndim` ``+ 1``
dimensions
signal_ndim (int): the number of dimensions in each signal.
:attr:`signal_ndim` can only be 1, 2 or 3
normalized (bool, optional): controls whether to return normalized results.
Expand Down