1+ import random
12import unittest
3+
4+ from itertools import product
25from typing import List
36
47import torch
1518 import scipy .signal
1619
1720
18- def _sample_input_windows (sample_input , * args , ** kwargs ):
19- for size in [0 , 1 , 2 , 5 , 10 , 50 , 100 , 1024 , 2048 ]:
20- for periodic in [True , False ]:
21- kwargs .update (
22- {
23- "periodic" : periodic ,
24- }
25- )
26- yield sample_input (size , args = args , kwargs = kwargs )
21+ def sample_inputs_window (op_info , device , dtype , requires_grad , * , kws = ({},), ** kwargs ):
22+ _kwargs = dict (
23+ kwargs , ** {"device" : device , "dtype" : dtype , "requires_grad" : requires_grad }
24+ )
25+ sizes = [0 , 1 , 2 , 5 , 10 , 50 , 100 , 1024 , 2048 ]
26+ for size , periodic , k in product (sizes , (True , False ), kws ):
27+ yield SampleInput (size , periodic = periodic , ** k , ** _kwargs )
28+
29+
30+ def sample_inputs_gaussian_window (op_info , device , dtype , requires_grad , ** kwargs ):
31+ kws = [{"std" : random .uniform (0 , 3 )} for _ in range (50 )]
32+ yield from sample_inputs_window (
33+ op_info , device , dtype , requires_grad , kws = kws , ** kwargs
34+ )
2735
2836
29- def sample_inputs_window (op_info , * args , ** kwargs ):
30- return _sample_input_windows (SampleInput , * args , ** kwargs )
37+ def sample_inputs_exponential_window (op_info , device , dtype , requires_grad , ** kwargs ):
38+ kws = [{"center" : None , "tau" : random .uniform (0 , 10 )} for _ in range (50 )]
39+ yield from sample_inputs_window (
40+ op_info , device , dtype , requires_grad , kws = kws , ** kwargs
41+ )
3142
3243
33- def error_inputs_window (op_info , * args , ** kwargs ):
44+ def error_inputs_window (op_info , device , ** kwargs ):
45+ tmp_kwargs = dict (
46+ kwargs ,
47+ ** {
48+ "device" : device ,
49+ },
50+ )
51+
3452 yield ErrorInput (
35- SampleInput (- 1 , args = args , kwargs = kwargs ),
53+ SampleInput (- 1 , kwargs = tmp_kwargs ),
3654 error_type = ValueError ,
3755 error_regex = "requires non-negative window_length, got window_length=-1" ,
3856 )
@@ -45,71 +63,88 @@ def error_inputs_window(op_info, *args, **kwargs):
4563 )
4664
4765 yield ErrorInput (
48- SampleInput (3 , args = args , kwargs = tmp_kwargs ),
66+ SampleInput (3 , kwargs = tmp_kwargs ),
4967 error_type = ValueError ,
5068 error_regex = "is not implemented for sparse types, got: torch.sparse_coo" ,
5169 )
5270
5371 tmp_kwargs = kwargs
54- tmp_kwargs ["dtype" ] = torch .long
72+ tmp_kwargs .update (
73+ {
74+ "dtype" : torch .long ,
75+ }
76+ )
5577
5678 yield ErrorInput (
57- SampleInput (3 , args = args , kwargs = tmp_kwargs ),
79+ SampleInput (3 , kwargs = tmp_kwargs ),
5880 error_type = ValueError ,
5981 error_regex = "expects floating point dtypes, got: torch.int64" ,
6082 )
6183
6284
63- def error_inputs_exponential_window (op_info , device , * args , ** kwargs ):
64- tmp_kwargs = dict (kwargs , ** {"dtype" : torch .float32 , "device" : device })
85+ def error_inputs_exponential_window (op_info , device , ** kwargs ):
86+ tmp_kwargs = dict (kwargs , ** {"dtype" : torch .float32 })
87+
88+ yield from error_inputs_window (op_info , device , ** tmp_kwargs )
6589
66- for error_input in error_inputs_window (op_info , * args , ** kwargs ):
67- yield error_input
90+ tmp_kwargs .update ({"device" : device })
6891
6992 tmp_kwargs = dict (tmp_kwargs , ** {"tau" : - 1 })
7093
7194 yield ErrorInput (
72- SampleInput (3 , args = args , kwargs = tmp_kwargs ),
95+ SampleInput (3 , kwargs = tmp_kwargs ),
7396 error_type = ValueError ,
7497 error_regex = "Tau must be positive, got: -1 instead." ,
7598 )
7699
77100 tmp_kwargs = dict (tmp_kwargs , ** {"center" : 1 })
78101
79102 yield ErrorInput (
80- SampleInput (3 , args = args , kwargs = tmp_kwargs ),
103+ SampleInput (3 , kwargs = tmp_kwargs ),
81104 error_type = ValueError ,
82105 error_regex = "Center must be 'None' for periodic equal True" ,
83106 )
84107
85108
86- def error_inputs_gaussian_window (op_info , device , * args , * *kwargs ):
109+ def error_inputs_gaussian_window (op_info , device , ** kwargs ):
87110 tmp_kwargs = dict (kwargs , ** {"dtype" : torch .float32 , "device" : device })
88111
89- for error_input in error_inputs_window (op_info , * args , ** kwargs ):
90- yield error_input
112+ yield from error_inputs_window (op_info , device , ** kwargs )
91113
92114 tmp_kwargs = dict (tmp_kwargs , ** {"std" : - 1 })
93115
94116 yield ErrorInput (
95- SampleInput (3 , args = args , kwargs = tmp_kwargs ),
117+ SampleInput (3 , kwargs = tmp_kwargs ),
96118 error_type = ValueError ,
97119 error_regex = "Standard deviation must be positive, got: -1 instead." ,
98120 )
99121
100122
101- op_db : List [OpInfo ] = [
102- OpInfo (
103- "signal.windows.cosine" ,
123+ def make_signal_windows_opinfo (
124+ name , variant_test_name , sample_inputs_func , error_inputs_func , * , skips = ()
125+ ):
126+ return OpInfo (
127+ name = name ,
128+ variant_test_name = variant_test_name ,
104129 ref = scipy .signal .get_window if TEST_SCIPY else None ,
105130 dtypes = all_types_and (torch .float , torch .double , torch .long ),
106131 dtypesIfCUDA = all_types_and (
107132 torch .float , torch .double , torch .bfloat16 , torch .half , torch .long
108133 ),
109- sample_inputs_func = sample_inputs_window ,
110- error_inputs_func = error_inputs_window ,
134+ sample_inputs_func = sample_inputs_func ,
135+ error_inputs_func = error_inputs_func ,
111136 supports_out = False ,
112137 supports_autograd = False ,
138+ skips = skips ,
139+ )
140+
141+
142+ op_db : List [OpInfo ] = [
143+ make_signal_windows_opinfo (
144+ "signal.windows.cosine" ,
145+ "signal.windows.cosine_default" ,
146+ sample_inputs_window ,
147+ error_inputs_window ,
113148 skips = (
114149 DecorateInfo (
115150 unittest .expectedFailure ,
@@ -147,17 +182,11 @@ def error_inputs_gaussian_window(op_info, device, *args, **kwargs):
147182 DecorateInfo (unittest .expectedFailure , "TestCommon" , "test_out_warning" ),
148183 ),
149184 ),
150- OpInfo (
185+ make_signal_windows_opinfo (
151186 "signal.windows.exponential" ,
152- ref = scipy .signal .get_window if TEST_SCIPY else None ,
153- dtypes = all_types_and (torch .float , torch .double , torch .long ),
154- dtypesIfCUDA = all_types_and (
155- torch .float , torch .double , torch .bfloat16 , torch .half , torch .long
156- ),
157- sample_inputs_func = sample_inputs_window ,
158- error_inputs_func = error_inputs_exponential_window ,
159- supports_out = False ,
160- supports_autograd = False ,
187+ "signal.windows.exponential_default" ,
188+ sample_inputs_exponential_window ,
189+ error_inputs_exponential_window ,
161190 skips = (
162191 DecorateInfo (
163192 unittest .expectedFailure ,
@@ -195,17 +224,11 @@ def error_inputs_gaussian_window(op_info, device, *args, **kwargs):
195224 DecorateInfo (unittest .expectedFailure , "TestCommon" , "test_out_warning" ),
196225 ),
197226 ),
198- OpInfo (
227+ make_signal_windows_opinfo (
199228 "signal.windows.gaussian" ,
200- ref = scipy .signal .get_window if TEST_SCIPY else None ,
201- dtypes = all_types_and (torch .float , torch .double , torch .long ),
202- dtypesIfCUDA = all_types_and (
203- torch .float , torch .double , torch .bfloat16 , torch .half , torch .long
204- ),
205- sample_inputs_func = sample_inputs_window ,
206- error_inputs_func = error_inputs_gaussian_window ,
207- supports_out = False ,
208- supports_autograd = False ,
229+ "signal.windows.gaussian_default" ,
230+ sample_inputs_gaussian_window ,
231+ error_inputs_gaussian_window ,
209232 skips = (
210233 DecorateInfo (
211234 unittest .expectedFailure ,
@@ -244,5 +267,3 @@ def error_inputs_gaussian_window(op_info, device, *args, **kwargs):
244267 ),
245268 ),
246269]
247-
248- python_ref_db : List [OpInfo ] = []
0 commit comments