Skip to content

Commit f54233e

Browse files
crcrparpytorchmergebot
authored andcommitted
[foreach] bump tensor's version and define backward via torchgen (as possible) (#93901)
## summary - increment tensor versions in inplace foreach functions - add a logic to take care of `ArrayRef<Scalar>` rel: #58833, #89591 Pull Request resolved: #93901 Approved by: https://github.com/albanD
1 parent 83b5eb4 commit f54233e

File tree

12 files changed

+432
-115
lines changed

12 files changed

+432
-115
lines changed

aten/src/ATen/native/cuda/ForeachBinaryOpList.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ void foreach_tensor_list_op_(TensorList tensors1, TensorList tensors2, const Sca
5959
/* res_arg_index */ 0>(),
6060
Op<opmath_t>(),
6161
alpha.to<opmath_t>());
62+
increment_version(tensors1);
6263
}
6364

6465
template<template<class> class Op>

aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ void foreach_binary_op_(TensorList tensors, const Scalar& scalar) {
5757
/* res_arg_index */ 0>(),
5858
Op<opmath_t>(),
5959
scalar.to<opmath_t>());
60+
increment_version(tensors);
6061
}
6162

6263
template<template<class> class Op>

aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ void foreach_binary_op_(TensorList tensors, at::ArrayRef<Scalar> scalars) {
5858
/* r_args_depth */ 1,
5959
/* res_arg_index */ 0>(),
6060
Op<opmath_t>());
61+
increment_version(tensors);
6162
}
6263

6364
template<template<class> class Op>

aten/src/ATen/native/cuda/ForeachPointwiseOp.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ void foreach_pointwise_op_(TensorList input, TensorList tensors1, TensorList ten
6666
Op<opmath_t>(),
6767
scalar.to<opmath_t>());
6868
});
69+
increment_version(input);
6970
}
7071

7172
template<template<class> class Op>
@@ -86,6 +87,7 @@ void foreach_pointwise_op_(TensorList input, TensorList tensors1, TensorList ten
8687
/* res_arg_index */ 0>(),
8788
Op<opmath_t>());
8889
});
90+
increment_version(input);
8991
}
9092

9193
template<template<class> class Op>

aten/src/ATen/native/cuda/ForeachTernaryOp.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ void foreach_tensor_lerp_ternary_cuda_(TensorList tensors1, TensorList tensors2,
6666
LerpFunctor<opmath_t>());
6767
}
6868
);
69+
increment_version(tensors1);
6970
}
7071

7172
std::vector<at::Tensor> foreach_tensor_lerp_list_cuda(TensorList tensors1, TensorList tensors2, const Scalar& weight) {

test/test_autograd.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6135,6 +6135,14 @@ def test_grad_fn_attr_bindings(self):
61356135
with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
61366136
out.grad_fn._saved_weight
61376137

6138+
num_tensors = 3
6139+
input_tensors = [torch.ones(2, 2, requires_grad=True) for _ in range(num_tensors)]
6140+
scalars = [0.0 for _ in range(num_tensors)] # ArrayRef<Scalar> -> Tuple[Scalar, ...]
6141+
results = torch._foreach_maximum(input_tensors, scalars)
6142+
for t in results:
6143+
self.assertEqual(t.grad_fn._saved_scalars, scalars)
6144+
6145+
61386146
def test_cant_create_saved_tensors(self):
61396147
with self.assertRaisesRegex(RuntimeError, "Trying to create a SavedTensor object from Python is forbidden"):
61406148
torch.autograd.SavedTensor()

test/test_foreach.py

Lines changed: 125 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,26 @@ def __call__(self, inputs, is_cuda, is_fastpath, **kwargs):
6565
return inputs[0] if self._is_inplace else actual
6666

6767

68+
def get_transform_func(num_tensors, dtype, device, is_fastpath):
69+
def transform(t):
70+
if not torch.is_tensor(t):
71+
return t
72+
return make_tensor(
73+
(num_tensors, num_tensors), dtype=dtype, device=device,
74+
requires_grad=True, noncontiguous=not is_fastpath,
75+
)
76+
return transform
77+
78+
79+
def clone(arg):
80+
if isinstance(arg, (list, tuple)):
81+
return [clone(a) for a in arg]
82+
if torch.is_tensor(arg):
83+
return arg.clone().detach().requires_grad_()
84+
else:
85+
return arg
86+
87+
6888
class TestForeach(TestCase):
6989

7090
@property
@@ -82,18 +102,21 @@ def _get_funcs(self, op):
82102
RegularFuncWrapper(op.ref_inplace),
83103
)
84104

