Skip to content

Commit 7a6f207

Browse files
committed
Update OpInfo and tests
1 parent 23485c5 commit 7a6f207

File tree

3 files changed

+89
-81
lines changed

3 files changed

+89
-81
lines changed

test/test_signal.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# Owner(s): ["module: signal"]
22

3-
import random
4-
53
import torch
64
import unittest
75
import re
@@ -13,7 +11,7 @@
1311
ops, instantiate_device_type_tests, OpDTypes
1412
)
1513
from torch.testing._internal.common_methods_invocations import (
16-
precisionOverride, signal_funcs
14+
precisionOverride, op_db
1715
)
1816
from torch.testing._internal.opinfo.core import OpInfo
1917

@@ -27,20 +25,24 @@ def _test_window(self, device, dtype, op: OpInfo, **kwargs):
2725
if op.ref is None:
2826
raise unittest.SkipTest("No reference implementation")
2927

30-
sample_inputs = op.sample_inputs(device, dtype, **kwargs)
28+
sample_inputs = op.sample_inputs(device, dtype, False, **kwargs)
3129

3230
for sample_input in sample_inputs:
3331
window_size = sample_input.input
3432
window_name = re.search(self.supported_windows, op.name).group(0)
35-
periodic = sample_input.kwargs.pop('periodic')
33+
34+
ref_kwargs = {
35+
k: sample_input.kwargs[k] for k in sample_input.kwargs if k not in ('device', 'dtype', 'requires_grad', 'periodic')
36+
}
3637

3738
expected = torch.from_numpy(
38-
op.ref((window_name, *(sample_input.kwargs.values())), window_size, fftbins=periodic)
39+
op.ref((window_name, *(ref_kwargs.values())), window_size, fftbins=sample_input.kwargs['periodic'])
3940
)
40-
actual = op(window_size, periodic=periodic, **sample_input.kwargs)
41+
actual = op(window_size, **sample_input.kwargs)
4142
self.assertEqual(actual, expected, exact_dtype=self.exact_dtype)
42-
self.assertTrue(op(3, requires_grad=True).requires_grad)
43-
self.assertFalse(op(3).requires_grad)
43+
44+
self.assertTrue(op(3, requires_grad=True).requires_grad)
45+
self.assertFalse(op(3).requires_grad)
4446

4547
def _test_window_errors(self, device, op):
4648
error_inputs = op.error_inputs(device)
@@ -50,29 +52,15 @@ def _test_window_errors(self, device, op):
5052
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
5153
op(sample_input.input, *sample_input.args, **sample_input.kwargs)
5254

53-
@ops([op for op in signal_funcs if 'windows' in op.name], dtypes=OpDTypes.none)
55+
@ops([op for op in op_db if 'windows' in op.name], dtypes=OpDTypes.none)
5456
def test_window_errors(self, device, op):
5557
self._test_window_errors(device, op)
5658

5759
@precisionOverride({torch.bfloat16: 5e-2, torch.half: 1e-3})
58-
@ops([op for op in signal_funcs if 'windows.cosine' in op.name],
59-
allowed_dtypes=(torch.float, torch.double, torch.long))
60-
def test_cosine_window(self, device, dtype, op):
61-
self._test_window(device, dtype, op)
62-
63-
@precisionOverride({torch.bfloat16: 5e-2, torch.half: 1e-3})
64-
@ops([op for op in signal_funcs if 'windows.exponential' in op.name],
60+
@ops([op for op in op_db if 'windows' in op.name],
6561
allowed_dtypes=(torch.float, torch.double))
66-
def test_exponential_window(self, device, dtype, op):
67-
for _ in range(50):
68-
self._test_window(device, dtype, op, center=None, tau=random.uniform(0, 10))
69-
70-
@precisionOverride({torch.bfloat16: 5e-2, torch.half: 1e-3})
71-
@ops([op for op in signal_funcs if 'windows.gaussian' in op.name],
72-
allowed_dtypes=(torch.float, torch.double, torch.long))
73-
def test_gaussian_window(self, device, dtype, op):
74-
for _ in range(50):
75-
self._test_window(device, dtype, op, std=random.uniform(0, 3))
62+
def test_windows(self, device, dtype, op):
63+
self._test_window(device, dtype, op)
7664

7765

7866
instantiate_device_type_tests(TestSignalWindows, globals())

torch/testing/_internal/opinfo/definitions/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,4 @@
2222
*fft.python_ref_db,
2323
*linalg.python_ref_db,
2424
*special.python_ref_db,
25-
*signal.python_ref_db,
2625
]

