Skip to content

Commit 1603573

Browse files
committed
update
1 parent 56aebc7 commit 1603573

File tree

5 files changed

+119
-68
lines changed

5 files changed

+119
-68
lines changed

docs/source/signal.windows.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
.. role:: hidden
2+
:class: hidden-section
3+
4+
torch.signal.windows
5+
=============
6+
7+
The torch.signal.windows module, modeled after SciPy's `special <https://docs.scipy.org/doc/scipy/reference/signal.windows.html>`_ module.
8+
9+
.. automodule:: torch.signal.windows
10+
.. currentmodule:: torch.signal.windows
11+
12+
Functions
13+
-----------------------
14+
15+
.. autofunction:: cosine_window
16+
.. autofunction:: exponential_window
17+
.. autofunction:: gaussian_window

torch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,7 @@ def _assert(condition, message):
846846
from torch import futures as futures
847847
from torch import nested as nested
848848
from torch import nn as nn
849+
from torch.signal import windows as windows
849850
from torch import optim as optim
850851
import torch.optim._multi_tensor
851852
from torch import multiprocessing as multiprocessing

torch/signal/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from .windows import cosine_window, exponential_window, chebyshev_window
1+
from .windows import cosine_window, exponential_window, gaussian_window
22

33
__all__ = [
44
'cosine_window',
55
'exponential_window',
6-
'chebyshev_window'
6+
'gaussian_window',
77
]

torch/signal/windows/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from .windows import cosine_window, exponential_window, chebyshev_window
1+
from .windows import cosine_window, exponential_window, gaussian_window
22

33
__all__ = [
44
'cosine_window',
55
'exponential_window',
6-
'chebyshev_window'
6+
'gaussian_window',
77
]

torch/signal/windows/windows.py

Lines changed: 97 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,91 +6,73 @@
66
from torch.types import _dtype, _device, _layout
77
from torch.fft import fft
88

9-
__all__ = ['cosine_window']
9+
__all__ = [
10+
'cosine_window',
11+
'exponential_window',
12+
'gaussian_window',
13+
]
1014

1115

12-
def _window_function_checks(function_name: str, window_length: int, dtype: _dtype, layout: _layout, device: _device):
16+
def _window_function_checks(function_name: str, window_length: int, dtype: _dtype, layout: _layout):
1317
def is_floating_type(t: _dtype) -> bool:
1418
return t == torch.float32 or t == torch.bfloat16 or t == torch.float64 or t == torch.float16
1519

1620
def is_complex_type(t: _dtype) -> bool:
1721
return t == torch.complex64 or t == torch.complex128 or t == torch.complex32
1822

1923
if window_length < 0:
20-
raise RuntimeError(f'{function_name} requires non-negative window_length, got window_length: {window_length}')
21-
if layout is torch.sparse:
22-
raise RuntimeError(f'{function_name} is not implemented for sparse types, got layout: {layout}')
24+
raise RuntimeError(f'{function_name} requires non-negative window_length, got window_length={window_length}')
25+
if layout is torch.sparse_coo:
26+
raise RuntimeError(f'{function_name} is not implemented for sparse types, got: {layout}')
2327
if not is_floating_type(dtype) and not is_complex_type(dtype):
2428
raise RuntimeError(f'{function_name} expects floating point dtypes, got: {dtype}')
2529

2630

27-
def chebyshev_window(window_length: int,
28-
attenuation: float,
29-
periodic: bool = True,
30-
dtype: _dtype = None,
31-
layout: _layout = torch.strided,
32-
device: _device = None) -> Tensor:
33-
_window_function_checks('chebyshev_window', window_length, dtype, layout, device)
34-
35-
if window_length == 0:
36-
return torch.empty((0,), dtype=dtype, layout=layout, device=device)
37-
38-
if window_length == 1:
39-
return torch.ones((1,), dtype=dtype, layout=layout, device=device)
40-
41-
if not periodic:
42-
window_length += 1
43-
44-
k = torch.arange(window_length, dtype=dtype, layout=layout, device=device)
45-
46-
order = window_length - 1
47-
beta = np.cosh(1.0 / order * np.arccosh(np.power(10, attenuation / 20.0)))
48-
49-
x = beta * torch.cos(torch.pi * k / window_length)
50-
window = torch.special.chebyshev_polynomial_t(x, order) / np.power(10, attenuation / 20.0)
51-
52-
if window_length % 2 != 0:
53-
window = torch.real(fft(window))
54-
n = (window_length + 1) // 2
55-
window = torch.concat((torch.flip(window[1:n], (0,)), window[:n]))
56-
else:
57-
window = window * torch.exp(1.j * torch.pi / window_length * torch.arange(window_length))
58-
window = torch.real(fft(window))
59-
n = window_length // 2 + 1
60-
window = torch.concat((torch.flip(window[1:n], (0,)), window[1:n]))
61-
62-
window /= torch.max(window)
63-
64-
return window if periodic else window[:window_length - 1]
65-
66-
6731
def exponential_window(window_length: int,
6832
periodic: bool = True,
6933
center: float = None,
7034
tau: float = 1.0,
7135
dtype: _dtype = None,
7236
layout: _layout = torch.strided,
73-
device: _device = None) -> Tensor:
37+
device: _device = None,
38+
requires_grad: bool = False) -> Tensor:
7439
"""r
75-
Computes a window with a simple cosine waveform.
40+
Computes a window with an exponential form. The window
41+
is also known as Poisson window.
42+
43+
The exponential window is defined as follows:
44+
45+
.. math::
46+
w(n) = \exp{-\frac{|n - center|}{\tau}}
7647
7748
Args:
78-
window_length:
49+
window_length: the length of the output window. In other words, the number of points of the cosine window.
50+
periodic: If `True`, returns a periodic window suitable for use in spectral analysis. If `False`,
51+
returns a symmetric window suitable for use in filter design.
52+
center: this value defines where the center of the window will be located. In other words, at which
53+
sample the peak of the window can be found.
54+
tau: the decay value. For `center = 0`, it's suggested to use `tau = -(M - 1) / ln(x)`, if `` is
55+
the fraction of the window remaining at the end.
7956
8057
Keyword args:
8158
{dtype}
59+
layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only
60+
`torch.strided` (dense layout) is supported.
8261
{device}
83-
62+
{requires_grad}
8463
"""
64+
if dtype is None:
65+
dtype = torch.get_default_dtype()
66+
8567
_window_function_checks('exponential_window', window_length, dtype, layout, device)
8668

