Skip to content

Commit 1e2e1a6

Browse files
committed
Address review
1 parent e030843 commit 1e2e1a6

File tree

2 files changed

+60
-91
lines changed

2 files changed

+60
-91
lines changed

torch/signal/windows/windows.py

Lines changed: 54 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
window_common_args = merge_dicts(
1515
parse_kwargs(
1616
"""
17-
M (int): the length of the output window.
17+
M (int): the length of the window.
1818
In other words, the number of points of the exponential window.
1919
sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis.
2020
If `True`, returns a symmetric window suitable for use in filter design. Default: `True`.
2121
"""
2222
),
23-
factory_common_args
23+
factory_common_args,
24+
{"normalization": "The window is normalized to 1 (maximum value is 1). However, the 1 doesn't appear if "
25+
"`M` is even and `sym` is `True`."}
2426
)
2527

2628

@@ -37,9 +39,7 @@ def _add_docstr(*args):
3739
"""
3840

3941
def decorator(o):
40-
o.__doc__ = ""
41-
for arg in args:
42-
o.__doc__ += arg
42+
o.__doc__ = "".join(args)
4343
return o
4444

4545
return decorator
@@ -71,48 +71,45 @@ def _window_function_checks(function_name: str, M: int, dtype: torch.dtype, layo
7171
The exponential window is defined as follows:
7272
7373
.. math::
74-
w(n) = \exp{\left(-\frac{|n - center|}{\tau}\right)}
74+
w(n) = \exp{\left(-\frac{|n - c|}{\tau}\right)}
75+
76+
Where `c` is the center of the window.
7577
""",
7678
r"""
7779
80+
{normalization}
81+
7882
Args:
7983
{M}
84+
85+
Keyword args:
8086
center (float, optional): where the center of the window will be located.
81-
In other words, at which sample the peak of the window can be found.
8287
Default: `M / 2` if `sym` is `False`, else `(M - 1) / 2`.
8388
tau (float, optional): the decay value. Default: 1.0.
8489
{sym}
85-
86-
.. note::
87-
The window is normalized to 1 (maximum value is 1), however, the 1 doesn't appear if `M` is even
88-
and `sym` is `True`.
89-
90-
Keyword args:
9190
{dtype}
9291
{layout}
9392
{device}
9493
{requires_grad}
9594
96-
Examples:
95+
Examples::
9796
>>> # Generate an exponential window without keyword args.
9897
>>> torch.signal.windows.exponential(10)
99-
tensor([0.0111, 0.0302, 0.0821, 0.2231, 0.6065, 0.6065, 0.2231, 0.0821, 0.0302,
100-
0.0111])
98+
tensor([0.0111, 0.0302, 0.0821, 0.2231, 0.6065, 0.6065, 0.2231, 0.0821, 0.0302, 0.0111])
10199
102100
>>> # Generate a periodic exponential window and decay factor equal to .5
103101
>>> torch.signal.windows.exponential(10,sym=False,tau=.5)
104-
tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00,
105-
1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04])
102+
tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04])
106103
""".format(
107104
**window_common_args
108105
),
109106
)
110107
def exponential(
111108
M: int,
109+
*,
112110
center: float = None,
113111
tau: float = 1.0,
114112
sym: bool = True,
115-
*,
116113
dtype: torch.dtype = None,
117114
layout: torch.layout = torch.strided,
118115
device: torch.device = None,
@@ -126,17 +123,14 @@ def exponential(
126123
if M == 0:
127124
return torch.empty((0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
128125

129-
if M == 1:
130-
return torch.ones((1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
131-
132126
if tau <= 0:
133127
raise ValueError(f'Tau must be positive, got: {tau} instead.')
134128

135129
if not sym and center is not None:
136-
raise ValueError('Center must be \'None\' for non-symmetric windows')
130+
raise ValueError('Center must be None for non-symmetric windows')
137131

138132
if center is None:
139-
center = -(M if not sym else M - 1) / 2.0
133+
center = -(M if not sym and M > 1 else M - 1) / 2.0
140134

141135
constant = 1 / tau
142136

@@ -164,46 +158,42 @@ def exponential(
164158
165159
.. math::
166160
w(n) = \cos{\left(\frac{\pi n}{M} - \frac{\pi}{2}\right)} = \sin{\left(\frac{\pi n}{M}\right)}
167-
168-
Where `M` is the length of the window.
169161
""",
170162
r"""
171163
164+
{normalization}
165+
172166
Args:
173167
{M}
174-
{sym}
175-
176-
.. note::
177-
The window is normalized to 1 (maximum value is 1), however, the 1 doesn't appear if `M` is even
178-
and `sym` is `True`.
179168
180169
Keyword args:
170+
{sym}
181171
{dtype}
182172
{layout}
183173
{device}
184174
{requires_grad}
185175
186-
Examples:
176+
Examples::
187177
>>> # Generate a cosine window without keyword args.
188178
>>> torch.signal.windows.cosine(10)
189-
tensor([0.1564, 0.4540, 0.7071, 0.8910, 0.9877, 0.9877, 0.8910, 0.7071, 0.4540,
190-
0.1564])
179+
tensor([0.1564, 0.4540, 0.7071, 0.8910, 0.9877, 0.9877, 0.8910, 0.7071, 0.4540, 0.1564])
191180
192181
>>> # Generate a periodic cosine window.
193182
>>> torch.signal.windows.cosine(10,sym=False)
194-
tensor([0.1423, 0.4154, 0.6549, 0.8413, 0.9595, 1.0000, 0.9595, 0.8413, 0.6549,
195-
0.4154])
183+
tensor([0.1423, 0.4154, 0.6549, 0.8413, 0.9595, 1.0000, 0.9595, 0.8413, 0.6549, 0.4154])
196184
""".format(
197185
**window_common_args,
198186
),
199187
)
200-
def cosine(M: int,
201-
sym: bool = True,
202-
*,
203-
dtype: torch.dtype = None,
204-
layout: torch.layout = torch.strided,
205-
device: torch.device = None,
206-
requires_grad: bool = False) -> Tensor:
188+
def cosine(
189+
M: int,
190+
*,
191+
sym: bool = True,
192+
dtype: torch.dtype = None,
193+
layout: torch.layout = torch.strided,
194+
device: torch.device = None,
195+
requires_grad: bool = False
196+
) -> Tensor:
207197
if dtype is None:
208198
dtype = torch.get_default_dtype()
209199

@@ -212,11 +202,8 @@ def cosine(M: int,
212202
if M == 0:
213203
return torch.empty((0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
214204

215-
if M == 1:
216-
return torch.ones((1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
217-
218205
start = 0.5
219-
constant = torch.pi / (M + 1 if not sym else M)
206+
constant = torch.pi / (M + 1 if not sym and M > 1 else M)
220207

221208
"""
222209
Note that non-integer step is subject to floating point rounding errors when comparing against end;
@@ -241,47 +228,45 @@ def cosine(M: int,
241228
242229
.. math::
243230
w(n) = \exp{\left(-\left(\frac{n}{2\sigma}\right)^2\right)}
244-
""",
245-
r"""
246231
232+
{normalization}
233+
""",
234+
r"""
235+
247236
Args:
248237
{M}
238+
239+
Keyword args:
249240
std (float, optional): the standard deviation of the gaussian. It controls how narrow or wide the window is.
250241
Default: 1.0.
251242
{sym}
252-
253-
.. note::
254-
The window is normalized to 1 (maximum value is 1), however, the 1 doesn't appear if `M` is even
255-
and `sym` is `True`.
256-
257-
Keyword args:
258243
{dtype}
259244
{layout}
260245
{device}
261246
{requires_grad}
262247
263-
Examples:
248+
Examples::
264249
>>> # Generate a gaussian window without keyword args.
265250
>>> torch.signal.windows.gaussian(10)
266-
tensor([4.0065e-05, 2.1875e-03, 4.3937e-02, 3.2465e-01, 8.8250e-01, 8.8250e-01,
267-
3.2465e-01, 4.3937e-02, 2.1875e-03, 4.0065e-05])
251+
tensor([4.0065e-05, 2.1875e-03, 4.3937e-02, 3.2465e-01, 8.8250e-01, 8.8250e-01, 3.2465e-01, 4.3937e-02, 2.1875e-03, 4.0065e-05])
268252
269253
>>> # Generate a periodic gaussian window and standard deviation equal to 0.9.
270254
>>> torch.signal.windows.gaussian(10,sym=False,std=0.9)
271-
tensor([1.9858e-07, 5.1365e-05, 3.8659e-03, 8.4658e-02, 5.3941e-01, 1.0000e+00,
272-
5.3941e-01, 8.4658e-02, 3.8659e-03, 5.1365e-05])
255+
tensor([1.9858e-07, 5.1365e-05, 3.8659e-03, 8.4658e-02, 5.3941e-01, 1.0000e+00, 5.3941e-01, 8.4658e-02, 3.8659e-03, 5.1365e-05])
273256
""".format(
274257
**window_common_args,
275258
),
276259
)
277-
def gaussian(M: int,
278-
std: float = 1.0,
279-
sym: bool = True,
280-
*,
281-
dtype: torch.dtype = None,
282-
layout: torch.layout = torch.strided,
283-
device: torch.device = None,
284-
requires_grad: bool = False) -> Tensor:
260+
def gaussian(
261+
M: int,
262+
*,
263+
std: float = 1.0,
264+
sym: bool = True,
265+
dtype: torch.dtype = None,
266+
layout: torch.layout = torch.strided,
267+
device: torch.device = None,
268+
requires_grad: bool = False
269+
) -> Tensor:
285270
if dtype is None:
286271
dtype = torch.get_default_dtype()
287272

@@ -290,13 +275,10 @@ def gaussian(M: int,
290275
if M == 0:
291276
return torch.empty((0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
292277

293-
if M == 1:
294-
return torch.ones((1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
295-
296278
if std <= 0:
297279
raise ValueError(f'Standard deviation must be positive, got: {std} instead.')
298280

299-
start = -(M if not sym else M - 1) / 2.0
281+
start = -(M if not sym and M > 1 else M - 1) / 2.0
300282

301283
constant = 1 / (std * sqrt(2))
302284

torch/testing/_internal/opinfo/definitions/signal.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,8 @@ def sample_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs)
2626
additional keyword arguments.
2727
"""
2828

29-
# Test a window size of length zero and one.
30-
# If it's either symmetric or not doesn't matter in these sample inputs.
31-
for size in range(2):
32-
yield SampleInput(
33-
size,
34-
*args,
35-
device=device,
36-
dtype=dtype,
37-
requires_grad=requires_grad,
38-
**kwargs,
39-
)
40-
41-
# For sizes larger than 1 we need to test both symmetric and non-symmetric windows.
42-
# Note: sample input tensors must be kept rather small.
43-
for size, sym in product(list(range(2, 6)), (True, False)):
29+
# Tests window sizes up to 5 samples.
30+
for size, sym in product(range(6), (True, False)):
4431
yield SampleInput(
4532
size,
4633
*args,
@@ -97,7 +84,7 @@ def error_inputs_exponential_window(op_info, device, **kwargs):
9784
yield ErrorInput(
9885
SampleInput(3, center=1, sym=False, dtype=torch.float32, device=device),
9986
error_type=ValueError,
100-
error_regex="Center must be 'None' for non-symmetric windows",
87+
error_regex="Center must be None for non-symmetric windows",
10188
)
10289

10390

@@ -203,19 +190,19 @@ def make_signal_windows_opinfo(
203190
op_db: List[OpInfo] = [
204191
make_signal_windows_opinfo(
205192
name="signal.windows.cosine",
206-
ref=make_signal_windows_ref(scipy.signal.windows.cosine),
193+
ref=make_signal_windows_ref(scipy.signal.windows.cosine) if TEST_SCIPY else None,
207194
sample_inputs_func=sample_inputs_window,
208195
error_inputs_func=error_inputs_window,
209196
),
210197
make_signal_windows_opinfo(
211198
name="signal.windows.exponential",
212-
ref=make_signal_windows_ref(scipy.signal.windows.exponential),
199+
ref=make_signal_windows_ref(scipy.signal.windows.exponential) if TEST_SCIPY else None,
213200
sample_inputs_func=partial(sample_inputs_window, tau=random.uniform(0, 10)),
214201
error_inputs_func=error_inputs_exponential_window,
215202
),
216203
make_signal_windows_opinfo(
217204
name="signal.windows.gaussian",
218-
ref=make_signal_windows_ref(scipy.signal.windows.gaussian),
205+
ref=make_signal_windows_ref(scipy.signal.windows.gaussian) if TEST_SCIPY else None,
219206
sample_inputs_func=partial(sample_inputs_window, std=random.uniform(0, 3)),
220207
error_inputs_func=error_inputs_gaussian_window,
221208
),

0 commit comments

Comments
 (0)