Skip to content

Commit b0bd00b

Browse files
committed
Support arbitrary number of batch dimensions in *FFT
1 parent 6b7ec95 commit b0bd00b

File tree

4 files changed

+70
-48
lines changed

4 files changed

+70
-48
lines changed

aten/src/ATen/native/SpectralOps.cpp

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,40 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
3131
}
3232
if (!at::isFloatingType(self.type().scalarType())) {
3333
std::ostringstream ss;
34-
ss << "Expected an input tensor of floating types, but got input "
34+
ss << "Expected an input tensor of floating types, but got input="
3535
<< self.type() << self.sizes();
3636
throw std::runtime_error(ss.str());
3737
}
3838

3939
auto signal_tensor_ndim = signal_ndim + static_cast<int>(complex_input); // add complex dim
40-
if (self.dim() != signal_tensor_ndim && self.dim() != signal_tensor_ndim + 1) {
40+
if (self.dim() < signal_tensor_ndim) {
4141
std::ostringstream ss;
4242
ss << "Given signal_ndim=" << signal_ndim << ", expected an input tensor "
43-
<< "of " << signal_tensor_ndim << "D or " << signal_tensor_ndim + 1
44-
<< "D (batch mode)";
43+
<< "of at least" << signal_tensor_ndim << "D";
4544
if (complex_input) {
4645
ss << " (complex input adds an extra dimension)";
4746
}
48-
ss << ", but got input " << self.type() << self.sizes();
47+
ss << ", but got input=" << self.type() << self.sizes();
4948
throw std::runtime_error(ss.str());
5049
}
51-
bool is_batched = self.dim() == signal_tensor_ndim + 1;
5250

53-
Tensor input = self;
51+
auto self_shape = self.sizes();
52+
auto batch_ndim = self.dim() - signal_tensor_ndim;
5453

55-
if (!is_batched) {
54+
Tensor input = self;
55+
// flatten the batch dims
56+
if (batch_ndim == 0) {
57+
// slightly faster path for non-batch mode
5658
input = input.unsqueeze(0);
59+
} else if (batch_ndim > 1) {
60+
std::vector<int64_t> flatten_input_shape(signal_tensor_ndim + 1);
61+
std::copy(self_shape.begin() + batch_ndim, self_shape.end(), flatten_input_shape.begin() + 1);
62+
flatten_input_shape[0] = -1;
63+
input = input.reshape(flatten_input_shape);
64+
5765
}
58-
// now we assume that input is batched
66+
67+
// now we assume that input is batched as [ B x signal_dims... ]
5968

6069
if (complex_input) {
6170
if (input.size(signal_ndim + 1) != 2) {
@@ -104,8 +113,8 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
104113
std::ostringstream ss;
105114
ss << "Expected given signal_sizes=" << signal_sizes << " to have same "
106115
<< "shape with input at signal dimension " << i << ", but got "
107-
<< "signal_sizes[" << i << "] = " << signal_sizes[i] << " and "
108-
<< "input.size(" << i + (int)is_batched << ") = " << input_size;
116+
<< "signal_sizes=" << signal_sizes << " and input=" << self.type()
117+
<< self.sizes();
109118
throw std::runtime_error(ss.str());
110119
}
111120
}
@@ -119,12 +128,18 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
119128
checked_signal_sizes, normalized, onesided,
120129
output_sizes);
121130

122-
// un-batch if needed
123-
if (!is_batched) {
124-
return output.squeeze_(0);
125-
} else {
126-
return output;
131+
// unflatten the batch dims
132+
if (batch_ndim == 0) {
133+
// slightly faster path for non-batch mode
134+
output = output.squeeze(0);
135+
} else if (batch_ndim > 1) {
136+
auto output_ndim = self.dim() + static_cast<int>(complex_output) - static_cast<int>(complex_input);
137+
std::vector<int64_t> unflatten_output_shape(output_ndim);
138+
std::copy(self_shape.begin(), self_shape.begin() + batch_ndim, unflatten_output_shape.begin());
139+
std::copy(output_sizes.begin() + 1, output_sizes.end(), unflatten_output_shape.begin() + batch_ndim);
140+
output = output.reshape(unflatten_output_shape);
127141
}
142+
return output;
128143
}
129144

130145
Tensor fft(const Tensor& self, const int64_t signal_ndim, const bool normalized) {

test/test_autograd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,9 +1929,9 @@ def rfft_irfft(x):
19291929
_test_real((2, 3, 4), 2)
19301930
_test_real((2, 3, 4, 3), 3)
19311931

1932-
_test_complex((2, 10, 2), 1)
1933-
_test_complex((2, 3, 4, 2), 2)
1934-
_test_complex((2, 3, 4, 3, 2), 3)
1932+
_test_complex((2, 2, 10, 2), 1)
1933+
_test_complex((1, 2, 3, 4, 2), 2)
1934+
_test_complex((2, 1, 3, 4, 3, 2), 3)
19351935

19361936
def test_variable_traverse(self):
19371937
def get_out_and_unrefed_cycle():

test/test_torch.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3485,11 +3485,7 @@ def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x):
34853485
def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x):
34863486
x = prepro_fn(build_fn(*sizes))
34873487
signal_numel = 1
3488-
if x.dim() == signal_ndim:
3489-
start_dim = 0
3490-
else:
3491-
start_dim = 1
3492-
signal_sizes = x.size()[start_dim:start_dim + signal_ndim]
3488+
signal_sizes = x.size()[-signal_ndim:]
34933489
for normalized, onesided in product((True, False), repeat=2):
34943490
res = x.rfft(signal_ndim, normalized=normalized, onesided=onesided)
34953491
if not onesided: # check Hermitian symmetry
@@ -3504,10 +3500,12 @@ def test_one_sample(res, test_num=10):
35043500
if len(sizes) == signal_ndim:
35053501
test_one_sample(res)
35063502
else:
3507-
nb = res.size(0)
3503+
output_non_batch_shape = res.size()[-(signal_ndim + 1):]
3504+
flatten_batch_res = res.view(-1, *output_non_batch_shape)
3505+
nb = flatten_batch_res.size(0)
35083506
test_idxs = torch.LongTensor(min(nb, 4)).random_(nb)
35093507
for test_idx in test_idxs.tolist():
3510-
test_one_sample(res[test_idx])
3508+
test_one_sample(flatten_batch_res[test_idx])
35113509
# compare with C2C
35123510
xc = torch.stack([x, torch.zeros_like(x)], -1)
35133511
xc_res = xc.fft(signal_ndim, normalized=normalized)
@@ -3523,18 +3521,18 @@ def test_one_sample(res, test_num=10):
35233521