85-
def _binary_test(self, dtype, op, ref, inputs, is_fastpath, is_inplace, *, alpha=None):
105+
def _binary_test(self, dtype, op, ref, inputs, is_fastpath, is_inplace, *, alpha=None, scalar_self_arg=False):
86106
ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1]] if is_inplace else inputs
87107

88108
try:
89109
actual = op(inputs, self.is_cuda, is_fastpath)
90110
except RuntimeError as e:
91111
with self.assertRaisesRegex(type(e), re.escape(str(e))):
92-
ref(ref_inputs)
112+
if not scalar_self_arg:
113+
ref(ref_inputs)
114+
else:
115+
[ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
93116
else:
94-
expected = ref(ref_inputs)
117+
expected = ref(ref_inputs) if not scalar_self_arg else [ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
95118
self.assertEqual(actual, expected)
96-
if alpha is not None:
119+
if alpha is not None and not scalar_self_arg:
97120
kwargs = {'alpha': alpha}
98121
ref_inputs = inputs
99122
try:
@@ -112,26 +135,54 @@ def _binary_test(self, dtype, op, ref, inputs, is_fastpath, is_inplace, *, alpha
112135
@ops(foreach_binary_op_db)
113136
@parametrize("is_fastpath", (True, False))
114137
def test_binary_op(self, device, dtype, op, is_fastpath):
115-
for sample in op.sample_inputs(device, dtype, noncontiguous=not is_fastpath):
138+
scalar_self_arg_test_complete = False
139+
for i, sample in enumerate(op.sample_inputs(device, dtype, noncontiguous=not is_fastpath)):
116140
rhs_arg, = sample.args
117141
kwargs = {} or sample.kwargs
118142
alpha = kwargs.pop("alpha", None)
119143
disable_fastpath = kwargs.pop("disable_fastpath") if is_fastpath else False
120144
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
121145
self._binary_test(
122-
dtype, wrapped_op, ref, [sample.input, rhs_arg], is_fastpath and not disable_fastpath, False, alpha=alpha)
146+
dtype, wrapped_op, ref, [sample.input, rhs_arg],
147+
is_fastpath and not disable_fastpath, False, alpha=alpha)
123148
self._binary_test(
124-
dtype, inplace_op, inplace_ref, [sample.input, rhs_arg], is_fastpath and not disable_fastpath, True, alpha=alpha)
125-
if op.supports_scalar_self_arg and isinstance(rhs_arg, list) and isinstance(rhs_arg[0], torch.Tensor):
149+
dtype, inplace_op, inplace_ref, [sample.input, rhs_arg],
150+
is_fastpath and not disable_fastpath, True, alpha=alpha)
151+
152+
if op.supports_autograd and dtype in floating_types():
153+
transformed_sample = sample.transform(get_transform_func(len(sample.input), dtype, device, is_fastpath))
154+
tensors = transformed_sample.input
155+
rhs_arg, = transformed_sample.args
156+
ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
157+
try:
158+
sum(wrapped_op([tensors, rhs_arg], is_cuda=False, is_fastpath=False)).mean().backward()
159+
except RuntimeError:
160+
with self.assertRaises(RuntimeError):
161+
sum(ref([ref_tensors, ref_rhs_arg])).mean().backward()
162+
else:
163+
sum(ref([ref_tensors, ref_rhs_arg])).mean().backward()
164+
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
165+
if isinstance(rhs_arg, list) and isinstance(rhs_arg[0], torch.Tensor):
166+
self.assertEqual([t.grad for t in rhs_arg], [t.grad for t in ref_rhs_arg])
167+
if op.supports_scalar_self_arg and isinstance(rhs_arg, Number) and (not scalar_self_arg_test_complete):
168+
scalar_self_arg_test_complete = True
126169
self._binary_test(
127-
dtype, wrapped_op, ref, [rhs_arg, sample.input], is_fastpath and not disable_fastpath, False, alpha=alpha)
170+
dtype, wrapped_op, ref, [rhs_arg, sample.input], is_fastpath, False,
171+
alpha=alpha, scalar_self_arg=True)
172+
if op.supports_autograd and dtype == torch.float32:
173+
transformed_sample = sample.transform(
174+
get_transform_func(len(sample.input), dtype, device, is_fastpath))
175+
tensors = transformed_sample.input
176+
rhs_arg, = transformed_sample.args
177+
ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
178+
sum(wrapped_op([rhs_arg, tensors], is_cuda=False, is_fastpath=False)).mean().backward()
179+
sum([ref.func(ref_rhs_arg, t) for t in ref_tensors]).mean().backward()
180+
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
128181

