66from torch .types import _dtype , _device , _layout
77from torch .fft import fft
88
9- __all__ = ['cosine_window' ]
9+ __all__ = [
10+ 'cosine_window' ,
11+ 'exponential_window' ,
12+ 'gaussian_window' ,
13+ ]
1014
1115
12- def _window_function_checks (function_name : str , window_length : int , dtype : _dtype , layout : _layout , device : _device ):
16+ def _window_function_checks (function_name : str , window_length : int , dtype : _dtype , layout : _layout ):
1317 def is_floating_type (t : _dtype ) -> bool :
1418 return t == torch .float32 or t == torch .bfloat16 or t == torch .float64 or t == torch .float16
1519
1620 def is_complex_type (t : _dtype ) -> bool :
1721 return t == torch .complex64 or t == torch .complex128 or t == torch .complex32
1822
1923 if window_length < 0 :
20- raise RuntimeError (f'{ function_name } requires non-negative window_length, got window_length: { window_length } ' )
21- if layout is torch .sparse :
22- raise RuntimeError (f'{ function_name } is not implemented for sparse types, got layout : { layout } ' )
24+ raise RuntimeError (f'{ function_name } requires non-negative window_length, got window_length= { window_length } ' )
25+ if layout is torch .sparse_coo :
26+ raise RuntimeError (f'{ function_name } is not implemented for sparse types, got: { layout } ' )
2327 if not is_floating_type (dtype ) and not is_complex_type (dtype ):
2428 raise RuntimeError (f'{ function_name } expects floating point dtypes, got: { dtype } ' )
2529
2630
27- def chebyshev_window (window_length : int ,
28- attenuation : float ,
29- periodic : bool = True ,
30- dtype : _dtype = None ,
31- layout : _layout = torch .strided ,
32- device : _device = None ) -> Tensor :
33- _window_function_checks ('chebyshev_window' , window_length , dtype , layout , device )
34-
35- if window_length == 0 :
36- return torch .empty ((0 ,), dtype = dtype , layout = layout , device = device )
37-
38- if window_length == 1 :
39- return torch .ones ((1 ,), dtype = dtype , layout = layout , device = device )
40-
41- if not periodic :
42- window_length += 1
43-
44- k = torch .arange (window_length , dtype = dtype , layout = layout , device = device )
45-
46- order = window_length - 1
47- beta = np .cosh (1.0 / order * np .arccosh (np .power (10 , attenuation / 20.0 )))
48-
49- x = beta * torch .cos (torch .pi * k / window_length )
50- window = torch .special .chebyshev_polynomial_t (x , order ) / np .power (10 , attenuation / 20.0 )
51-
52- if window_length % 2 != 0 :
53- window = torch .real (fft (window ))
54- n = (window_length + 1 ) // 2
55- window = torch .concat ((torch .flip (window [1 :n ], (0 ,)), window [:n ]))
56- else :
57- window = window * torch .exp (1.j * torch .pi / window_length * torch .arange (window_length ))
58- window = torch .real (fft (window ))
59- n = window_length // 2 + 1
60- window = torch .concat ((torch .flip (window [1 :n ], (0 ,)), window [1 :n ]))
61-
62- window /= torch .max (window )
63-
64- return window if periodic else window [:window_length - 1 ]
65-
66-
6731def exponential_window (window_length : int ,
6832 periodic : bool = True ,
6933 center : float = None ,
7034 tau : float = 1.0 ,
7135 dtype : _dtype = None ,
7236 layout : _layout = torch .strided ,
73- device : _device = None ) -> Tensor :
37+ device : _device = None ,
38+ requires_grad : bool = False ) -> Tensor :
7439 """r
75- Computes a window with a simple cosine waveform.
40+ Computes a window with an exponential form. The window
41+ is also known as Poisson window.
42+
43+ The exponential window is defined as follows:
44+
45+ .. math::
46+ w(n) = \exp{-\f rac{|n - center|}{\t au}}
7647
7748 Args:
78- window_length:
49+ window_length: the length of the output window. In other words, the number of points of the cosine window.
50+ periodic: If `True`, returns a periodic window suitable for use in spectral analysis. If `False`,
51+ returns a symmetric window suitable for use in filter design.
52+ center: this value defines where the center of the window will be located. In other words, at which
53+ sample the peak of the window can be found.
54+ tau: the decay value. For `center = 0`, it's suggested to use `tau = -(M - 1) / ln(x)`, if `` is
55+ the fraction of the window remaining at the end.
7956
8057 Keyword args:
8158 {dtype}
59+ layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only
60+ `torch.strided` (dense layout) is supported.
8261 {device}
83-
62+ {requires_grad}
8463 """
64+ if dtype is None :
65+ dtype = torch .get_default_dtype ()
66+
8567 _window_function_checks ('exponential_window' , window_length , dtype , layout , device )
8668
8769 if window_length == 0 :
88- return torch .empty ((0 ,), dtype = dtype , layout = layout , device = device )
70+ return torch .empty ((0 ,), dtype = dtype , layout = layout , device = device , requires_grad = requires_grad )
8971
9072 if window_length == 1 :
91- return torch .ones ((1 ,), dtype = dtype , layout = layout , device = device )
73+ return torch .ones ((1 ,), dtype = dtype , layout = layout , device = device , requires_grad = requires_grad )
9274
93- if not periodic :
75+ if periodic :
9476 window_length += 1
9577
9678 if periodic and center is not None :
@@ -99,17 +81,18 @@ def exponential_window(window_length: int,
9981 if center is None :
10082 center = (window_length - 1 ) / 2
10183
102- k = torch .arange (window_length , dtype = dtype , layout = layout , device = device )
84+ k = torch .arange (window_length , dtype = dtype , layout = layout , device = device , requires_grad = requires_grad )
10385 window = torch .exp (- torch .abs (k - center ) / tau )
10486
105- return window if periodic else window [: window_length - 1 ]
87+ return window [: - 1 ] if periodic else window
10688
10789
10890def cosine_window (window_length : int ,
10991 periodic : bool = True ,
11092 dtype : _dtype = None ,
11193 layout : _layout = torch .strided ,
112- device : _device = None ) -> Tensor :
94+ device : _device = None ,
95+ requires_grad : bool = False ) -> Tensor :
11396 """r
11497 Computes a window with a simple cosine waveform.
11598
@@ -119,7 +102,7 @@ def cosine_window(window_length: int,
119102 .. math::
120103 w(n) = \cos{(\f rac{\pi n}{M}) - \f rac{\pi}{2})} = \sin{(\f rac{\pi n}{M})}
121104
122- Where `M
105+ Where `M` is the window length.
123106
124107
125108 Args:
@@ -134,20 +117,70 @@ def cosine_window(window_length: int,
134117 {device}
135118 {requires_grad}
136119 """
120+ if dtype is None :
121+ dtype = torch .get_default_dtype ()
122+
137123 _window_function_checks ('cosine_window' , window_length , dtype , layout , device )
138124
139125 if window_length == 0 :
140- return torch .empty ((0 ,), dtype = dtype , layout = layout , device = device )
126+ return torch .empty ((0 ,), dtype = dtype , layout = layout , device = device , requires_grad = requires_grad )
141127
142128 if window_length == 1 :
143- return torch .ones ((1 ,), dtype = dtype , layout = layout , device = device )
129+ return torch .ones ((1 ,), dtype = dtype , layout = layout , device = device , requires_grad = requires_grad )
144130
145- if not periodic :
131+ if periodic :
146132 window_length += 1
147133
148- # k = torch.arange(window_length, dtype=dtype, layout=layout, device=device)
149- k = np .arange (0 , window_length )
150- window = np .sin (np .pi / window_length * (k + .5 ))
151- window = torch .from_numpy (window )
134+ k = torch .arange (window_length , dtype = dtype , layout = layout , device = device , requires_grad = requires_grad )
135+ window = torch .sin (torch .pi / window_length * (k + .5 ))
136+ return window [:- 1 ] if periodic else window
137+
138+
139+ def gaussian_window (window_length : int ,
140+ periodic : bool = True ,
141+ std : float = 0.5 ,
142+ dtype : _dtype = None ,
143+ layout : _layout = torch .strided ,
144+ device : _device = None ,
145+ requires_grad : bool = False ) -> Tensor :
146+ """r
147+ Computes a window with a gaussian waveform.
148+
149+ The gaussian window is defined as follows:
150+
151+ .. math::
152+ w(n) = \exp{-\f rac{1}{2}\f rac{n}{\sigma}^2}
153+
154+ Args:
155+ window_length: the length of the output window. In other words, the number of points of the cosine window.
156+ periodic: If `True`, returns a periodic window suitable for use in spectral analysis. If `False`,
157+ returns a symmetric window suitable for use in filter design.
158+ std: the standard deviation of the gaussian. It controls how narrow or wide the window is.
159+
160+
161+ Keyword args:
162+ {dtype}
163+ layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only
164+ `torch.strided` (dense layout) is supported.
165+ {device}
166+ {requires_grad}
167+ """
168+ if dtype is None :
169+ dtype = torch .get_default_dtype ()
170+
171+ _window_function_checks ('cosine_window' , window_length , dtype , layout , device )
172+
173+ if window_length == 0 :
174+ return torch .empty ((0 ,), dtype = dtype , layout = layout , device = device , requires_grad = requires_grad )
175+
176+ if window_length == 1 :
177+ return torch .ones ((1 ,), dtype = dtype , layout = layout , device = device , requires_grad = requires_grad )
178+
179+ if periodic :
180+ window_length += 1
152181
153- return window if periodic else window [:window_length - 1 ]
182+ k = torch .arange (window_length , dtype = dtype , layout = layout , device = device , requires_grad = requires_grad )
183+ k = k - (window_length - 1.0 ) / 2.0
184+ sig2 = 2 * std * std
185+ window = torch .exp (- k ** 2 / sig2 )
186+ return window [:- 1 ] if periodic else window
0 commit comments