Skip to content

Commit 3fbecdd

Browse files
committed
Update docstrings
1 parent f564293 commit 3fbecdd

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

torch/signal/windows/windows.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,26 @@
33

44
from torch import Tensor
55
from torch._prims_common import is_float_dtype
6-
from torch._torch_docs import factory_common_args
6+
from torch._torch_docs import factory_common_args, parse_kwargs, merge_dicts
77

88
__all__ = [
99
'cosine',
1010
'exponential',
1111
'gaussian',
1212
]
1313

14+
window_common_args = merge_dicts(
15+
parse_kwargs(
16+
"""
17+
M (int): the length of the output window.
18+
In other words, the number of points of the exponential window.
19+
sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis.
20+
If `True`, returns a symmetric window suitable for use in filter design. Default: `True`.
21+
"""
22+
),
23+
factory_common_args
24+
)
25+
1426

1527
def _add_docstr(*args):
1628
r"""Adds docstrings to a given decorated function.
@@ -23,11 +35,13 @@ def _add_docstr(*args):
2335
Args:
2436
args (str):
2537
"""
38+
2639
def decorator(o):
2740
o.__doc__ = ""
2841
for arg in args:
2942
o.__doc__ += arg
3043
return o
44+
3145
return decorator
3246

3347

@@ -62,16 +76,12 @@ def _window_function_checks(function_name: str, M: int, dtype: torch.dtype, layo
6276
r"""
6377
6478
Args:
65-
M (int): the length of the output window.
66-
In other words, the number of points of the exponential window.
79+
{M}
6780
center (float, optional): where the center of the window will be located.
6881
In other words, at which sample the peak of the window can be found.
6982
Default: `M / 2` if `sym` is `False`, else `(M - 1) / 2`.
7083
tau (float, optional): the decay value. Default: 1.0.
71-
sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis.
72-
If `True`, returns a symmetric window suitable for use in filter design. Default: `True`.
73-
""" +
74-
r"""
84+
{sym}
7585
7686
.. note::
7787
The window is normalized to 1 (maximum value is 1), however, the 1 doesn't appear if `M` is even
@@ -94,7 +104,7 @@ def _window_function_checks(function_name: str, M: int, dtype: torch.dtype, layo
94104
tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00,
95105
1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04])
96106
""".format(
97-
**factory_common_args
107+
**window_common_args
98108
),
99109
)
100110
def exponential(
@@ -160,10 +170,8 @@ def exponential(
160170
r"""
161171
162172
Args:
163-
M (int): the length of the output window.
164-
In other words, the number of points of the cosine window.
165-
sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis.
166-
If `True`, returns a symmetric window suitable for use in filter design. Default: `True`.
173+
{M}
174+
{sym}
167175
168176
.. note::
169177
The window is normalized to 1 (maximum value is 1), however, the 1 doesn't appear if `M` is even
@@ -186,7 +194,7 @@ def exponential(
186194
tensor([0.1423, 0.4154, 0.6549, 0.8413, 0.9595, 1.0000, 0.9595, 0.8413, 0.6549,
187195
0.4154])
188196
""".format(
189-
**factory_common_args
197+
**window_common_args,
190198
),
191199
)
192200
def cosine(M: int,
@@ -237,12 +245,10 @@ def cosine(M: int,
237245
r"""
238246
239247
Args:
240-
M (int): the length of the output window.
241-
In other words, the number of points of the cosine window.
248+
{M}
242249
std (float, optional): the standard deviation of the gaussian. It controls how narrow or wide the window is.
243250
Default: 1.0.
244-
sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis.
245-
If `True`, returns a symmetric window suitable for use in filter design. Default: `True`
251+
{sym}
246252
247253
.. note::
248254
The window is normalized to 1 (maximum value is 1), however, the 1 doesn't appear if `M` is even
@@ -265,7 +271,7 @@ def cosine(M: int,
265271
tensor([1.9858e-07, 5.1365e-05, 3.8659e-03, 8.4658e-02, 5.3941e-01, 1.0000e+00,
266272
5.3941e-01, 8.4658e-02, 3.8659e-03, 5.1365e-05])
267273
""".format(
268-
**factory_common_args
274+
**window_common_args,
269275
),
270276
)
271277
def gaussian(M: int,

0 commit comments

Comments
 (0)