torch/testing/_internal/opinfo/definitions/signal.py

Lines changed: 74 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import random
12
import unittest
3+
4+
from itertools import product
25
from typing import List
36

47
import torch
@@ -15,24 +18,39 @@
1518
import scipy.signal
1619

1720

18-
def _sample_input_windows(sample_input, *args, **kwargs):
19-
for size in [0, 1, 2, 5, 10, 50, 100, 1024, 2048]:
20-
for periodic in [True, False]:
21-
kwargs.update(
22-
{
23-
"periodic": periodic,
24-
}
25-
)
26-
yield sample_input(size, args=args, kwargs=kwargs)
21+
def sample_inputs_window(op_info, device, dtype, requires_grad, *, kws=({},), **kwargs):
22+
_kwargs = dict(
23+
kwargs, **{"device": device, "dtype": dtype, "requires_grad": requires_grad}
24+
)
25+
sizes = [0, 1, 2, 5, 10, 50, 100, 1024, 2048]
26+
for size, periodic, k in product(sizes, (True, False), kws):
27+
yield SampleInput(size, periodic=periodic, **k, **_kwargs)
28+
29+
30+
def sample_inputs_gaussian_window(op_info, device, dtype, requires_grad, **kwargs):
31+
kws = [{"std": random.uniform(0, 3)} for _ in range(50)]
32+
yield from sample_inputs_window(
33+
op_info, device, dtype, requires_grad, kws=kws, **kwargs
34+
)
2735

2836

29-
def sample_inputs_window(op_info, *args, **kwargs):
30-
return _sample_input_windows(SampleInput, *args, **kwargs)
37+
def sample_inputs_exponential_window(op_info, device, dtype, requires_grad, **kwargs):
38+
kws = [{"center": None, "tau": random.uniform(0, 10)} for _ in range(50)]
39+
yield from sample_inputs_window(
40+
op_info, device, dtype, requires_grad, kws=kws, **kwargs
41+
)
3142

3243

33-
def error_inputs_window(op_info, *args, **kwargs):
44+
def error_inputs_window(op_info, device, **kwargs):
45+
tmp_kwargs = dict(
46+
kwargs,
47+
**{
48+
"device": device,
49+
},
50+
)
51+
3452
yield ErrorInput(
35-
SampleInput(-1, args=args, kwargs=kwargs),
53+
SampleInput(-1, kwargs=tmp_kwargs),
3654
error_type=ValueError,
3755
error_regex="requires non-negative window_length, got window_length=-1",
3856
)
@@ -45,71 +63,88 @@ def error_inputs_window(op_info, *args, **kwargs):
4563
)
4664

4765
yield ErrorInput(
48-
SampleInput(3, args=args, kwargs=tmp_kwargs),
66+
SampleInput(3, kwargs=tmp_kwargs),
4967
error_type=ValueError,
5068
error_regex="is not implemented for sparse types, got: torch.sparse_coo",
5169
)
5270

5371
tmp_kwargs = kwargs
54-
tmp_kwargs["dtype"] = torch.long
72+
tmp_kwargs.update(
73+
{
74+
"dtype": torch.long,
75+
}
76+
)
5577

5678
yield ErrorInput(
57-
SampleInput(3, args=args, kwargs=tmp_kwargs),
79+
SampleInput(3, kwargs=tmp_kwargs),
5880
error_type=ValueError,
5981
error_regex="expects floating point dtypes, got: torch.int64",
6082
)
6183

6284

63-
def error_inputs_exponential_window(op_info, device, *args, **kwargs):
64-
tmp_kwargs = dict(kwargs, **{"dtype": torch.float32, "device": device})
85+
def error_inputs_exponential_window(op_info, device, **kwargs):
86+
tmp_kwargs = dict(kwargs, **{"dtype": torch.float32})
87+
88+
yield from error_inputs_window(op_info, device, **tmp_kwargs)
6589

66-
for error_input in error_inputs_window(op_info, *args, **kwargs):
67-
yield error_input
90+
tmp_kwargs.update({"device": device})
6891

6992
tmp_kwargs = dict(tmp_kwargs, **{"tau": -1})
7093

7194
yield ErrorInput(
72-
SampleInput(3, args=args, kwargs=tmp_kwargs),
95+
SampleInput(3, kwargs=tmp_kwargs),
7396
error_type=ValueError,
7497
error_regex="Tau must be positive, got: -1 instead.",
7598
)
7699

77100
tmp_kwargs = dict(tmp_kwargs, **{"center": 1})
78101

