@@ -65,6 +65,26 @@ def __call__(self, inputs, is_cuda, is_fastpath, **kwargs):
6565 return inputs [0 ] if self ._is_inplace else actual
6666
6767
68+ def get_transform_func (num_tensors , dtype , device , is_fastpath ):
69+ def transform (t ):
70+ if not torch .is_tensor (t ):
71+ return t
72+ return make_tensor (
73+ (num_tensors , num_tensors ), dtype = dtype , device = device ,
74+ requires_grad = True , noncontiguous = not is_fastpath ,
75+ )
76+ return transform
77+
78+
79+ def clone (arg ):
80+ if isinstance (arg , (list , tuple )):
81+ return [clone (a ) for a in arg ]
82+ if torch .is_tensor (arg ):
83+ return arg .clone ().detach ().requires_grad_ ()
84+ else :
85+ return arg
86+
87+
6888class TestForeach (TestCase ):
6989
7090 @property
@@ -82,18 +102,21 @@ def _get_funcs(self, op):
82102 RegularFuncWrapper (op .ref_inplace ),
83103 )
84104
85- def _binary_test (self , dtype , op , ref , inputs , is_fastpath , is_inplace , * , alpha = None ):
105+ def _binary_test (self , dtype , op , ref , inputs , is_fastpath , is_inplace , * , alpha = None , scalar_self_arg = False ):
86106 ref_inputs = [[t .clone ().detach () for t in inputs [0 ]], inputs [1 ]] if is_inplace else inputs
87107
88108 try :
89109 actual = op (inputs , self .is_cuda , is_fastpath )
90110 except RuntimeError as e :
91111 with self .assertRaisesRegex (type (e ), re .escape (str (e ))):
92- ref (ref_inputs )
112+ if not scalar_self_arg :
113+ ref (ref_inputs )
114+ else :
115+ [ref .func (ref_inputs [0 ], t ) for t in ref_inputs [1 ]]
93116 else :
94- expected = ref (ref_inputs )
117+ expected = ref (ref_inputs ) if not scalar_self_arg else [ ref . func ( ref_inputs [ 0 ], t ) for t in ref_inputs [ 1 ]]
95118 self .assertEqual (actual , expected )
96- if alpha is not None :
119+ if alpha is not None and not scalar_self_arg :
97120 kwargs = {'alpha' : alpha }
98121 ref_inputs = inputs
99122 try :
@@ -112,26 +135,54 @@ def _binary_test(self, dtype, op, ref, inputs, is_fastpath, is_inplace, *, alpha
112135 @ops (foreach_binary_op_db )
113136 @parametrize ("is_fastpath" , (True , False ))
114137 def test_binary_op (self , device , dtype , op , is_fastpath ):
115- for sample in op .sample_inputs (device , dtype , noncontiguous = not is_fastpath ):
138+ scalar_self_arg_test_complete = False
139+ for i , sample in enumerate (op .sample_inputs (device , dtype , noncontiguous = not is_fastpath )):
116140 rhs_arg , = sample .args
117141 kwargs = {} or sample .kwargs
118142 alpha = kwargs .pop ("alpha" , None )
119143 disable_fastpath = kwargs .pop ("disable_fastpath" ) if is_fastpath else False
120144 wrapped_op , ref , inplace_op , inplace_ref = self ._get_funcs (op )
121145 self ._binary_test (
122- dtype , wrapped_op , ref , [sample .input , rhs_arg ], is_fastpath and not disable_fastpath , False , alpha = alpha )
146+ dtype , wrapped_op , ref , [sample .input , rhs_arg ],
147+ is_fastpath and not disable_fastpath , False , alpha = alpha )
123148 self ._binary_test (
124- dtype , inplace_op , inplace_ref , [sample .input , rhs_arg ], is_fastpath and not disable_fastpath , True , alpha = alpha )
125- if op .supports_scalar_self_arg and isinstance (rhs_arg , list ) and isinstance (rhs_arg [0 ], torch .Tensor ):
149+ dtype , inplace_op , inplace_ref , [sample .input , rhs_arg ],
150+ is_fastpath and not disable_fastpath , True , alpha = alpha )
151+
152+ if op .supports_autograd and dtype in floating_types ():
153+ transformed_sample = sample .transform (get_transform_func (len (sample .input ), dtype , device , is_fastpath ))
154+ tensors = transformed_sample .input
155+ rhs_arg , = transformed_sample .args
156+ ref_tensors , ref_rhs_arg = clone (tensors ), clone (rhs_arg )
157+ try :
158+ sum (wrapped_op ([tensors , rhs_arg ], is_cuda = False , is_fastpath = False )).mean ().backward ()
159+ except RuntimeError :
160+ with self .assertRaises (RuntimeError ):
161+ sum (ref ([ref_tensors , ref_rhs_arg ])).mean ().backward ()
162+ else :
163+ sum (ref ([ref_tensors , ref_rhs_arg ])).mean ().backward ()
164+ self .assertEqual ([t .grad for t in tensors ], [t .grad for t in ref_tensors ])
165+ if isinstance (rhs_arg , list ) and isinstance (rhs_arg [0 ], torch .Tensor ):
166+ self .assertEqual ([t .grad for t in rhs_arg ], [t .grad for t in ref_rhs_arg ])
167+ if op .supports_scalar_self_arg and isinstance (rhs_arg , Number ) and (not scalar_self_arg_test_complete ):
168+ scalar_self_arg_test_complete = True
126169 self ._binary_test (
127- dtype , wrapped_op , ref , [rhs_arg , sample .input ], is_fastpath and not disable_fastpath , False , alpha = alpha )
170+ dtype , wrapped_op , ref , [rhs_arg , sample .input ], is_fastpath , False ,
171+ alpha = alpha , scalar_self_arg = True )
172+ if op .supports_autograd and dtype == torch .float32 :
173+ transformed_sample = sample .transform (
174+ get_transform_func (len (sample .input ), dtype , device , is_fastpath ))
175+ tensors = transformed_sample .input
176+ rhs_arg , = transformed_sample .args
177+ ref_tensors , ref_rhs_arg = clone (tensors ), clone (rhs_arg )
178+ sum (wrapped_op ([rhs_arg , tensors ], is_cuda = False , is_fastpath = False )).mean ().backward ()
179+ sum ([ref .func (ref_rhs_arg , t ) for t in ref_tensors ]).mean ().backward ()
180+ self .assertEqual ([t .grad for t in tensors ], [t .grad for t in ref_tensors ])
128181
129182 @ops (foreach_pointwise_op_db )
130183 @parametrize ("is_fastpath" , (True , False ))
131184 def test_pointwise_op (self , device , dtype , op , is_fastpath ):
132- for sample in op .sample_inputs (device , dtype ):
133- if not is_fastpath :
134- sample = sample .noncontiguous ()
185+ for sample in op .sample_inputs (device , dtype , noncontiguous = not is_fastpath ):
135186 assert isinstance (sample .args , tuple )
136187 assert len (sample .args ) == 2
137188 inputs = [sample .input , * sample .args ]
@@ -140,7 +191,27 @@ def test_pointwise_op(self, device, dtype, op, is_fastpath):
140191 wrapped_op , ref , inplace_op , inplace_ref = self ._get_funcs (op )
141192 values = kwargs .pop ("values" )
142193 self ._pointwise_test (wrapped_op , ref , inputs , is_fastpath and not disable_fastpath , False , values = values )
143- self ._pointwise_test (inplace_op , inplace_ref , inputs , is_fastpath and not disable_fastpath , True , values = values )
194+ self ._pointwise_test (
195+ inplace_op , inplace_ref , inputs , is_fastpath and not disable_fastpath ,
196+ True , values = values )
197+
198+ if op .supports_autograd and dtype in floating_types ():
199+ transformed_sample = sample .transform (
200+ get_transform_func (len (sample .input ), dtype , device , is_fastpath ))
201+ tensors = transformed_sample .input
202+ rhs_arg = transformed_sample .args
203+ ref_tensors , ref_rhs_arg = clone (tensors ), clone (rhs_arg )
204+ try :
205+ sum (wrapped_op ([tensors , * rhs_arg ], is_cuda = False , is_fastpath = False )).mean ().backward ()
206+ except RuntimeError :
207+ with self .assertRaises (RuntimeError ):
208+ sum (ref ([ref_tensors , * ref_rhs_arg ])).mean ().backward ()
209+ else :
210+ sum (ref ([ref_tensors , * ref_rhs_arg ])).mean ().backward ()
211+ self .assertEqual ([t .grad for t in tensors ], [t .grad for t in ref_tensors ])
212+ for op_list , ref_list in zip (rhs_arg , ref_rhs_arg ):
213+ if isinstance (op_list , list ) and isinstance (op_list [0 ], torch .Tensor ):
214+ self .assertEqual ([t .grad for t in op_list ], [t .grad for t in ref_list ])
144215
145216 if is_fastpath and isinstance (values , list ):
146217 sample = sample .transform (lambda t : t .clone ().detach () if torch .is_tensor (t ) else t )
@@ -224,24 +295,6 @@ def _inplace_unary_test(self, inplace, inplace_ref, inputs, is_fastpath):
224295 inplace_ref (copied_inputs ),
225296 self .assertEqual (copied_inputs , inputs )
226297
227- def _test_unary (self , device , dtype , opinfo , N , is_fastpath ):
228- op , ref , inplace_op , inplace_ref = self ._get_funcs (opinfo , 1 )
229- inputs = opinfo .sample_inputs (device , dtype , N , noncontiguous = not is_fastpath ),
230- # note(mkozuki): Complex inputs for `_foreach_abs` go through slowpath.
231- if opinfo .name == "_foreach_abs" and dtype in complex_types ():
232- is_fastpath = False
233- self ._regular_unary_test (dtype , op , ref , inputs , is_fastpath )
234- self ._inplace_unary_test (dtype , inplace_op , inplace_ref , inputs , is_fastpath )
235-
236- if opinfo .supports_autograd and dtype in floating_types ():
237- tensors = opinfo .sample_inputs (device , dtype , N , noncontiguous = not is_fastpath , same_size = True )
238- tensors = [t .requires_grad_ () for t in tensors ]
239- ref_tensors = [t .clone ().detach ().requires_grad_ () for t in tensors ]
240-
241- sum (op .func (tensors )).mean ().backward ()
242- sum ([ref .func (t ) for t in ref_tensors ]).mean ().backward ()
243- self .assertEqual ([t .grad for t in tensors ], [t .grad for t in ref_tensors ])
244-
245298 @skipMeta
246299 @ops (foreach_unary_op_db )
247300 @parametrize ("is_fastpath" , (True , False ))
@@ -259,19 +312,39 @@ def test_unary_op(self, device, dtype, op, is_fastpath):
259312 )
260313 self .assertEqual (ref (inputs ), wrapped_op (inputs , self .is_cuda , is_fastpath and not disable_fastpath ))
261314 self ._inplace_unary_test (inplace_op , inplace_ref , [sample .input ], is_fastpath and not disable_fastpath )
315+ if op .supports_autograd and dtype in floating_types ():
316+ num_tensors = len (sample .input )
317+ tensors = [
318+ make_tensor (
319+ (num_tensors , num_tensors ), dtype = dtype , device = device ,
320+ requires_grad = True , noncontiguous = not is_fastpath ,
321+ )
322+ for _ in range (num_tensors )
323+ ]
324+ ref_tensors = [t .clone ().detach ().requires_grad_ () for t in tensors ]
325+ sum (wrapped_op .func (tensors )).mean ().backward ()
326+ sum ([ref .func (t ) for t in ref_tensors ]).mean ().backward ()
327+ self .assertEqual ([t .grad for t in tensors ], [t .grad for t in ref_tensors ])
262328
263329 @ops (foreach_reduce_op_db )
264330 @parametrize ("is_fastpath" , (True , False ))
265331 def test_reduce_op (self , device , dtype , op , is_fastpath ):
266- for sample in op .sample_inputs (device , dtype ):
267- if not is_fastpath :
268- sample = sample .noncontiguous ()
332+ for sample in op .sample_inputs (device , dtype , noncontiguous = not is_fastpath ):
269333 ord = sample .kwargs .pop ("ord" )
270334 disable_fastpath = sample .kwargs .pop ("disable_fastpath" , False )
271335
272336 inputs = (sample .input ,)
273337 wrapped_op , ref , _ , _ = self ._get_funcs (op )
274338 self .assertEqual (ref (inputs , ord = ord ), wrapped_op (inputs , self .is_cuda , is_fastpath and not disable_fastpath , ord = ord ))
339+ if op .supports_autograd and dtype in floating_types ():
340+ transformed_sample = sample .transform (get_transform_func (len (sample .input ), dtype , device , is_fastpath ))
341+ tensors = transformed_sample .input
342+ ref_tensors = clone (tensors )
343+ sum (wrapped_op ((tensors ,), False , False , ord = ord )).backward ()
344+ sum (ref ((ref_tensors ,), ord = ord )).backward ()
345+ self .assertEqual (
346+ [t .grad for t in tensors ], [t .grad for t in ref_tensors ],
347+ )
275348
276349 @dtypes (* all_types_and_complex_and (torch .half , torch .bfloat16 , torch .bool ))
277350 def test_add_scalar_with_empty_list_and_empty_tensor (self , device , dtype ):
@@ -285,7 +358,6 @@ def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
285358
286359 @ops (foreach_binary_op_db , dtypes = OpDTypes .supported )
287360 def test_binary_op_scalar_with_overlapping_tensors (self , device , dtype , op ):
288- print (op , device , dtype )
289361 foreach_op , ref = op .method_variant , op .ref
290362 tensors = [torch .ones (1 , 1 , device = device , dtype = dtype ).expand (2 , 1 , 3 )]
291363
@@ -533,7 +605,6 @@ def test_foreach_l2_large_value_input(self, device, dtype, op):
533605 def test_lerp (self , device , dtype , op , is_fastpath ):
534606 for sample in op .sample_inputs (device , dtype , noncontiguous = not is_fastpath ):
535607 wrapped_op , ref , inplace_op , _ = self ._get_funcs (op )
536-
537608 args = [* sample .args ]
538609 inputs = [sample .input , args [0 ]]
539610
@@ -559,6 +630,24 @@ def test_lerp(self, device, dtype, op, is_fastpath):
559630 inplace_actual = inplace_op (inplace_inputs , self .is_cuda , is_fastpath , ** kwargs )
560631 self .assertEqual (inplace_actual , expected )
561632
633+ if op .supports_autograd and dtype in floating_types ():
634+ transformed_sample = sample .transform (get_transform_func (len (sample .input ), dtype , device , is_fastpath ))
635+ args = [* transformed_sample .args ]
636+ inputs = [transformed_sample .input , args [0 ]]
637+
638+ kwargs , ref_kwargs = {}, {}
639+ if isinstance (args [1 ], list ):
640+ inputs .append (args [1 ])
641+ else :
642+ kwargs = ref_kwargs = {"weight" : args [1 ]}
643+ ref_tensors = clone (transformed_sample .input )
644+ sum (wrapped_op ((transformed_sample .input , * inputs [1 :]), False , False , ** kwargs )).mean ().backward ()
645+ sum (ref ((ref_tensors , * inputs [1 :]), ** ref_kwargs )).mean ().backward ()
646+ self .assertEqual (
647+ [t .grad for t in transformed_sample .input ], [t .grad for t in ref_tensors ],
648+ msg = f"{ transformed_sample .input [0 ].grad [:2 , :2 ]} , { ref_tensors [0 ].grad [:2 , :2 ]} "
649+ )
650+
562651 @onlyCUDA
563652 @ops (foreach_reduce_op_db )
564653 def test_foreach_reduce_large_input (self , device , dtype , op ):
0 commit comments