129182
@ops(foreach_pointwise_op_db)
130183
@parametrize("is_fastpath", (True, False))
131184
def test_pointwise_op(self, device, dtype, op, is_fastpath):
132-
for sample in op.sample_inputs(device, dtype):
133-
if not is_fastpath:
134-
sample = sample.noncontiguous()
185+
for sample in op.sample_inputs(device, dtype, noncontiguous=not is_fastpath):
135186
assert isinstance(sample.args, tuple)
136187
assert len(sample.args) == 2
137188
inputs = [sample.input, *sample.args]
@@ -140,7 +191,27 @@ def test_pointwise_op(self, device, dtype, op, is_fastpath):
140191
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
141192
values = kwargs.pop("values")
142193
self._pointwise_test(wrapped_op, ref, inputs, is_fastpath and not disable_fastpath, False, values=values)
143-
self._pointwise_test(inplace_op, inplace_ref, inputs, is_fastpath and not disable_fastpath, True, values=values)
194+
self._pointwise_test(
195+
inplace_op, inplace_ref, inputs, is_fastpath and not disable_fastpath,
196+
True, values=values)
197+
198+
if op.supports_autograd and dtype in floating_types():
199+
transformed_sample = sample.transform(
200+
get_transform_func(len(sample.input), dtype, device, is_fastpath))
201+
tensors = transformed_sample.input
202+
rhs_arg = transformed_sample.args
203+
ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
204+
try:
205+
sum(wrapped_op([tensors, *rhs_arg], is_cuda=False, is_fastpath=False)).mean().backward()
206+
except RuntimeError:
207+
with self.assertRaises(RuntimeError):
208+
sum(ref([ref_tensors, *ref_rhs_arg])).mean().backward()
209+
else:
210+
sum(ref([ref_tensors, *ref_rhs_arg])).mean().backward()
211+
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
212+
for op_list, ref_list in zip(rhs_arg, ref_rhs_arg):
213+
if isinstance(op_list, list) and isinstance(op_list[0], torch.Tensor):
214+
self.assertEqual([t.grad for t in op_list], [t.grad for t in ref_list])
144215

145216
if is_fastpath and isinstance(values, list):
146217
sample = sample.transform(lambda t: t.clone().detach() if torch.is_tensor(t) else t)
@@ -224,24 +295,6 @@ def _inplace_unary_test(self, inplace, inplace_ref, inputs, is_fastpath):
224295
inplace_ref(copied_inputs),
225296
self.assertEqual(copied_inputs, inputs)
226297

227-
def _test_unary(self, device, dtype, opinfo, N, is_fastpath):
228-
op, ref, inplace_op, inplace_ref = self._get_funcs(opinfo, 1)
229-
inputs = opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath),
230-
# note(mkozuki): Complex inputs for `_foreach_abs` go through slowpath.
231-
if opinfo.name == "_foreach_abs" and dtype in complex_types():
232-
is_fastpath = False
233-
self._regular_unary_test(dtype, op, ref, inputs, is_fastpath)
234-
self._inplace_unary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath)
235-
236-
if opinfo.supports_autograd and dtype in floating_types():
237-
tensors = opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath, same_size=True)
238-
tensors = [t.requires_grad_() for t in tensors]
239-
ref_tensors = [t.clone().detach().requires_grad_() for t in tensors]
240-
241-
sum(op.func(tensors)).mean().backward()
242-
sum([ref.func(t) for t in ref_tensors]).mean().backward()
243-
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
244-
245298
@skipMeta
246299
@ops(foreach_unary_op_db)
247300
@parametrize("is_fastpath", (True, False))
@@ -259,19 +312,39 @@ def test_unary_op(self, device, dtype, op, is_fastpath):
259312
)
260313
self.assertEqual(ref(inputs), wrapped_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath))
261314
self._inplace_unary_test(inplace_op, inplace_ref, [sample.input], is_fastpath and not disable_fastpath)
315+
if op.supports_autograd and dtype in floating_types():
316+
num_tensors = len(sample.input)
317+
tensors = [
318+
make_tensor(
319+
(num_tensors, num_tensors), dtype=dtype, device=device,
320+
requires_grad=True, noncontiguous=not is_fastpath,
321+
)
322+
for _ in range(num_tensors)
323+
]
324+
ref_tensors = [t.clone().detach().requires_grad_() for t in tensors]
325+
sum(wrapped_op.func(tensors)).mean().backward()
326+
sum([ref.func(t) for t in ref_tensors]).mean().backward()
327+
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
262328

