Skip to content

Commit f74207c

Browse files
authored
Allow autograd to work even when the shape of values cannot be determined (#8641)
This commit implements the solution proposed in #8410 to workaround the need to create zero tensors with the same shape as inputs. It introduces the concept of a LinearBlock which marks places in the code where we know if all the inputs to the node are zero, then the outputs to the node are also zero. Autodiff introduces LinearBlocks around backwards functions, which have this property. specializeUndef then propagates Undef nodes using this information. Notes: * Since we do not always specialize, we have a pass LowerLinearBlocks that replaces the block with an if statement that dynamically guards the Undef case. * We introduce AutogradAdd which is addition that still works when its inputs might be undefined. In cases where we specialize this will get removed in favor of a normal add, but there are cases where gradient graphs do not specialize (e.g. when they are not differentiable, but a derivative is required) so it is important for this op to be executable.
1 parent 7a61479 commit f74207c

16 files changed

+368
-208
lines changed

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,12 +768,14 @@ def run(self):
768768
"torch/csrc/jit/passes/dead_code_elimination.cpp",
769769
"torch/csrc/jit/passes/remove_expands.cpp",
770770
"torch/csrc/jit/passes/lower_tuples.cpp",
771+
"torch/csrc/jit/passes/lower_grad_of.cpp",
771772
"torch/csrc/jit/passes/common_subexpression_elimination.cpp",
772773
"torch/csrc/jit/passes/peephole.cpp",
773774
"torch/csrc/jit/passes/inplace_check.cpp",
774775
"torch/csrc/jit/passes/canonicalize.cpp",
775776
"torch/csrc/jit/passes/batch_mm.cpp",
776777
"torch/csrc/jit/passes/decompose_addmm.cpp",
778+
"torch/csrc/jit/passes/specialize_undef.cpp",
777779
"torch/csrc/jit/passes/erase_number_types.cpp",
778780
"torch/csrc/jit/passes/loop_unrolling.cpp",
779781
"torch/csrc/jit/passes/onnx/peephole.cpp",

test/expect/TestJit.test_cpp.expect

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,26 @@ graph(%0 : Float(2, 3, 4)
8888
%2 : Float(2, 3, 4)
8989
%3 : Float(2, 3, 4)
9090
%4 : Float(2, 3, 4)) {
91-
%5 : Float(2, 3, 4!) = prim::Constant[value=<Tensor>, is_zero=1]()
92-
%6 : Dynamic = prim::ReplaceIfUndef(%0, %5)
93-
%7 : Float(2, 3, 4!) = prim::Constant[value=<Tensor>, is_zero=1]()
94-
%8 : Dynamic = prim::ReplaceIfUndef(%1, %7)
95-
%9 : Dynamic = aten::mul(%6, %2)
96-
%10 : Dynamic = aten::add[alpha={1}](%8, %9)
97-
%11 : Dynamic = aten::mul(%6, %4)
98-
%12 : Dynamic = aten::mul(%10, %3)
99-
%13 : Dynamic = aten::mul(%10, %2)
100-
%14 : Dynamic = aten::add[alpha={1}](%11, %12)
101-
%15 : Dynamic = aten::add[alpha={1}](%6, %13)
102-
return (%14, %15);
91+
%5 : Float(2, 3, 4), %6 : Float(2, 3, 4) = prim::GradOf[name=aten::add](%0)
92+
block0() {
93+
-> (%0, %0)
94+
}
95+
%7 : Float(2, 3, 4), %8 : Float(2, 3, 4) = prim::GradOf[name=aten::mul](%5)
96+
block0() {
97+
%9 : Float(2, 3, 4) = aten::mul(%5, %2)
98+
%10 : Float(2, 3, 4) = aten::mul(%5, %4)
99+
-> (%9, %10)
100+
}
101+
%11 : Dynamic = prim::AutogradAdd(%1, %7)
102+
%12 : Float(2, 3, 4), %13 : Float(2, 3, 4) = prim::GradOf[name=aten::mul](%11)
103+
block0() {
104+
%14 : Float(2, 3, 4) = aten::mul(%11, %3)
105+
%15 : Float(2, 3, 4) = aten::mul(%11, %2)
106+
-> (%14, %15)
107+
}
108+
%16 : Dynamic = prim::AutogradAdd(%8, %12)
109+
%17 : Dynamic = prim::AutogradAdd(%6, %13)
110+
return (%16, %17);
103111
}
104112

105113
testDifferentiateWithRequiresGrad
@@ -116,14 +124,22 @@ graph(%0 : Float(2, 3, 4)
116124
%1 : Float(2, 3, 4)
117125
%2 : Float(2, 3, 4)
118126
%3 : Float(2, 3, 4)) {
119-
%4 : Float(2, 3, 4!) = prim::Constant[value=<Tensor>, is_zero=1]()
120-
%5 : Dynamic = prim::ReplaceIfUndef(%0, %4)
121-
%6 : Float(2, 3, 4!) = prim::Constant[value=<Tensor>, is_zero=1]()
122-
%7 : Dynamic = prim::ReplaceIfUndef(%1, %6)
123-
%8 : Dynamic = aten::mul(%5, %2)
124-
%9 : Dynamic = aten::add[alpha={1}](%7, %8)
125-
%10 : Dynamic = aten::mul(%5, %3)
126-
%11 : Dynamic = aten::add[alpha={1}](%10, %9)
127-
return (%11);
127+
%4 : Float(2, 3, 4), %5 : Float(2, 3, 4) = prim::GradOf[name=aten::add](%0)
128+
block0() {
129+
-> (%0, %0)
130+
}
131+
%6 : Float(2, 3, 4), %7 : Float(2, 3, 4) = prim::GradOf[name=aten::mul](%4)
132+
block0() {
133+
%8 : Float(2, 3, 4) = aten::mul(%4, %2)
134+
%9 : Float(2, 3, 4) = aten::mul(%4, %3)
135+
-> (%8, %9)
136+
}
137+
%10 : Dynamic = prim::AutogradAdd(%1, %6)
138+
%11 : Float(2, 3, 4), %12 : Float(2, 3, 4) = prim::GradOf[name=aten::add](%10)
139+
block0() {
140+
-> (%10, %10)
141+
}
142+
%13 : Dynamic = prim::AutogradAdd(%7, %12)
143+
return (%13);
128144
}
129145

test/test_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def allSum(vs):
221221
for g2, g2_ge in zip(grads2, grads2_ge):
222222
if g2 is None and g2_ge is None:
223223
continue
224-
self.assertTrue(torch.allclose(g2, g2_ge, atol=5e-4, rtol=1e-4))
224+
self.assertTrue(torch.allclose(g2, g2_ge, atol=7e-4, rtol=1e-4))
225225

226226
return ge
227227

torch/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,14 @@ set(TORCH_SRCS
218218
${TORCH_SRC_DIR}/csrc/jit/passes/dead_code_elimination.cpp
219219
${TORCH_SRC_DIR}/csrc/jit/passes/erase_number_types.cpp
220220
${TORCH_SRC_DIR}/csrc/jit/passes/lower_tuples.cpp
221+
${TORCH_SRC_DIR}/csrc/jit/passes/lower_grad_of.cpp
221222
${TORCH_SRC_DIR}/csrc/jit/passes/peephole.cpp
222223
${TORCH_SRC_DIR}/csrc/jit/passes/inplace_check.cpp
223224
${TORCH_SRC_DIR}/csrc/jit/passes/batch_mm.cpp
224225
${TORCH_SRC_DIR}/csrc/jit/passes/create_autodiff_subgraphs.cpp
225226
${TORCH_SRC_DIR}/csrc/jit/passes/remove_expands.cpp
226227
${TORCH_SRC_DIR}/csrc/jit/passes/decompose_addmm.cpp
228+
${TORCH_SRC_DIR}/csrc/jit/passes/specialize_undef.cpp
227229
${TORCH_SRC_DIR}/csrc/jit/passes/loop_unrolling.cpp
228230
${TORCH_SRC_DIR}/csrc/jit/interned_strings.cpp
229231
${TORCH_SRC_DIR}/csrc/jit/script/compiler.cpp

torch/csrc/jit/autodiff.cpp

Lines changed: 55 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ bool isDifferentiable(Node * n) {
2020
aten::add, aten::sub, aten::mul, prim::Constant, prim::ReplaceIfUndef,
2121
aten::sigmoid, aten::tanh, aten::mm, aten::chunk, aten::split, aten::t, aten::neg,
2222
aten::unsqueeze, aten::expand, aten::addmm, aten::gt, aten::lt, aten::eq, aten::ne, aten::ge, aten::le, aten::type_as,
23-
aten::relu, aten::exp
23+
aten::relu, aten::exp, prim::AutogradAdd
2424
};
2525
// TODO: check this more generally via schema
2626
// This check ensures that the `alpha` and `beta` attributes on this addmm
@@ -34,6 +34,17 @@ bool isDifferentiable(Node * n) {
3434
if (n->kind() == aten::type_as && !n->inputs().at(1)->isTensor()) {
3535
return false;
3636
}
37+
38+
// linear blocks may appear as inputs to graph executors, but they are removed
39+
// before differentiation occurs
40+
if (n->kind() == prim::GradOf) {
41+
auto body = n->blocks().at(0);
42+
return std::all_of(
43+
body->nodes().begin(),
44+
body->nodes().end(),
45+
static_cast<bool (*)(Node*)>(isDifferentiable));
46+
}
47+
3748
return differentiable_kinds.count(n->kind()) > 0;
3849
}
3950

@@ -194,17 +205,10 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
194205
throw std::runtime_error(std::string("don't support differentiation of `") +
195206
node->kind().toDisplayString() + "`");
196207
};
197-
const auto has_tensor_type = [](Value *v) { return v->isTensor(); };
198208
if (!isDifferentiable(node)) {
199209
throw std::runtime_error(std::string("differentiation of ") + node->kind().toDisplayString() + " "
200210
"is not supported, or it is missing necessary type information");
201211
}
202-
if (!std::all_of(node->inputs().begin(), node->inputs().end(), has_tensor_type) ||
203-
!std::all_of(node->outputs().begin(), node->outputs().end(), has_tensor_type)) {
204-
throw std::runtime_error("differentiate should be called with a graph where every value "
205-
"has a type registered");
206-
207-
}
208212
auto sym_grads = build_sym_grad(fmap<SymbolicVariable>(grad_values));
209213
return fmap(sym_grads, [](const SymbolicVariable &v) { return v.value(); });
210214
}
@@ -230,30 +234,35 @@ static value_set findAllRequiresGradNodes(
230234
return requires_grad_set;
231235
}
232236

