Commit bed250e
committed
[generate_vmap_rule] Add generate_vmap_rule to autograd.Function
Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit
This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).
Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation
Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests
Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.
ghstack-source-id: 6905e60
Pull Request resolved: #909661 parent 53cb80d commit bed250e
File tree
7 files changed
+497
-21
lines changed- aten/src/ATen/functorch
- test/functorch
- torch
- _C
- _functorch
- autograd
- testing/_internal
7 files changed
+497
-21
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
| 7 | + | |
7 | 8 | | |
8 | 9 | | |
9 | 10 | | |
| |||
88 | 89 | | |
89 | 90 | | |
90 | 91 | | |
91 | | - | |
92 | | - | |
| 92 | + | |
| 93 | + | |
93 | 94 | | |
94 | | - | |
| 95 | + | |
95 | 96 | | |
96 | 97 | | |
97 | 98 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1116 | 1116 | | |
1117 | 1117 | | |
1118 | 1118 | | |
1119 | | - | |
| 1119 | + | |
1120 | 1120 | | |
1121 | 1121 | | |
1122 | 1122 | | |
| |||
1136 | 1136 | | |
1137 | 1137 | | |
1138 | 1138 | | |
| 1139 | + | |
| 1140 | + | |
| 1141 | + | |
| 1142 | + | |
| 1143 | + | |
| 1144 | + | |
| 1145 | + | |
| 1146 | + | |
| 1147 | + | |
| 1148 | + | |
| 1149 | + | |
| 1150 | + | |
| 1151 | + | |
| 1152 | + | |
| 1153 | + | |
| 1154 | + | |
| 1155 | + | |
| 1156 | + | |
| 1157 | + | |
| 1158 | + | |
| 1159 | + | |
| 1160 | + | |
| 1161 | + | |
| 1162 | + | |
| 1163 | + | |
| 1164 | + | |
| 1165 | + | |
1139 | 1166 | | |
1140 | 1167 | | |
1141 | 1168 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1349 | 1349 | | |
1350 | 1350 | | |
1351 | 1351 | | |
| 1352 | + | |
| 1353 | + | |
1352 | 1354 | | |
1353 | 1355 | | |
1354 | 1356 | | |
| |||
1517 | 1519 | | |
1518 | 1520 | | |
1519 | 1521 | | |
| 1522 | + | |
| 1523 | + | |
| 1524 | + | |
1520 | 1525 | | |
1521 | 1526 | | |
1522 | 1527 | | |
| |||
1962 | 1967 | | |
1963 | 1968 | | |
1964 | 1969 | | |
1965 | | - | |
1966 | | - | |
1967 | | - | |
| 1970 | + | |
| 1971 | + | |
| 1972 | + | |
1968 | 1973 | | |
1969 | 1974 | | |
1970 | 1975 | | |
| |||
1982 | 1987 | | |
1983 | 1988 | | |
1984 | 1989 | | |
| 1990 | + | |
1985 | 1991 | | |
1986 | 1992 | | |
1987 | 1993 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
| 19 | + | |
19 | 20 | | |
20 | 21 | | |
21 | 22 | | |
| |||
0 commit comments