Skip to content

Commit 19ac119

Browse files
committed
Update gaussian and exp windows
1 parent 8666c9d commit 19ac119

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

torch/signal/windows/windows.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def _window_function_checks(function_name: str, window_length: int, dtype: _dtyp
2020
dtype (:class:`torch.dtype`): the desired data type of the window tensor.
2121
layout (:class:`torch.layout`): the desired layout of the window tensor.
2222
"""
23+
2324
def is_floating_type(t: _dtype) -> bool:
2425
return t == torch.float32 or t == torch.bfloat16 or t == torch.float64 or t == torch.float16
2526

@@ -95,19 +96,19 @@ def exponential_window(window_length: int,
9596
if window_length == 1:
9697
return torch.ones((1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
9798

98-
if periodic:
99-
window_length += 1
100-
10199
if periodic and center is not None:
102100
raise ValueError('Center must be \'None\' for periodic equal True')
103101

104102
if center is None:
105-
center = (window_length - 1) / 2
103+
center = -(window_length if periodic else window_length - 1) / 2.0
106104

107-
k = torch.arange(window_length, dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
108-
window = torch.exp(-torch.abs(k - center) / tau)
105+
k = torch.arange(start=center,
106+
end=center + window_length,
107+
dtype=dtype, layout=layout,
108+
device=device,
109+
requires_grad=requires_grad)
109110

110-
return window[:-1] if periodic else window
111+
return torch.exp(-torch.abs(k) / tau)
111112

112113

113114
def cosine_window(window_length: int,
@@ -234,11 +235,13 @@ def gaussian_window(window_length: int,
234235
if window_length == 1:
235236
return torch.ones((1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
236237

237-
if periodic:
238-
window_length += 1
238+
start = -(window_length if periodic else window_length - 1) / 2.0
239239

240-
k = torch.arange(window_length, dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
241-
k = k - (window_length - 1.0) / 2.0
242-
window = torch.exp(-(k / std) ** 2 / 2)
240+
k = torch.arange(start,
241+
start + window_length,
242+
dtype=dtype,
243+
layout=layout,
244+
device=device,
245+
requires_grad=requires_grad)
243246

244-
return window[:-1] if periodic else window
247+
return torch.exp(-(k / std) ** 2 / 2)

0 commit comments

Comments
 (0)