233-
static Value* createZerosLike(Value *v) {
234-
JIT_EXPECTM(v->isTensor(), "can't allocate zero gradient for a value without a type");
235-
Graph *graph = v->owningGraph();
236-
auto type = v->type()->expect<TensorType>();
237-
at::DeviceGuard device_guard(type->device());
238-
239-
auto & at_type = type->device() == -1 ? at::CPU(type->scalarType()) : at::CUDA(type->scalarType());
240-
auto zeros = at::zeros({}, at_type).expand(type->sizes());
241-
Node *constant = graph->createConstant(zeros)
242-
->i_(attr::is_zero, 1);
243-
graph->insertNode(constant);
244-
return constant->output();
245-
}
246237

247-
// any vjp input may be undefined, and we need to potentially replace it
248-
// with a zero tensor of the right size if required.
249-
// this function inserts a guard into the graph that does this replacement.
250-
// ReplaceIfUndef(dv,c) replaces dv with c if dv is undef.
251-
// During Graph specialization these guards will get removed when
252-
// 'dv' is known to be undef, and the zeros will be propagated if possible.
253-
static Value* createUndefGuard(Value * dv, Value * alternative) {
254-
Graph* graph = dv->owningGraph();
255-
Node * n = graph->create(prim::ReplaceIfUndef, {dv, alternative});
256-
return graph->insertNode(n)->output();
238+
// If we have a function y = f(x) with jacobian J, the backwards of f is dx = J^t dy.
239+
// Note that because the backwards always implements this matrix multiply,
240+
// we know that it maps an input vector of zeros to an output vector of zero
241+
// regardless of what operations it choses to do inside to actually implement
242+
// the matrix multiply (most use some optimized form and never generate J^t).
243+
// More generally, we know that all of the backward computations are linear and
244+
// can use this property to do more aggressive optimizations later.
245+
// It is ok to replace any backward function with known-zero inputs with something
246+
// that produces known-zero outputs. This function encloses each know-linear
247+
// backward function in a 'GradOf' sub-block so that we can perform optimizations
248+
// using this information. In particular, specializeUndef will observe if
249+
// all the inputs to the linear block are Undef, which the autograd uses to represent
250+
// zeros, and then propagate the undefs to the outputs of the block.
251+
static std::vector<Value*> linearGradientForNode(Node* node, ArrayRef<Value*> grad_values) {
252+
auto & graph = *node->owningGraph();
253+
auto linear = graph.insertNode(graph.create(prim::GradOf, {grad_values}, 0));
254+
// to make reading gradient graphs easier, remember the name of the forward op
255+
linear->s_(attr::name, node->kind().toDisplayString());
256+
auto block = linear->addBlock();
257+
{
258+
WithInsertPoint guard(block);
259+
auto results = gradientForNode(node, grad_values);
260+
for(auto r : results) {
261+
block->registerOutput(r);
262+
linear->addOutput()->copyMetadata(r);
263+
}
264+
}
265+
return linear->outputs();
257266
}
258267

259268
struct ReverseDetails {
@@ -267,6 +276,16 @@ struct ReverseDetails {
267276
Block * reverse_block;
268277
};
269278

279+
// AutogradAdd is a special addition function that handles Undef
280+
// AutogradAdd(a, b) == a + b if defined(a) and defined(b)
281+
// AutogradAdd(Undef, b) == b
282+
// AutogradAdd(a, Undef) == a
283+
// AutogradAdd(Undef, Undef) == Undef
284+
static Value* createAutogradAdd(Value* a, Value* b) {
285+
auto graph = a->owningGraph();
286+
return graph->insertNode(graph->create(prim::AutogradAdd, {a, b}))->output();
287+
}
288+
270289
// Before:
271290
// - grad_desc has field f initialized to the original 0-stage graph
272291
// After:
@@ -291,13 +310,14 @@ static ReverseDetails addReverseInline(Gradient& grad_desc,
291310
const auto get_grad = [&](Value* v) -> Value* {
292311
auto it = grad_map.find(v);
293312
if (it == grad_map.end()) {
294-
std::tie(it, std::ignore) = grad_map.emplace(v, createZerosLike(v));
313+
auto undef = graph.insertNode(graph.createUndefined());
314+
std::tie(it, std::ignore) = grad_map.emplace(v, undef->output());
295315
}
296316
return it->second;
297317
};
298318
const auto set_grad = [&](Value *x, Value *dx) {
299319
if (Value * prev_grad = grad_map[x]) {
300-
grad_map[x] = toVar(prev_grad) + toVar(dx);
320+
grad_map[x] = createAutogradAdd(prev_grad, dx);
301321
} else {
302322
grad_map[x] = dx;
303323
}
@@ -309,7 +329,6 @@ static ReverseDetails addReverseInline(Gradient& grad_desc,
309329
if (!requires_grad(output))
310330
continue;
311331
Value * output_grad = reverse_block->addInput()->setType(output->type());
312-
output_grad = createUndefGuard(output_grad, createZerosLike(output));
313332
set_grad(output, output_grad);
314333
grad_desc.df_input_vjps.push_back(i);
315334
}
@@ -319,7 +338,7 @@ static ReverseDetails addReverseInline(Gradient& grad_desc,
319338
auto inputs = node->inputs();
320339
if (!outputRequiresGrad(node, requires_grad)) continue;
321340

322-
value_list grad_inputs = gradientForNode(node, fmap(node->outputs(), get_grad));
341+
value_list grad_inputs = linearGradientForNode(node, fmap(node->outputs(), get_grad));
323342
JIT_ASSERT(grad_inputs.size() == node->inputs().size());
324343
for (size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) {
325344
set_grad(inputs[i], grad_inputs[i]);
@@ -337,45 +356,6 @@ static ReverseDetails addReverseInline(Gradient& grad_desc,
337356
return ReverseDetails(std::move(grad_map), std::move(requires_grad_set), reverse_block);
338357
}
339358

340-
bool isZero(Value * v) {
341-
auto n = v->node();
342-
return n->kind() == prim::Constant &&
343-
n->hasAttribute(attr::is_zero) &&
344-
n->i(attr::is_zero);
345-
}
346-
347-
// In the case where an input is routed to an output
348-
// return the (possibly undefined) input rather than
349-
// the value guarded by replaceIfUndef
350-
// this ensures that we do not produce a 0 tensor
351-
// when the autograd would produce None
352-
// graph(a) {
353-
// b = replaceIfUndef(a,0);
354-
// c = b + b
355-
// return c, b; // will replace 'b' with 'a'
356-
// }
357-
// Also replace any known-to-be-zero outputs with Undef
358-
// for the same reason
359-
360-
static void passthroughUndefs(std::shared_ptr<Graph> graph) {
361-
bool changed = false;
362-
for(size_t i = 0; i < graph->outputs().size(); i++) {
363-
Value * v = graph->outputs()[i];
364-
if(v->node()->kind() == prim::ReplaceIfUndef) {
365-
graph->return_node()->replaceInput(i, v->node()->inputs()[0]);
366-
changed = true;
367-
} else if(isZero(v)) {
368-
auto undef = graph->insertNode(graph->createUndefined());
369-
graph->return_node()->replaceInput(i, undef->output());
370-
changed = true;
371-
}
372-
}
373-
// handle cases where replaceIfUndef or constants has become dead
374-
if(changed)
375-
EliminateDeadCode(graph);
376-
377-
}
378-
379359
// Takes a grad_desc.f returned from `addReverseInline` and splits off the
380360
// reverse_block into its own graph, storing it in df.
381361
// All intermediates needed in the second stage are added to
@@ -469,15 +449,10 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
469449
if (rev_info.requires_grad_set.count(tmp) == 0) continue;
470450
Value * tmp_vjp_in = reverse_block->addInput()->setType(tmp->type());
471451
Value * tmp_vjp_prev = rev_info.grad_map.at(tmp);
472-
{
473-
WithInsertPoint guard(tmp_vjp_prev->node());
474-
auto zeroes = createZerosLike(tmp);
475-
tmp_vjp_in = createUndefGuard(tmp_vjp_in, zeroes);
476-
}
477452
// This is quite weird because we can't first make a sum and then replace all uses
478453
// of tmp_vjp_prev (that would replace its use in the sum too!), so we create an
479454
// incorrect sum that doesn't use prev vjp, replace uses, and fix the sum.
480-
Value * new_vjp = toVar(tmp_vjp_in) + toVar(tmp_vjp_in);
455+
Value * new_vjp = createAutogradAdd(tmp_vjp_in, tmp_vjp_in);
481456
new_vjp->node()->moveAfter(tmp_vjp_prev->node());
482457
tmp_vjp_prev->replaceAllUsesWith(new_vjp);
483458
new_vjp->node()->replaceInput(1, tmp_vjp_prev);
@@ -526,7 +501,6 @@ Gradient differentiate(std::shared_ptr<Graph>& _graph, const std::vector<bool>&
526501
// Fills in f, df, f_real_outputs, df_input_captures,
527502
// modifies df_input_vjps (new vjps are added for temporaries)
528503
lambdaLiftReverse(grad_desc, rev_info);
529-
passthroughUndefs(grad_desc.df);
530504
return grad_desc;
531505
}
532506

0 commit comments

Comments
 (0)