79102
yield ErrorInput(
80-
SampleInput(3, args=args, kwargs=tmp_kwargs),
103+
SampleInput(3, kwargs=tmp_kwargs),
81104
error_type=ValueError,
82105
error_regex="Center must be 'None' for periodic equal True",
83106
)
84107

85108

86-
def error_inputs_gaussian_window(op_info, device, *args, **kwargs):
109+
def error_inputs_gaussian_window(op_info, device, **kwargs):
87110
tmp_kwargs = dict(kwargs, **{"dtype": torch.float32, "device": device})
88111

89-
for error_input in error_inputs_window(op_info, *args, **kwargs):
90-
yield error_input
112+
yield from error_inputs_window(op_info, device, **kwargs)
91113

92114
tmp_kwargs = dict(tmp_kwargs, **{"std": -1})
93115

94116
yield ErrorInput(
95-
SampleInput(3, args=args, kwargs=tmp_kwargs),
117+
SampleInput(3, kwargs=tmp_kwargs),
96118
error_type=ValueError,
97119
error_regex="Standard deviation must be positive, got: -1 instead.",
98120
)
99121

100122

101-
op_db: List[OpInfo] = [
102-
OpInfo(
103-
"signal.windows.cosine",
123+
def make_signal_windows_opinfo(
124+
name, variant_test_name, sample_inputs_func, error_inputs_func, *, skips=()
125+
):
126+
return OpInfo(
127+
name=name,
128+
variant_test_name=variant_test_name,
104129
ref=scipy.signal.get_window if TEST_SCIPY else None,
105130
dtypes=all_types_and(torch.float, torch.double, torch.long),
106131
dtypesIfCUDA=all_types_and(
107132
torch.float, torch.double, torch.bfloat16, torch.half, torch.long
108133
),
109-
sample_inputs_func=sample_inputs_window,
110-
error_inputs_func=error_inputs_window,
134+
sample_inputs_func=sample_inputs_func,
135+
error_inputs_func=error_inputs_func,
111136
supports_out=False,
112137
supports_autograd=False,
138+
skips=skips,
139+
)
140+
141+
142+
op_db: List[OpInfo] = [
143+
make_signal_windows_opinfo(
144+
"signal.windows.cosine",
145+
"signal.windows.cosine_default",
146+
sample_inputs_window,
147+
error_inputs_window,
113148
skips=(
114149
DecorateInfo(
115150
unittest.expectedFailure,
@@ -147,17 +182,11 @@ def error_inputs_gaussian_window(op_info, device, *args, **kwargs):
147182
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"),
148183
),
149184
),
150-
OpInfo(
185+
make_signal_windows_opinfo(
151186
"signal.windows.exponential",
152-
ref=scipy.signal.get_window if TEST_SCIPY else None,
153-
dtypes=all_types_and(torch.float, torch.double, torch.long),
154-
dtypesIfCUDA=all_types_and(
155-
torch.float, torch.double, torch.bfloat16, torch.half, torch.long
156-
),
157-
sample_inputs_func=sample_inputs_window,
158-
error_inputs_func=error_inputs_exponential_window,
159-
supports_out=False,
160-
supports_autograd=False,
187+
"signal.windows.exponential_default",
188+
sample_inputs_exponential_window,
189+
error_inputs_exponential_window,
161190
skips=(
162191
DecorateInfo(
163192
unittest.expectedFailure,
@@ -195,17 +224,11 @@ def error_inputs_gaussian_window(op_info, device, *args, **kwargs):
195224
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"),
196225
),
197226
),
198-
OpInfo(
227+
make_signal_windows_opinfo(
199228
"signal.windows.gaussian",
200-
ref=scipy.signal.get_window if TEST_SCIPY else None,
201-
dtypes=all_types_and(torch.float, torch.double, torch.long),
202-
dtypesIfCUDA=all_types_and(
203-
torch.float, torch.double, torch.bfloat16, torch.half, torch.long
204-
),
205-
sample_inputs_func=sample_inputs_window,
206-
error_inputs_func=error_inputs_gaussian_window,
207-
supports_out=False,
208-
supports_autograd=False,
229+
"signal.windows.gaussian_default",
230+
sample_inputs_gaussian_window,
231+
error_inputs_gaussian_window,
209232
skips=(
210233
DecorateInfo(
211234
unittest.expectedFailure,
@@ -244,5 +267,3 @@ def error_inputs_gaussian_window(op_info, device, *args, **kwargs):
244267
),
245268
),
246269
]
247-
248-
python_ref_db: List[OpInfo] = []

0 commit comments

Comments
 (0)