Skip to content

Commit 3087622

Browse files
crcrparpytorchmergebot
authored andcommitted
[mta] Backward of unary foreach functions (#89591)
as per title, this PR defines backward of those. This doesn't implement forward-mode automatic differentiation as [the current codegen](https://github.com/pytorch/pytorch/blob/a747326423ed4731996769e3b8eb73eecbdee2d4/tools/autograd/gen_variable_type.py#L1513) doesn't seem to handle `ArrayRef<Tensor>`. Rel: - #53796 - #58833 Pull Request resolved: #89591 Approved by: https://github.com/albanD
1 parent 32b2d80 commit 3087622

File tree

10 files changed

+292
-16
lines changed

10 files changed

+292
-16
lines changed

aten/src/ATen/native/cuda/ForeachFunctors.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@ namespace at { namespace native {
77

88
namespace {
99

10+
// TODO(crcrpar): Handle version bump in codegen.
11+
// rel: https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482
12+
inline void increment_version(TensorList tensors) {
13+
for (const auto & t : tensors) {
14+
t.unsafeGetTensorImpl()->bump_version();
15+
}
16+
}
17+
1018
// Initializes args and checks if all args are aligned
1119
template<int depth, typename T>
1220
__device__ bool init_args(

aten/src/ATen/native/cuda/ForeachUnaryOp.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ template <typename scalar_t, template<class> class Op> void foreach_unary_op_(Te
7373
/* r_args_depth */ 1,
7474
/* res_arg_index */ 0>(),
7575
Op<opmath_t>());
76+
increment_version(tensors);
7677
}
7778

7879
template <template<class> class Op>

docs/source/torch.rst

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,73 @@ BLAS and LAPACK Operations
597597
triangular_solve
598598
vdot
599599

600+
Foreach Operations
601+
~~~~~~~~~~~~~~~~~~
602+
603+
.. warning::
604+
This API is in beta and subject to future changes.
605+
Forward-mode AD is not supported.
606+
607+
.. autosummary::
608+
:toctree: generated
609+
:nosignatures:
610+
611+
_foreach_abs
612+
_foreach_abs_
613+
_foreach_acos
614+
_foreach_acos_
615+
_foreach_asin
616+
_foreach_asin_
617+
_foreach_atan
618+
_foreach_atan_
619+
_foreach_ceil
620+
_foreach_ceil_
621+
_foreach_cos
622+
_foreach_cos_
623+
_foreach_cosh
624+
_foreach_cosh_
625+
_foreach_erf
626+
_foreach_erf_
627+
_foreach_erfc
628+
_foreach_erfc_
629+
_foreach_exp
630+
_foreach_exp_
631+
_foreach_expm1
632+
_foreach_expm1_
633+
_foreach_floor
634+
_foreach_floor_
635+
_foreach_log
636+
_foreach_log_
637+
_foreach_log10
638+
_foreach_log10_
639+
_foreach_log1p
640+
_foreach_log1p_
641+
_foreach_log2
642+
_foreach_log2_
643+
_foreach_neg
644+
_foreach_neg_
645+
_foreach_tan
646+
_foreach_tan_
647+
_foreach_sin
648+
_foreach_sin_
649+
_foreach_sinh
650+
_foreach_sinh_
651+
_foreach_round
652+
_foreach_round_
653+
_foreach_sqrt
654+
_foreach_sqrt_
655+
_foreach_lgamma
656+
_foreach_lgamma_
657+
_foreach_frac
658+
_foreach_frac_
659+
_foreach_reciprocal
660+
_foreach_reciprocal_
661+
_foreach_sigmoid
662+
_foreach_sigmoid_
663+
_foreach_trunc
664+
_foreach_trunc_
665+
_foreach_zero_
666+
600667
Utilities
601668
----------------------------------
602669
.. autosummary::

test/test_foreach.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,24 @@ def _inplace_unary_test(self, inplace, inplace_ref, inputs, is_fastpath):
222222
inplace_ref(copied_inputs),
223223
self.assertEqual(copied_inputs, inputs)
224224

225+
def _test_unary(self, device, dtype, opinfo, N, is_fastpath):
226+
op, ref, inplace_op, inplace_ref = self._get_funcs(opinfo, 1)
227+
inputs = opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath),
228+
# note(mkozuki): Complex inputs for `_foreach_abs` go through slowpath.
229+
if opinfo.name == "_foreach_abs" and dtype in complex_types():
230+
is_fastpath = False
231+
self._regular_unary_test(dtype, op, ref, inputs, is_fastpath)
232+
self._inplace_unary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath)
233+
234+
if opinfo.supports_autograd and dtype in floating_types():
235+
tensors = opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath, same_size=True)
236+
tensors = [t.requires_grad_() for t in tensors]
237+
ref_tensors = [t.clone().detach().requires_grad_() for t in tensors]
238+
239+
sum(op.func(tensors)).mean().backward()
240+
sum([ref.func(t) for t in ref_tensors]).mean().backward()
241+
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
242+
225243
@skipMeta
226244
@ops(foreach_unary_op_db)
227245
@parametrize("is_fastpath", (True, False))

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
# to that argument could exist. You should either:
4343
# - Specify the formula for that gradient
4444
# - Specify not_implemented("function_name") as a formula to say that this is not
45-
# implement yet (but might be in the future and the user can request that on an issue)
45+
# implemented yet (but might be in the future and the user can request that on an issue)
4646
# - If that argument is not differentiable, because it is not a floating point dtype or the
4747
# function is not differentiable with respect to that argument for
4848
# example. You should either:

tools/autograd/gen_autograd_functions.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,23 @@
9898
"""
9999
)
100100

101+
# note(crcrpar): `self` argument and other optional positional argument
102+
# of foreach functions are basically a list of n `Tensor`s thus iterating over
103+
# `grads` in order to utilize and apply the existing derivative definitions
104+
# to each `Tensor`(s) of `self`, and the others.
105+
DERIVATIVE_SINGLE_FOREACH = CodeTemplate(
106+
"""\
107+
if (task_should_compute_output({ ${name}_ix })) {
108+
std::vector<Tensor> grad_result;
109+
grad_result.reserve(grads.size());
110+
for (const auto & i : c10::irange(grads.size())) {
111+
grad_result.emplace_back(${derivative});
112+
}
113+
copy_range(grad_inputs, ${name}_ix, grad_result);
114+
}
115+
"""
116+
)
117+
101118
DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
102119
"""\
103120
if (task_should_compute_output({ ${name}_ix })) {
@@ -709,9 +726,13 @@ def emit_derivative(
709726
) in ("Tensor", "Tensor?"):
710727
formula = "any_grad_defined ? (" + formula + ") : Tensor()"
711728
checks_any_grad_defined = True
729+
if info.name.startswith("_foreach_"):
730+
derivative_template = DERIVATIVE_SINGLE_FOREACH
731+
else:
732+
derivative_template = DERIVATIVE_SINGLE
712733
return (
713734
checks_any_grad_defined,
714-
DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula),
735+
derivative_template.substitute(name=var_names[0], derivative=formula),
715736
)
716737
else:
717738
if "grad_input_mask" in formula:

torch/_torch_docs.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14003,3 +14003,59 @@ def merge_dicts(*dicts):
1400314003
are freshly created instead of aliasing the input.
1400414004
""",
1400514005
)
14006+
14007+
for unary_base_func_name in (
14008+
"exp",
14009+
"sqrt",
14010+
"abs",
14011+
"acos",
14012+
"asin",
14013+
"atan",
14014+
"ceil",
14015+
"cos",
14016+
"cosh",
14017+
"erf",
14018+
"erfc",
14019+
"expm1",
14020+
"floor",
14021+
"log",
14022+
"log10",
14023+
"log1p",
14024+
"log2",
14025+
"neg",
14026+
"tan",
14027+
"tanh",
14028+
"sin",
14029+
"sinh",
14030+
"round",
14031+
"lgamma",
14032+
"frac",
14033+
"reciprocal",
14034+
"sigmoid",
14035+
"trunc",
14036+
"zero",
14037+
):
14038+
unary_foreach_func_name = f"_foreach_{unary_base_func_name}"
14039+
if hasattr(torch, unary_foreach_func_name):
14040+
add_docstr(
14041+
getattr(torch, unary_foreach_func_name),
14042+
r"""
14043+
{}(self: List[Tensor]) -> List[Tensor]
14044+
14045+
Apply :func:`torch.{}` to each Tensor of the input list.
14046+
""".format(
14047+
unary_foreach_func_name, unary_base_func_name
14048+
),
14049+
)
14050+
unary_inplace_foreach_func_name = f"{unary_foreach_func_name}_"
14051+
if hasattr(torch, unary_inplace_foreach_func_name):
14052+
add_docstr(
14053+
getattr(torch, unary_inplace_foreach_func_name),
14054+
r"""
14055+
{}(self: List[Tensor]) -> None
14056+
14057+
Apply :func:`torch.{}` to each Tensor of the input list.
14058+
""".format(
14059+
unary_inplace_foreach_func_name, unary_base_func_name
14060+
),
14061+
)

torch/testing/_internal/common_methods_invocations.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8074,109 +8074,122 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
80748074

80758075

80768076
foreach_unary_op_db: List[OpInfo] = [
8077-
ForeachFuncInfo('exp', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8078-
ForeachFuncInfo('acos', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8079-
ForeachFuncInfo('asin', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8080-
ForeachFuncInfo('atan', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8081-
ForeachFuncInfo('cos', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8082-
ForeachFuncInfo('cosh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8083-
ForeachFuncInfo('log', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8084-
ForeachFuncInfo('log10', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8085-
ForeachFuncInfo('log2', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8086-
ForeachFuncInfo('tan', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8087-
ForeachFuncInfo('tanh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8088-
ForeachFuncInfo('sin', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8089-
ForeachFuncInfo('sinh', sample_inputs_func=foreach_inputs_sample_func(1, False, False)),
8077+
ForeachFuncInfo('exp', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
8078+
ForeachFuncInfo('acos', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
8079+
ForeachFuncInfo('asin', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
8080+
ForeachFuncInfo('atan', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
8081+
ForeachFuncInfo('cos', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
8082+
ForeachFuncInfo('cosh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
8083+
ForeachFuncInfo('log', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
8084+
ForeachFuncInfo('log10', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
8085+
ForeachFuncInfo('log2', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
8086+
ForeachFuncInfo('tan', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
8087+
ForeachFuncInfo('tanh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
8088+
ForeachFuncInfo('sin', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
8089+
ForeachFuncInfo('sinh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
80908090

80918091
ForeachFuncInfo(
80928092
'neg',
80938093
dtypes=all_types_and_complex(),
80948094
dtypesIfCUDA=all_types_and_complex(),
80958095
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8096+
supports_autograd=True,
80968097
),
80978098

80988099
ForeachFuncInfo(
80998100
'sqrt',
81008101
dtypes=floating_and_complex_types_and(torch.bfloat16),
81018102
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
81028103
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8104+
supports_autograd=True,
81038105
),
81048106

81058107
ForeachFuncInfo(
81068108
'ceil',
81078109
dtypes=all_types_and(torch.bfloat16),
81088110
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
81098111
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8112+
supports_autograd=True,
81108113
),
81118114

81128115
ForeachFuncInfo(
81138116
'erf',
81148117
dtypes=floating_types_and(torch.bfloat16),
81158118
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
81168119
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8120+
supports_autograd=True,
81178121
),
81188122

81198123
ForeachFuncInfo(
81208124
'erfc',
81218125
dtypes=floating_types_and(torch.bfloat16),
81228126
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
81238127
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8128+
supports_autograd=True,
81248129
),
81258130

81268131
ForeachFuncInfo(
81278132
'expm1',
81288133
dtypes=floating_types_and(torch.bfloat16),
81298134
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
81308135
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8136+
supports_autograd=True,
81318137
),
81328138

81338139
ForeachFuncInfo(
81348140
'floor',
81358141
dtypes=all_types_and(torch.bfloat16),
81368142
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
81378143
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8144+
supports_autograd=True,
81388145
),
81398146

81408147
ForeachFuncInfo(
81418148
'log1p',
81428149
dtypes=floating_and_complex_types_and(torch.bfloat16),
81438150
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
81448151
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8152+
supports_autograd=True,
81458153
),
81468154

81478155
ForeachFuncInfo(
81488156
'round',
81498157
dtypes=all_types_and(torch.bfloat16),
81508158
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
81518159
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8160+
supports_autograd=True,
81528161
),
81538162

81548163
ForeachFuncInfo(
81558164
'frac',
81568165
dtypes=floating_types_and(torch.bfloat16),
81578166
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
81588167
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8168+
supports_autograd=True,
81598169
),
81608170

81618171
ForeachFuncInfo(
81628172
'reciprocal',
81638173
dtypes=floating_types_and(torch.bfloat16),
81648174
dtypesIfCUDA=floating_types_and(torch.half),
81658175
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8176+
supports_autograd=True,
81668177
),
81678178

81688179
ForeachFuncInfo(
81698180
'sigmoid',
81708181
dtypes=floating_types_and(torch.bfloat16),
81718182
dtypesIfCUDA=floating_types_and(torch.half),
81728183
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8184+
supports_autograd=True,
81738185
),
81748186

81758187
ForeachFuncInfo(
81768188
'trunc',
81778189
dtypes=all_types_and(torch.bfloat16),
81788190
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
81798191
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8192+
supports_autograd=True,
81808193
),
81818194

81828195
ForeachFuncInfo(
@@ -8186,6 +8199,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
81868199
supports_forward_ad=True,
81878200
supports_fwgrad_bwgrad=True,
81888201
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
8202+
supports_autograd=True,
81898203
),
81908204
]
81918205

torch/testing/_internal/opinfo/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2571,6 +2571,7 @@ def __init__(
25712571
dtypesIfROCM=None,
25722572
supports_alpha_param=False,
25732573
sample_inputs_func=sample_inputs_foreach,
2574+
supports_autograd=False,
25742575
**kwargs,
25752576
):
25762577
super().__init__(
@@ -2579,6 +2580,7 @@ def __init__(
25792580
dtypesIfCUDA=dtypesIfCUDA,
25802581
dtypesIfROCM=dtypesIfROCM,
25812582
sample_inputs_func=sample_inputs_func,
2583+
supports_autograd=supports_autograd,
25822584
**kwargs,
25832585
)
25842586

0 commit comments

Comments
 (0)