8769
if window_length == 0:
88-
return torch.empty((0,), dtype=dtype, layout=layout, device=device)
70+
return torch.empty((0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
8971

9072
if window_length == 1:
91-
return torch.ones((1,), dtype=dtype, layout=layout, device=device)
73+
return torch.ones((1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
9274

93-
if not periodic:
75+
if periodic:
9476
window_length += 1
9577

9678
if periodic and center is not None:
@@ -99,17 +81,18 @@ def exponential_window(window_length: int,
9981
if center is None:
10082
center = (window_length - 1) / 2
10183

102-
k = torch.arange(window_length, dtype=dtype, layout=layout, device=device)
84+
k = torch.arange(window_length, dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
10385
window = torch.exp(-torch.abs(k - center) / tau)
10486

105-
return window if periodic else window[:window_length - 1]
87+
return window[:-1] if periodic else window
10688

10789

10890
def cosine_window(window_length: int,
10991
periodic: bool = True,
11092
dtype: _dtype = None,
11193
layout: _layout = torch.strided,
112-
device: _device = None) -> Tensor:
94+
device: _device = None,
95+
requires_grad: bool = False) -> Tensor:
11396
"""r
11497
Computes a window with a simple cosine waveform.
11598
@@ -119,7 +102,7 @@ def cosine_window(window_length: int,
119102
.. math::
120103
w(n) = \cos{(\frac{\pi n}{M}) - \frac{\pi}{2})} = \sin{(\frac{\pi n}{M})}
121104
122-
Where `M
105+
Where `M` is the window length.
123106
124107
125108
Args:
@@ -134,20 +117,70 @@ def cosine_window(window_length: int,
134117
{device}
135118
{requires_grad}
136119
"""
120+
if dtype is None:
121+
dtype = torch.get_default_dtype()
122+
137123
_window_function_checks('cosine_window', window_length, dtype, layout, device)
138124

139125
if window_length == 0:
140-
return torch.empty((0,), dtype=dtype, layout=layout, device=device)
126+
return torch.empty((0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
141127

142128
if window_length == 1:
143-
return torch.ones((1,), dtype=dtype, layout=layout, device=device)
129+
return torch.ones((1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
144130

145-
if not periodic:
131+
if periodic:
146132
window_length += 1
147133

148-
# k = torch.arange(window_length, dtype=dtype, layout=layout, device=device)
149-
k = np.arange(0, window_length)
150-
window = np.sin(np.pi / window_length * (k + .5))
151-
window = torch.from_numpy(window)
134+
k = torch.arange(window_length, dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
135+
window = torch.sin(torch.pi / window_length * (k + .5))
136+
return window[:-1] if periodic else window
137+
138+
139+
def gaussian_window(window_length: int,
140+
periodic: bool = True,
141+
std: float = 0.5,
142+
dtype: _dtype = None,
143+
layout: _layout = torch.strided,
144+
device: _device = None,
145+
requires_grad: bool = False) -> Tensor:
146+
"""r
147+
Computes a window with a gaussian waveform.
148+
149+
The gaussian window is defined as follows:
150+
151+
.. math::
152+
w(n) = \exp{-\frac{1}{2}\frac{n}{\sigma}^2}
153+
154+
Args:
155+
window_length: the length of the output window. In other words, the number of points of the cosine window.
156+
periodic: If `True`, returns a periodic window suitable for use in spectral analysis. If `False`,
157+
returns a symmetric window suitable for use in filter design.
158+
std: the standard deviation of the gaussian. It controls how narrow or wide the window is.
159+
160+
161+
Keyword args:
162+
{dtype}
163+
layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only
164+
`torch.strided` (dense layout) is supported.
165+
{device}
166+
{requires_grad}
167+
"""
168+
if dtype is None:
169+
dtype = torch.get_default_dtype()
170+
171+
_window_function_checks('cosine_window', window_length, dtype, layout, device)
172+
173+
if window_length == 0:
174+
return torch.empty((0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
175+
176+
if window_length == 1:
177+
return torch.ones((1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
178+
179+
if periodic:
180+
window_length += 1
152181

153-
return window if periodic else window[:window_length - 1]
182+
k = torch.arange(window_length, dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
183+
k = k - (window_length - 1.0) / 2.0
184+
sig2 = 2 * std * std
185+
window = torch.exp(-k ** 2 / sig2)
186+
return window[:-1] if periodic else window

0 commit comments

Comments
 (0)