Skip to content

Commit f5e1c45

Browse files
committed
Update sqrt and ref kwargs
1 parent 7a6f207 commit f5e1c45

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

test/test_signal.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def _test_window(self, device, dtype, op: OpInfo, **kwargs):
3232
window_name = re.search(self.supported_windows, op.name).group(0)
3333

3434
ref_kwargs = {
35-
k: sample_input.kwargs[k] for k in sample_input.kwargs if k not in ('device', 'dtype', 'requires_grad', 'periodic')
35+
k: sample_input.kwargs[k] for k in sample_input.kwargs
36+
if k not in ('device', 'dtype', 'requires_grad', 'periodic')
3637
}
3738

3839
expected = torch.from_numpy(

torch/signal/windows/windows.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import numpy as np
32

43
from torch import Tensor
54
from torch._torch_docs import factory_common_args
@@ -299,7 +298,7 @@ def gaussian(window_length: int,
299298

300299
start = -(window_length if periodic else window_length - 1) / 2.0
301300

302-
constant = 1 / (std * np.sqrt(2))
301+
constant = 1 / (std * torch.sqrt(torch.tensor(2)).item())
303302

304303
"""
305304
Note that non-integer step is subject to floating point rounding errors when comparing against end;

0 commit comments

Comments
 (0)