@@ -28,7 +28,7 @@ def sample_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs)
2828
2929 # Test a window size of length zero and one.
3030 # If it's either symmetric or not doesn't matter in these sample inputs.
31- for size in [ 0 , 1 ] :
31+ for size in range ( 2 ) :
3232 yield SampleInput (
3333 size ,
3434 * args ,
@@ -40,8 +40,7 @@ def sample_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs)
4040
4141 # For sizes larger than 1 we need to test both symmetric and non-symmetric windows.
4242 # Note: sample input tensors must be kept rather small.
43- sizes = [2 , 5 , 10 , 50 ]
44- for size , sym in product (sizes , (True , False )):
43+ for size , sym in product (list (range (2 , 6 )), (True , False )):
4544 yield SampleInput (
4645 size ,
4746 * args ,
@@ -53,12 +52,6 @@ def sample_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs)
5352 )
5453
5554
56- def sample_inputs_gaussian_window (op_info , device , dtype , requires_grad , ** kwargs ):
57- yield from sample_inputs_window (
58- op_info , device , dtype , requires_grad , random .uniform (0 , 3 ), ** kwargs # std,
59- )
60-
61-
6255def error_inputs_window (op_info , device , * args , ** kwargs ):
6356 # Tests for windows that have a negative size
6457 yield ErrorInput (
@@ -91,7 +84,7 @@ def error_inputs_window(op_info, device, *args, **kwargs):
9184
9285def error_inputs_exponential_window (op_info , device , ** kwargs ):
9386 # Yield common error inputs
94- yield from error_inputs_window (op_info , device , 0.5 , ** kwargs )
87+ yield from error_inputs_window (op_info , device , ** kwargs )
9588
9689 # Tests for negative decay values.
9790 yield ErrorInput (
@@ -110,41 +103,29 @@ def error_inputs_exponential_window(op_info, device, **kwargs):
110103
111104def error_inputs_gaussian_window (op_info , device , ** kwargs ):
112105 # Yield common error inputs
113- yield from error_inputs_window (op_info , device , 0.5 , ** kwargs ) # std
106+ yield from error_inputs_window (op_info , device , std = 0.5 , ** kwargs )
114107
115108 # Tests for negative standard deviations
116109 yield ErrorInput (
117- SampleInput (3 , - 1 , dtype = torch .float32 , device = device , ** kwargs ),
110+ SampleInput (3 , std = - 1 , dtype = torch .float32 , device = device , ** kwargs ),
118111 error_type = ValueError ,
119112 error_regex = "Standard deviation must be positive, got: -1 instead." ,
120113 )
121114
122115
123116def make_signal_windows_opinfo (
124- name , variant_test_name , ref , sample_inputs_func , error_inputs_func , * , skips = ()
117+ name , ref , sample_inputs_func , error_inputs_func , * , skips = ()
125118):
126119 r"""Helper function to create OpInfo objects related to different windows."""
127120 return OpInfo (
128121 name = name ,
129- variant_test_name = variant_test_name ,
130122 ref = ref if TEST_SCIPY else None ,
131123 dtypes = floating_types_and (torch .bfloat16 , torch .float16 ),
132124 dtypesIfCUDA = floating_types_and (torch .bfloat16 , torch .float16 ),
133125 sample_inputs_func = sample_inputs_func ,
134126 error_inputs_func = error_inputs_func ,
135127 supports_out = False ,
136128 supports_autograd = False ,
137- skips = skips ,
138- )
139-
140-
141- op_db : List [OpInfo ] = [
142- make_signal_windows_opinfo (
143- name = "signal.windows.cosine" ,
144- variant_test_name = "" ,
145- ref = scipy .signal .windows .cosine ,
146- sample_inputs_func = sample_inputs_window ,
147- error_inputs_func = error_inputs_window ,
148129 skips = (
149130 DecorateInfo (
150131 unittest .expectedFailure ,
@@ -178,88 +159,38 @@ def make_signal_windows_opinfo(
178159 unittest .skip ("Skipped!" ), "TestMathBits" , "test_neg_conj_view"
179160 ),
180161 DecorateInfo (unittest .skip ("Skipped!" ), "TestMathBits" , "test_neg_view" ),
162+ DecorateInfo (
163+ unittest .skip ("Skipped!" ),
164+ "TestVmapOperatorsOpInfo" ,
165+ "test_vmap_exhaustive" ,
166+ ),
167+ DecorateInfo (
168+ unittest .skip ("Skipped!" ),
169+ "TestVmapOperatorsOpInfo" ,
170+ "test_op_has_batch_rule" ,
171+ ),
172+ * skips ,
181173 ),
174+ )
175+
176+
177+ op_db : List [OpInfo ] = [
178+ make_signal_windows_opinfo (
179+ name = "signal.windows.cosine" ,
180+ ref = scipy .signal .windows .cosine ,
181+ sample_inputs_func = sample_inputs_window ,
182+ error_inputs_func = error_inputs_window ,
182183 ),
183184 make_signal_windows_opinfo (
184185 name = "signal.windows.exponential" ,
185- variant_test_name = "" ,
186186 ref = scipy .signal .windows .exponential ,
187187 sample_inputs_func = partial (sample_inputs_window , tau = random .uniform (0 , 10 )),
188188 error_inputs_func = error_inputs_exponential_window ,
189- skips = (
190- DecorateInfo (
191- unittest .expectedFailure ,
192- "TestNormalizeOperators" ,
193- "test_normalize_operator_exhaustive" ,
194- ),
195- # TODO: same as this?
196- # https://github.com/pytorch/pytorch/issues/81774
197- # also see: arange, new_full
198- # fails to match any schemas despite working in the interpreter
199- DecorateInfo (
200- unittest .expectedFailure ,
201- "TestOperatorSignatures" ,
202- "test_get_torch_func_signature_exhaustive" ,
203- ),
204- # fails to match any schemas despite working in the interpreter
205- DecorateInfo (
206- unittest .expectedFailure , "TestJit" , "test_variant_consistency_jit"
207- ),
208- # skip these tests since we have non tensor input
209- DecorateInfo (
210- unittest .skip ("Skipped!" ), "TestCommon" , "test_noncontiguous_samples"
211- ),
212- DecorateInfo (
213- unittest .skip ("Skipped!" ),
214- "TestCommon" ,
215- "test_variant_consistency_eager" ,
216- ),
217- DecorateInfo (unittest .skip ("Skipped!" ), "TestMathBits" , "test_conj_view" ),
218- DecorateInfo (
219- unittest .skip ("Skipped!" ), "TestMathBits" , "test_neg_conj_view"
220- ),
221- DecorateInfo (unittest .skip ("Skipped!" ), "TestMathBits" , "test_neg_view" ),
222- ),
223189 ),
224190 make_signal_windows_opinfo (
225191 name = "signal.windows.gaussian" ,
226- variant_test_name = "" ,
227192 ref = scipy .signal .windows .gaussian ,
228- sample_inputs_func = sample_inputs_gaussian_window ,
193+ sample_inputs_func = partial ( sample_inputs_window , std = random . uniform ( 0 , 3 )) ,
229194 error_inputs_func = error_inputs_gaussian_window ,
230- skips = (
231- DecorateInfo (
232- unittest .expectedFailure ,
233- "TestNormalizeOperators" ,
234- "test_normalize_operator_exhaustive" ,
235- ),
236- # TODO: same as this?
237- # https://github.com/pytorch/pytorch/issues/81774
238- # also see: arange, new_full
239- # fails to match any schemas despite working in the interpreter
240- DecorateInfo (
241- unittest .expectedFailure ,
242- "TestOperatorSignatures" ,
243- "test_get_torch_func_signature_exhaustive" ,
244- ),
245- # fails to match any schemas despite working in the interpreter
246- DecorateInfo (
247- unittest .expectedFailure , "TestJit" , "test_variant_consistency_jit"
248- ),
249- # skip these tests since we have non tensor input
250- DecorateInfo (
251- unittest .skip ("Skipped!" ), "TestCommon" , "test_noncontiguous_samples"
252- ),
253- DecorateInfo (
254- unittest .skip ("Skipped!" ),
255- "TestCommon" ,
256- "test_variant_consistency_eager" ,
257- ),
258- DecorateInfo (unittest .skip ("Skipped!" ), "TestMathBits" , "test_conj_view" ),
259- DecorateInfo (
260- unittest .skip ("Skipped!" ), "TestMathBits" , "test_neg_conj_view"
261- ),
262- DecorateInfo (unittest .skip ("Skipped!" ), "TestMathBits" , "test_neg_view" ),
263- ),
264195 ),
265196]
0 commit comments