33
44from torch import Tensor
55from torch ._prims_common import is_float_dtype
6- from torch ._torch_docs import factory_common_args
6+ from torch ._torch_docs import factory_common_args , parse_kwargs , merge_dicts
77
88__all__ = [
99 'cosine' ,
1010 'exponential' ,
1111 'gaussian' ,
1212]
1313
14+ window_common_args = merge_dicts (
15+ parse_kwargs (
16+ """
17+ M (int): the length of the output window.
18+ In other words, the number of points of the exponential window.
19+ sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis.
20+ If `True`, returns a symmetric window suitable for use in filter design. Default: `True`.
21+ """
22+ ),
23+ factory_common_args
24+ )
25+
1426
1527def _add_docstr (* args ):
1628 r"""Adds docstrings to a given decorated function.
@@ -23,11 +35,13 @@ def _add_docstr(*args):
2335 Args:
2436 args (str):
2537 """
38+
2639 def decorator (o ):
2740 o .__doc__ = ""
2841 for arg in args :
2942 o .__doc__ += arg
3043 return o
44+
3145 return decorator
3246
3347
@@ -62,16 +76,12 @@ def _window_function_checks(function_name: str, M: int, dtype: torch.dtype, layo
6276 r"""
6377
6478Args:
65- M (int): the length of the output window.
66- In other words, the number of points of the exponential window.
79+ {M}
6780 center (float, optional): where the center of the window will be located.
6881 In other words, at which sample the peak of the window can be found.
6982 Default: `M / 2` if `sym` is `False`, else `(M - 1) / 2`.
7083 tau (float, optional): the decay value. Default: 1.0.
71- sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis.
72- If `True`, returns a symmetric window suitable for use in filter design. Default: `True`.
73- """ +
74- r"""
84+ {sym}
7585
7686.. note::
7787 The window is normalized to 1 (maximum value is 1), however, the 1 doesn't appear if `M` is even
@@ -94,7 +104,7 @@ def _window_function_checks(function_name: str, M: int, dtype: torch.dtype, layo
94104 tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00,
95105 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04])
96106 """ .format (
97- ** factory_common_args
107+ ** window_common_args
98108 ),
99109)
100110def exponential (
@@ -160,10 +170,8 @@ def exponential(
160170 r"""
161171
162172Args:
163- M (int): the length of the output window.
164- In other words, the number of points of the cosine window.
165- sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis.
166- If `True`, returns a symmetric window suitable for use in filter design. Default: `True`.
173+ {M}
174+ {sym}
167175
168176.. note::
169177 The window is normalized to 1 (maximum value is 1), however, the 1 doesn't appear if `M` is even
@@ -186,7 +194,7 @@ def exponential(
186194 tensor([0.1423, 0.4154, 0.6549, 0.8413, 0.9595, 1.0000, 0.9595, 0.8413, 0.6549,
187195 0.4154])
188196""" .format (
189- ** factory_common_args
197+ ** window_common_args ,
190198 ),
191199)
192200def cosine (M : int ,
@@ -237,12 +245,10 @@ def cosine(M: int,
237245 r"""
238246
239247Args:
240- M (int): the length of the output window.
241- In other words, the number of points of the cosine window.
248+ {M}
242249 std (float, optional): the standard deviation of the gaussian. It controls how narrow or wide the window is.
243250 Default: 1.0.
244- sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis.
245- If `True`, returns a symmetric window suitable for use in filter design. Default: `True`
251+ {sym}
246252
247253.. note::
248254 The window is normalized to 1 (maximum value is 1), however, the 1 doesn't appear if `M` is even
@@ -265,7 +271,7 @@ def cosine(M: int,
265271 tensor([1.9858e-07, 5.1365e-05, 3.8659e-03, 8.4658e-02, 5.3941e-01, 1.0000e+00,
266272 5.3941e-01, 8.4658e-02, 3.8659e-03, 5.1365e-05])
267273""" .format (
268- ** factory_common_args
274+ ** window_common_args ,
269275 ),
270276)
271277def gaussian (M : int ,
0 commit comments