263329
@ops(foreach_reduce_op_db)
264330
@parametrize("is_fastpath", (True, False))
265331
def test_reduce_op(self, device, dtype, op, is_fastpath):
266-
for sample in op.sample_inputs(device, dtype):
267-
if not is_fastpath:
268-
sample = sample.noncontiguous()
332+
for sample in op.sample_inputs(device, dtype, noncontiguous=not is_fastpath):
269333
ord = sample.kwargs.pop("ord")
270334
disable_fastpath = sample.kwargs.pop("disable_fastpath", False)
271335

272336
inputs = (sample.input,)
273337
wrapped_op, ref, _, _ = self._get_funcs(op)
274338
self.assertEqual(ref(inputs, ord=ord), wrapped_op(inputs, self.is_cuda, is_fastpath and not disable_fastpath, ord=ord))
339+
if op.supports_autograd and dtype in floating_types():
340+
transformed_sample = sample.transform(get_transform_func(len(sample.input), dtype, device, is_fastpath))
341+
tensors = transformed_sample.input
342+
ref_tensors = clone(tensors)
343+
sum(wrapped_op((tensors,), False, False, ord=ord)).backward()
344+
sum(ref((ref_tensors,), ord=ord)).backward()
345+
self.assertEqual(
346+
[t.grad for t in tensors], [t.grad for t in ref_tensors],
347+
)
275348

276349
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
277350
def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
@@ -285,7 +358,6 @@ def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
285358

286359
@ops(foreach_binary_op_db, dtypes=OpDTypes.supported)
287360
def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype, op):
288-
print(op, device, dtype)
289361
foreach_op, ref = op.method_variant, op.ref
290362
tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)]
291363

@@ -533,7 +605,6 @@ def test_foreach_l2_large_value_input(self, device, dtype, op):
533605
def test_lerp(self, device, dtype, op, is_fastpath):
534606
for sample in op.sample_inputs(device, dtype, noncontiguous=not is_fastpath):
535607
wrapped_op, ref, inplace_op, _ = self._get_funcs(op)
536-
537608
args = [*sample.args]
538609
inputs = [sample.input, args[0]]
539610

@@ -559,6 +630,24 @@ def test_lerp(self, device, dtype, op, is_fastpath):
559630
inplace_actual = inplace_op(inplace_inputs, self.is_cuda, is_fastpath, **kwargs)
560631
self.assertEqual(inplace_actual, expected)
561632

633+
if op.supports_autograd and dtype in floating_types():
634+
transformed_sample = sample.transform(get_transform_func(len(sample.input), dtype, device, is_fastpath))
635+
args = [*transformed_sample.args]
636+
inputs = [transformed_sample.input, args[0]]
637+
638+
kwargs, ref_kwargs = {}, {}
639+
if isinstance(args[1], list):
640+
inputs.append(args[1])
641+
else:
642+
kwargs = ref_kwargs = {"weight": args[1]}
643+
ref_tensors = clone(transformed_sample.input)
644+
sum(wrapped_op((transformed_sample.input, *inputs[1:]), False, False, **kwargs)).mean().backward()
645+
sum(ref((ref_tensors, *inputs[1:]), **ref_kwargs)).mean().backward()
646+
self.assertEqual(
647+
[t.grad for t in transformed_sample.input], [t.grad for t in ref_tensors],
648+
msg=f"{transformed_sample.input[0].grad[:2, :2]}, {ref_tensors[0].grad[:2, :2]}"
649+
)
650+
562651
@onlyCUDA
563652
@ops(foreach_reduce_op_db)
564653
def test_foreach_reduce_large_input(self, device, dtype, op):

0 commit comments

Comments
 (0)