35243522
# contiguous case
35253523
_test_real((100,), 1)
3526-
_test_real((100, 100), 1)
3524+
_test_real((10, 1, 10, 100), 1)
35273525
_test_real((100, 100), 2)
3528-
_test_real((20, 80, 60), 2)
3526+
_test_real((2, 2, 5, 80, 60), 2)
35293527
_test_real((50, 40, 70), 3)
3530-
_test_real((30, 50, 25, 20), 3)
3528+
_test_real((30, 1, 50, 25, 20), 3)
35313529

35323530
_test_complex((100, 2), 1)
35333531
_test_complex((100, 100, 2), 1)
35343532
_test_complex((100, 100, 2), 2)
3535-
_test_complex((20, 80, 60, 2), 2)
3533+
_test_complex((1, 20, 80, 60, 2), 2)
35363534
_test_complex((50, 40, 70, 2), 3)
3537-
_test_complex((30, 50, 25, 20, 2), 3)
3535+
_test_complex((6, 5, 50, 25, 20, 2), 3)
35383536

35393537
# non-contiguous case
35403538
_test_real((165,), 1, lambda x: x.narrow(0, 25, 100)) # input is not aligned to complex type

torch/_torch_docs.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6031,7 +6031,7 @@
60316031
Complex-to-complex Discrete Fourier Transform
60326032
60336033
This method computes the complex-to-complex discrete Fourier transform.
6034-
Ignoring the batch dimension, it computes the following expression:
6034+
Ignoring the batch dimensions, it computes the following expression:
60356035
60366036
.. math::
60376037
X[\omega_1, \dots, \omega_d] =
@@ -6044,10 +6044,10 @@
60446044
This method supports 1D, 2D and 3D complex-to-complex transforms, indicated
60456045
by :attr:`signal_ndim`. :attr:`input` must be a tensor with last dimension
60466046
of size 2, representing the real and imaginary components of complex
6047-
numbers, and should have ``signal_ndim + 1`` dimensions or ``signal_ndim + 2``
6048-
dimensions with batched data. If :attr:`normalized` is set to ``True``, this
6049-
normalizes the result by dividing it with :math:`\sqrt{\prod_{i=1}^K N_i}` so
6050-
that the operator is unitary.
6047+
numbers, and should have at least ``signal_ndim + 1`` dimensions with optionally
6048+
arbitrary number of leading batch dimensions. If :attr:`normalized` is set to
6049+
``True``, this normalizes the result by dividing it with
6050+
:math:`\sqrt{\prod_{i=1}^K N_i}` so that the operator is unitary.
60516051
60526052
Returns the real and the imaginary parts together as one tensor of the same
60536053
shape of :attr:`input`.
@@ -6059,7 +6059,8 @@
60596059
:func:`torch.backends.mkl.is_available` to check if MKL is installed.
60606060
60616061
Arguments:
6062-
input (Tensor): the input tensor
6062+
input (Tensor): the input tensor of at least :attr:`signal_ndim` ``+ 1``
6063+
dimensions
60636064
signal_ndim (int): the number of dimensions in each signal.
60646065
:attr:`signal_ndim` can only be 1, 2 or 3
60656066
normalized (bool, optional): controls whether to return normalized results.
@@ -6119,6 +6120,12 @@
61196120
0.2740 1.3332
61206121
[torch.FloatTensor of size (4,3,2)]
61216122
6123+
>>> # arbitrary number of batch dimensions, 2D FFT
6124+
>>> x = torch.randn(3, 3, 5, 5, 2)
6125+
>>> y = torch.fft(x, 2)
6126+
>>> y.shape
6127+
torch.Size([3, 3, 5, 5, 2])
6128+
61226129
""")
61236130

61246131
add_docstr(torch.ifft,
@@ -6128,7 +6135,7 @@
61286135
Complex-to-complex Inverse Discrete Fourier Transform
61296136
61306137
This method computes the complex-to-complex inverse discrete Fourier
6131-
transform. Ignoring the batch dimension, it computes the following
6138+
transform. Ignoring the batch dimensions, it computes the following
61326139
expression:
61336140
61346141
.. math::
@@ -6155,7 +6162,8 @@
61556162
:func:`torch.backends.mkl.is_available` to check if MKL is installed.
61566163
61576164
Arguments:
6158-
input (Tensor): the input tensor
6165+
input (Tensor): the input tensor of at least :attr:`signal_ndim` ``+ 1``
6166+
dimensions
61596167
signal_ndim (int): the number of dimensions in each signal.
61606168
:attr:`signal_ndim` can only be 1, 2 or 3
61616169
normalized (bool, optional): controls whether to return normalized results.
@@ -6217,11 +6225,11 @@
62176225
formats of the input and output.
62186226
62196227
This method supports 1D, 2D and 3D real-to-complex transforms, indicated
6220-
by :attr:`signal_ndim`. :attr:`input` must be a tensor with ``signal_ndim``
6221-
dimensions or ``signal_ndim + 1`` dimensions with batched data. If
6222-
:attr:`normalized` is set to ``True``, this normalizes the result by multiplying
6223-
it with :math:`\sqrt{\prod_{i=1}^K N_i}` so that the operator is unitary, where
6224-
:math:`N_i` is the size of signal dimension :math:`i`.
6228+
by :attr:`signal_ndim`. :attr:`input` must be a tensor with at least
6229+
``signal_ndim`` dimensions with optionally arbitrary number of leading batch
6230+
dimensions. If :attr:`normalized` is set to ``True``, this normalizes the result
6231+
by multiplying it with :math:`\sqrt{\prod_{i=1}^K N_i}` so that the operator is
6232+
unitary, where :math:`N_i` is the size of signal dimension :math:`i`.
62256233
62266234
The real-to-complex Fourier transform results follow conjugate symmetry:
62276235
@@ -6243,7 +6251,7 @@
62436251
:func:`torch.backends.mkl.is_available` to check if MKL is installed.
62446252
62456253
Arguments:
6246-
input (Tensor): the input tensor
6254+
input (Tensor): the input tensor of at least :attr:`signal_ndim` dimensions
62476255
signal_ndim (int): the number of dimensions in each signal.
62486256
:attr:`signal_ndim` can only be 1, 2 or 3
62496257
normalized (bool, optional): controls whether to return normalized results.
@@ -6287,8 +6295,8 @@
62876295
``rfft(signal, onesided=True)``. In such case, set the :attr:`onesided`
62886296
argument of this method to ``True``. Moreover, the original signal shape
62896297
information can sometimes be lost, optionally set :attr:`signal_sizes` to be
6290-
the size of the original signal (without batch dimension if in batched mode) to
6291-
recover it with correct shape.
6298+
the size of the original signal (without the batch dimensions if in batched
6299+
mode) to recover it with correct shape.
62926300
62936301
Therefore, to invert an :func:`~torch.rfft`, the :attr:`normalized` and
62946302
:attr:`onesided` arguments should be set identically for :func:`~torch.irfft`,
@@ -6313,7 +6321,8 @@
63136321
:func:`torch.backends.mkl.is_available` to check if MKL is installed.
63146322
63156323
Arguments:
6316-
input (Tensor): the input tensor
6324+
input (Tensor): the input tensor of at least :attr:`signal_ndim` ``+ 1``
6325+
dimensions
63176326
signal_ndim (int): the number of dimensions in each signal.
63186327
:attr:`signal_ndim` can only be 1, 2 or 3
63196328
normalized (bool, optional): controls whether to return normalized results.

0 commit comments

Comments
 (0)