@@ -20,6 +20,7 @@ def _window_function_checks(function_name: str, window_length: int, dtype: _dtyp
2020 dtype (:class:`torch.dtype`): the desired data type of the window tensor.
2121 layout (:class:`torch.layout`): the desired layout of the window tensor.
2222 """
23+
2324 def is_floating_type (t : _dtype ) -> bool :
2425 return t == torch .float32 or t == torch .bfloat16 or t == torch .float64 or t == torch .float16
2526
@@ -95,19 +96,19 @@ def exponential_window(window_length: int,
9596 if window_length == 1 :
9697 return torch .ones ((1 ,), dtype = dtype , layout = layout , device = device , requires_grad = requires_grad )
9798
98- if periodic :
99- window_length += 1
100-
10199 if periodic and center is not None :
102100 raise ValueError ('Center must be \' None\' for periodic equal True' )
103101
104102 if center is None :
105- center = (window_length - 1 ) / 2
103+ center = - (window_length if periodic else window_length - 1 ) / 2.0
106104
107- k = torch .arange (window_length , dtype = dtype , layout = layout , device = device , requires_grad = requires_grad )
108- window = torch .exp (- torch .abs (k - center ) / tau )
105+ k = torch .arange (start = center ,
106+ end = center + window_length ,
107+ dtype = dtype , layout = layout ,
108+ device = device ,
109+ requires_grad = requires_grad )
109110
110- return window [: - 1 ] if periodic else window
111+ return torch . exp ( - torch . abs ( k ) / tau )
111112
112113
113114def cosine_window (window_length : int ,
@@ -234,11 +235,13 @@ def gaussian_window(window_length: int,
234235 if window_length == 1 :
235236 return torch .ones ((1 ,), dtype = dtype , layout = layout , device = device , requires_grad = requires_grad )
236237
237- if periodic :
238- window_length += 1
238+ start = - (window_length if periodic else window_length - 1 ) / 2.0
239239
240- k = torch .arange (window_length , dtype = dtype , layout = layout , device = device , requires_grad = requires_grad )
241- k = k - (window_length - 1.0 ) / 2.0
242- window = torch .exp (- (k / std ) ** 2 / 2 )
240+ k = torch .arange (start ,
241+ start + window_length ,
242+ dtype = dtype ,
243+ layout = layout ,
244+ device = device ,
245+ requires_grad = requires_grad )
243246
244- return window [: - 1 ] if periodic else window
247+ return torch . exp ( - ( k / std ) ** 2 / 2 )
0 commit comments