Skip to content

Commit bb198a2

Browse files
committed
Update tests
1 parent 7aa10c4 commit bb198a2

File tree

3 files changed

+58
-114
lines changed

3 files changed

+58
-114
lines changed

torch/signal/windows/windows.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from torch import Tensor
55
from torch._prims_common import is_float_dtype
66
from torch._torch_docs import factory_common_args
7-
from torch.types import _dtype, _device, _layout
87

98
__all__ = [
109
'cosine',
@@ -32,18 +31,18 @@ def decorator(o):
3231
return decorator
3332

3433

35-
def _window_function_checks(function_name: str, window_length: int, dtype: _dtype, layout: _layout) -> None:
34+
def _window_function_checks(function_name: str, M: int, dtype: torch.dtype, layout: torch.layout) -> None:
3635
r"""Performs common checks for all the defined windows.
3736
This function should be called before computing any window
3837
3938
Args:
4039
function_name (str): name of the window function.
41-
window_length (int): length of the window.
40+
M (int): length of the window.
4241
dtype (:class:`torch.dtype`): the desired data type of the window tensor.
4342
layout (:class:`torch.layout`): the desired layout of the window tensor.
4443
"""
45-
if window_length < 0:
46-
raise ValueError(f'{function_name} requires non-negative window_length, got window_length={window_length}')
44+
if M < 0:
45+
raise ValueError(f'{function_name} requires non-negative window_length, got window_length={M}')
4746
if layout is not torch.strided:
4847
raise ValueError(f'{function_name} is implemented for strided tensors only, got: {layout}')
4948
if not is_float_dtype(dtype):
@@ -68,9 +67,7 @@ def _window_function_checks(function_name: str, window_length: int, dtype: _dtyp
6867
center (float, optional): where the center of the window will be located.
6968
In other words, at which sample the peak of the window can be found.
7069
Default: `window_length / 2` if `periodic` is `True` (default), else `(window_length - 1) / 2`.
71-
tau (float, optional): the decay value.
72-
For `center = 0`, it's suggested to use :math:`\tau = -\frac{(M - 1)}{\ln(x)}`,
73-
if `x` is the fraction of the window remaining at the end. Default: 1.0.
70+
tau (float, optional): the decay value. Default: 1.0.
7471
sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis.
7572
If `True`, returns a symmetric window suitable for use in filter design. Default: `True`.
7673
""" +
@@ -100,8 +97,17 @@ def _window_function_checks(function_name: str, window_length: int, dtype: _dtyp
10097
**factory_common_args
10198
),
10299
)
103-
def exponential(M: int, center: float = None, tau: float = 1.0, sym: bool = True, *, dtype: _dtype = None,
104-
layout: _layout = torch.strided, device: _device = None, requires_grad: bool = False) -> Tensor:
100+
def exponential(
101+
M: int,
102+
center: float = None,
103+
tau: float = 1.0,
104+
sym: bool = True,
105+
*,
106+
dtype: torch.dtype = None,
107+
layout: torch.layout = torch.strided,
108+
device: torch.device = None,
109+
requires_grad: bool = False
110+
) -> Tensor:
105111
if dtype is None:
106112
dtype = torch.get_default_dtype()
107113

@@ -186,9 +192,9 @@ def exponential(M: int, center: float = None, tau: float = 1.0, sym: bool = True
186192
def cosine(M: int,
187193
sym: bool = True,
188194
*,
189-
dtype: _dtype = None,
190-
layout: _layout = torch.strided,
191-
device: _device = None,
195+
dtype: torch.dtype = None,
196+
layout: torch.layout = torch.strided,
197+
device: torch.device = None,
192198
requires_grad: bool = False) -> Tensor:
193199
if dtype is None:
194200
dtype = torch.get_default_dtype()
@@ -233,7 +239,8 @@ def cosine(M: int,
233239
Args:
234240
M (int): the length of the output window.
235241
In other words, the number of points of the cosine window.
236-
std (float): the standard deviation of the gaussian. It controls how narrow or wide the window is.
242+
std (float, optional): the standard deviation of the gaussian. It controls how narrow or wide the window is.
243+
Default: 1.0.
237244
sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis.
238245
If `True`, returns a symmetric window suitable for use in filter design. Default: `True`
239246
@@ -262,12 +269,12 @@ def cosine(M: int,
262269
),
263270
)
264271
def gaussian(M: int,
265-
std: float,
272+
std: float = 1.0,
266273
sym: bool = True,
267274
*,
268-
dtype: _dtype = None,
269-
layout: _layout = torch.strided,
270-
device: _device = None,
275+
dtype: torch.dtype = None,
276+
layout: torch.layout = torch.strided,
277+
device: torch.device = None,
271278
requires_grad: bool = False) -> Tensor:
272279
if dtype is None:
273280
dtype = torch.get_default_dtype()

torch/testing/_internal/common_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2341,6 +2341,12 @@ def safeToDense(self, t):
23412341
def compare_with_reference(self, torch_fn, ref_fn, sample_input, **kwargs):
23422342
numpy_sample = sample_input.numpy()
23432343
n_inp, n_args, n_kwargs = numpy_sample.input, numpy_sample.args, numpy_sample.kwargs
2344+
2345+
# Remove torch-specific kwargs
2346+
for torch_key in {'device', 'layout', 'dtype', 'requires_grad'}:
2347+
if torch_key in n_kwargs:
2348+
n_kwargs.pop(torch_key)
2349+
23442350
t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs
23452351

23462352
actual = torch_fn(t_inp, *t_args, **t_kwargs)

torch/testing/_internal/opinfo/definitions/signal.py

Lines changed: 27 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def sample_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs)
2828

2929
# Test a window size of length zero and one.
3030
# If it's either symmetric or not doesn't matter in these sample inputs.
31-
for size in [0, 1]:
31+
for size in range(2):
3232
yield SampleInput(
3333
size,
3434
*args,
@@ -40,8 +40,7 @@ def sample_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs)
4040

4141
# For sizes larger than 1 we need to test both symmetric and non-symmetric windows.
4242
# Note: sample input tensors must be kept rather small.
43-
sizes = [2, 5, 10, 50]
44-
for size, sym in product(sizes, (True, False)):
43+
for size, sym in product(list(range(2, 6)), (True, False)):
4544
yield SampleInput(
4645
size,
4746
*args,
@@ -53,12 +52,6 @@ def sample_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs)
5352
)
5453

5554

56-
def sample_inputs_gaussian_window(op_info, device, dtype, requires_grad, **kwargs):
57-
yield from sample_inputs_window(
58-
op_info, device, dtype, requires_grad, random.uniform(0, 3), **kwargs # std,
59-
)
60-
61-
6255
def error_inputs_window(op_info, device, *args, **kwargs):
6356
# Tests for windows that have a negative size
6457
yield ErrorInput(
@@ -91,7 +84,7 @@ def error_inputs_window(op_info, device, *args, **kwargs):
9184

9285
def error_inputs_exponential_window(op_info, device, **kwargs):
9386
# Yield common error inputs
94-
yield from error_inputs_window(op_info, device, 0.5, **kwargs)
87+
yield from error_inputs_window(op_info, device, **kwargs)
9588

9689
# Tests for negative decay values.
9790
yield ErrorInput(
@@ -110,41 +103,29 @@ def error_inputs_exponential_window(op_info, device, **kwargs):
110103

111104
def error_inputs_gaussian_window(op_info, device, **kwargs):
112105
# Yield common error inputs
113-
yield from error_inputs_window(op_info, device, 0.5, **kwargs) # std
106+
yield from error_inputs_window(op_info, device, std=0.5, **kwargs)
114107

115108
# Tests for negative standard deviations
116109
yield ErrorInput(
117-
SampleInput(3, -1, dtype=torch.float32, device=device, **kwargs),
110+
SampleInput(3, std=-1, dtype=torch.float32, device=device, **kwargs),
118111
error_type=ValueError,
119112
error_regex="Standard deviation must be positive, got: -1 instead.",
120113
)
121114

122115

123116
def make_signal_windows_opinfo(
124-
name, variant_test_name, ref, sample_inputs_func, error_inputs_func, *, skips=()
117+
name, ref, sample_inputs_func, error_inputs_func, *, skips=()
125118
):
126119
r"""Helper function to create OpInfo objects related to different windows."""
127120
return OpInfo(
128121
name=name,
129-
variant_test_name=variant_test_name,
130122
ref=ref if TEST_SCIPY else None,
131123
dtypes=floating_types_and(torch.bfloat16, torch.float16),
132124
dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
133125
sample_inputs_func=sample_inputs_func,
134126
error_inputs_func=error_inputs_func,
135127
supports_out=False,
136128
supports_autograd=False,
137-
skips=skips,
138-
)
139-
140-
141-
op_db: List[OpInfo] = [
142-
make_signal_windows_opinfo(
143-
name="signal.windows.cosine",
144-
variant_test_name="",
145-
ref=scipy.signal.windows.cosine,
146-
sample_inputs_func=sample_inputs_window,
147-
error_inputs_func=error_inputs_window,
148129
skips=(
149130
DecorateInfo(
150131
unittest.expectedFailure,
@@ -178,88 +159,38 @@ def make_signal_windows_opinfo(
178159
unittest.skip("Skipped!"), "TestMathBits", "test_neg_conj_view"
179160
),
180161
DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_neg_view"),
162+
DecorateInfo(
163+
unittest.skip("Skipped!"),
164+
"TestVmapOperatorsOpInfo",
165+
"test_vmap_exhaustive",
166+
),
167+
DecorateInfo(
168+
unittest.skip("Skipped!"),
169+
"TestVmapOperatorsOpInfo",
170+
"test_op_has_batch_rule",
171+
),
172+
*skips,
181173
),
174+
)
175+
176+
177+
op_db: List[OpInfo] = [
178+
make_signal_windows_opinfo(
179+
name="signal.windows.cosine",
180+
ref=scipy.signal.windows.cosine,
181+
sample_inputs_func=sample_inputs_window,
182+
error_inputs_func=error_inputs_window,
182183
),
183184
make_signal_windows_opinfo(
184185
name="signal.windows.exponential",
185-
variant_test_name="",
186186
ref=scipy.signal.windows.exponential,
187187
sample_inputs_func=partial(sample_inputs_window, tau=random.uniform(0, 10)),
188188
error_inputs_func=error_inputs_exponential_window,
189-
skips=(
190-
DecorateInfo(
191-
unittest.expectedFailure,
192-
"TestNormalizeOperators",
193-
"test_normalize_operator_exhaustive",
194-
),
195-
# TODO: same as this?
196-
# https://github.com/pytorch/pytorch/issues/81774
197-
# also see: arange, new_full
198-
# fails to match any schemas despite working in the interpreter
199-
DecorateInfo(
200-
unittest.expectedFailure,
201-
"TestOperatorSignatures",
202-
"test_get_torch_func_signature_exhaustive",
203-
),
204-
# fails to match any schemas despite working in the interpreter
205-
DecorateInfo(
206-
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
207-
),
208-
# skip these tests since we have non tensor input
209-
DecorateInfo(
210-
unittest.skip("Skipped!"), "TestCommon", "test_noncontiguous_samples"
211-
),
212-
DecorateInfo(
213-
unittest.skip("Skipped!"),
214-
"TestCommon",
215-
"test_variant_consistency_eager",
216-
),
217-
DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_conj_view"),
218-
DecorateInfo(
219-
unittest.skip("Skipped!"), "TestMathBits", "test_neg_conj_view"
220-
),
221-
DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_neg_view"),
222-
),
223189
),
224190
make_signal_windows_opinfo(
225191
name="signal.windows.gaussian",
226-
variant_test_name="",
227192
ref=scipy.signal.windows.gaussian,
228-
sample_inputs_func=sample_inputs_gaussian_window,
193+
sample_inputs_func=partial(sample_inputs_window, std=random.uniform(0, 3)),
229194
error_inputs_func=error_inputs_gaussian_window,
230-
skips=(
231-
DecorateInfo(
232-
unittest.expectedFailure,
233-
"TestNormalizeOperators",
234-
"test_normalize_operator_exhaustive",
235-
),
236-
# TODO: same as this?
237-
# https://github.com/pytorch/pytorch/issues/81774
238-
# also see: arange, new_full
239-
# fails to match any schemas despite working in the interpreter
240-
DecorateInfo(
241-
unittest.expectedFailure,
242-
"TestOperatorSignatures",
243-
"test_get_torch_func_signature_exhaustive",
244-
),
245-
# fails to match any schemas despite working in the interpreter
246-
DecorateInfo(
247-
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
248-
),
249-
# skip these tests since we have non tensor input
250-
DecorateInfo(
251-
unittest.skip("Skipped!"), "TestCommon", "test_noncontiguous_samples"
252-
),
253-
DecorateInfo(
254-
unittest.skip("Skipped!"),
255-
"TestCommon",
256-
"test_variant_consistency_eager",
257-
),
258-
DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_conj_view"),
259-
DecorateInfo(
260-
unittest.skip("Skipped!"), "TestMathBits", "test_neg_conj_view"
261-
),
262-
DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_neg_view"),
263-
),
264195
),
265196
]

0 commit comments

Comments
 (0)