@@ -20,7 +20,7 @@ bool isDifferentiable(Node * n) {
2020 aten::add, aten::sub, aten::mul, prim::Constant, prim::ReplaceIfUndef,
2121 aten::sigmoid, aten::tanh, aten::mm, aten::chunk, aten::split, aten::t, aten::neg,
2222 aten::unsqueeze, aten::expand, aten::addmm, aten::gt, aten::lt, aten::eq, aten::ne, aten::ge, aten::le, aten::type_as,
23- aten::relu, aten::exp
23+ aten::relu, aten::exp, prim::AutogradAdd
2424 };
2525 // TODO: check this more generally via schema
2626 // This check ensures that the `alpha` and `beta` attributes on this addmm
@@ -34,6 +34,17 @@ bool isDifferentiable(Node * n) {
3434 if (n->kind () == aten::type_as && !n->inputs ().at (1 )->isTensor ()) {
3535 return false ;
3636 }
37+
38+ // linear blocks may appear as inputs to graph executors, but they are removed
39+ // before differentiation occurs
40+ if (n->kind () == prim::GradOf) {
41+ auto body = n->blocks ().at (0 );
42+ return std::all_of (
43+ body->nodes ().begin (),
44+ body->nodes ().end (),
45+ static_cast <bool (*)(Node*)>(isDifferentiable));
46+ }
47+
3748 return differentiable_kinds.count (n->kind ()) > 0 ;
3849}
3950
@@ -194,17 +205,10 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
194205 throw std::runtime_error (std::string (" don't support differentiation of `" ) +
195206 node->kind ().toDisplayString () + " `" );
196207 };
197- const auto has_tensor_type = [](Value *v) { return v->isTensor (); };
198208 if (!isDifferentiable (node)) {
199209 throw std::runtime_error (std::string (" differentiation of " ) + node->kind ().toDisplayString () + " "
200210 " is not supported, or it is missing necessary type information" );
201211 }
202- if (!std::all_of (node->inputs ().begin (), node->inputs ().end (), has_tensor_type) ||
203- !std::all_of (node->outputs ().begin (), node->outputs ().end (), has_tensor_type)) {
204- throw std::runtime_error (" differentiate should be called with a graph where every value "
205- " has a type registered" );
206-
207- }
208212 auto sym_grads = build_sym_grad (fmap<SymbolicVariable>(grad_values));
209213 return fmap (sym_grads, [](const SymbolicVariable &v) { return v.value (); });
210214}
@@ -230,30 +234,35 @@ static value_set findAllRequiresGradNodes(
230234 return requires_grad_set;
231235}
232236
233- static Value* createZerosLike (Value *v) {
234- JIT_EXPECTM (v->isTensor (), " can't allocate zero gradient for a value without a type" );
235- Graph *graph = v->owningGraph ();
236- auto type = v->type ()->expect <TensorType>();
237- at::DeviceGuard device_guard (type->device ());
238-
239- auto & at_type = type->device () == -1 ? at::CPU (type->scalarType ()) : at::CUDA (type->scalarType ());
240- auto zeros = at::zeros ({}, at_type).expand (type->sizes ());
241- Node *constant = graph->createConstant (zeros)
242- ->i_ (attr::is_zero, 1 );
243- graph->insertNode (constant);
244- return constant->output ();
245- }
246237
247- // any vjp input may be undefined, and we need to potentially replace it
248- // with a zero tensor of the right size if required.
249- // this function inserts a guard into the graph that does this replacement.
250- // ReplaceIfUndef(dv,c) replaces dv with c if dv is undef.
251- // During Graph specialization these guards will get removed when
252- // 'dv' is known to be undef, and the zeros will be propagated if possible.
253- static Value* createUndefGuard (Value * dv, Value * alternative) {
254- Graph* graph = dv->owningGraph ();
255- Node * n = graph->create (prim::ReplaceIfUndef, {dv, alternative});
256- return graph->insertNode (n)->output ();
238+ // If we have a function y = f(x) with jacobian J, the backwards of f is dx = J^t dy.
239+ // Note that because the backwards always implements this matrix multiply,
240+ // we know that it maps an input vector of zeros to an output vector of zero
241+ // regardless of what operations it choses to do inside to actually implement
242+ // the matrix multiply (most use some optimized form and never generate J^t).
243+ // More generally, we know that all of the backward computations are linear and
244+ // can use this property to do more aggressive optimizations later.
245+ // It is ok to replace any backward function with known-zero inputs with something
246+ // that produces known-zero outputs. This function encloses each know-linear
247+ // backward function in a 'GradOf' sub-block so that we can perform optimizations
248+ // using this information. In particular, specializeUndef will observe if
249+ // all the inputs to the linear block are Undef, which the autograd uses to represent
250+ // zeros, and then propagate the undefs to the outputs of the block.
251+ static std::vector<Value*> linearGradientForNode (Node* node, ArrayRef<Value*> grad_values) {
252+ auto & graph = *node->owningGraph ();
253+ auto linear = graph.insertNode (graph.create (prim::GradOf, {grad_values}, 0 ));
254+ // to make reading gradient graphs easier, remember the name of the forward op
255+ linear->s_ (attr::name, node->kind ().toDisplayString ());
256+ auto block = linear->addBlock ();
257+ {
258+ WithInsertPoint guard (block);
259+ auto results = gradientForNode (node, grad_values);
260+ for (auto r : results) {
261+ block->registerOutput (r);
262+ linear->addOutput ()->copyMetadata (r);
263+ }
264+ }
265+ return linear->outputs ();
257266}
258267
259268struct ReverseDetails {
@@ -267,6 +276,16 @@ struct ReverseDetails {
267276 Block * reverse_block;
268277};
269278
279+ // AutogradAdd is a special addition function that handles Undef
280+ // AutogradAdd(a, b) == a + b if defined(a) and defined(b)
281+ // AutogradAdd(Undef, b) == b
282+ // AutogradAdd(a, Undef) == a
283+ // AutogradAdd(Undef, Undef) == Undef
284+ static Value* createAutogradAdd (Value* a, Value* b) {
285+ auto graph = a->owningGraph ();
286+ return graph->insertNode (graph->create (prim::AutogradAdd, {a, b}))->output ();
287+ }
288+
270289// Before:
271290// - grad_desc has field f initialized to the original 0-stage graph
272291// After:
@@ -291,13 +310,14 @@ static ReverseDetails addReverseInline(Gradient& grad_desc,
291310 const auto get_grad = [&](Value* v) -> Value* {
292311 auto it = grad_map.find (v);
293312 if (it == grad_map.end ()) {
294- std::tie (it, std::ignore) = grad_map.emplace (v, createZerosLike (v));
313+ auto undef = graph.insertNode (graph.createUndefined ());
314+ std::tie (it, std::ignore) = grad_map.emplace (v, undef->output ());
295315 }
296316 return it->second ;
297317 };
298318 const auto set_grad = [&](Value *x, Value *dx) {
299319 if (Value * prev_grad = grad_map[x]) {
300- grad_map[x] = toVar (prev_grad) + toVar ( dx);
320+ grad_map[x] = createAutogradAdd (prev_grad, dx);
301321 } else {
302322 grad_map[x] = dx;
303323 }
@@ -309,7 +329,6 @@ static ReverseDetails addReverseInline(Gradient& grad_desc,
309329 if (!requires_grad (output))
310330 continue ;
311331 Value * output_grad = reverse_block->addInput ()->setType (output->type ());
312- output_grad = createUndefGuard (output_grad, createZerosLike (output));
313332 set_grad (output, output_grad);
314333 grad_desc.df_input_vjps .push_back (i);
315334 }
@@ -319,7 +338,7 @@ static ReverseDetails addReverseInline(Gradient& grad_desc,
319338 auto inputs = node->inputs ();
320339 if (!outputRequiresGrad (node, requires_grad)) continue ;
321340
322- value_list grad_inputs = gradientForNode (node, fmap (node->outputs (), get_grad));
341+ value_list grad_inputs = linearGradientForNode (node, fmap (node->outputs (), get_grad));
323342 JIT_ASSERT (grad_inputs.size () == node->inputs ().size ());
324343 for (size_t i = 0 , num_inputs = grad_inputs.size (); i < num_inputs; ++i) {
325344 set_grad (inputs[i], grad_inputs[i]);
@@ -337,45 +356,6 @@ static ReverseDetails addReverseInline(Gradient& grad_desc,
337356 return ReverseDetails (std::move (grad_map), std::move (requires_grad_set), reverse_block);
338357}
339358
340- bool isZero (Value * v) {
341- auto n = v->node ();
342- return n->kind () == prim::Constant &&
343- n->hasAttribute (attr::is_zero) &&
344- n->i (attr::is_zero);
345- }
346-
347- // In the case where an input is routed to an output
348- // return the (possibly undefined) input rather than
349- // the value guarded by replaceIfUndef
350- // this ensures that we do not produce a 0 tensor
351- // when the autograd would produce None
352- // graph(a) {
353- // b = replaceIfUndef(a,0);
354- // c = b + b
355- // return c, b; // will replace 'b' with 'a'
356- // }
357- // Also replace any known-to-be-zero outputs with Undef
358- // for the same reason
359-
360- static void passthroughUndefs (std::shared_ptr<Graph> graph) {
361- bool changed = false ;
362- for (size_t i = 0 ; i < graph->outputs ().size (); i++) {
363- Value * v = graph->outputs ()[i];
364- if (v->node ()->kind () == prim::ReplaceIfUndef) {
365- graph->return_node ()->replaceInput (i, v->node ()->inputs ()[0 ]);
366- changed = true ;
367- } else if (isZero (v)) {
368- auto undef = graph->insertNode (graph->createUndefined ());
369- graph->return_node ()->replaceInput (i, undef->output ());
370- changed = true ;
371- }
372- }
373- // handle cases where replaceIfUndef or constants has become dead
374- if (changed)
375- EliminateDeadCode (graph);
376-
377- }
378-
379359// Takes a grad_desc.f returned from `addReverseInline` and splits off the
380360// reverse_block into its own graph, storing it in df.
381361// All intermediates needed in the second stage are added to
@@ -469,15 +449,10 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
469449 if (rev_info.requires_grad_set .count (tmp) == 0 ) continue ;
470450 Value * tmp_vjp_in = reverse_block->addInput ()->setType (tmp->type ());
471451 Value * tmp_vjp_prev = rev_info.grad_map .at (tmp);
472- {
473- WithInsertPoint guard (tmp_vjp_prev->node ());
474- auto zeroes = createZerosLike (tmp);
475- tmp_vjp_in = createUndefGuard (tmp_vjp_in, zeroes);
476- }
477452 // This is quite weird because we can't first make a sum and then replace all uses
478453 // of tmp_vjp_prev (that would replace its use in the sum too!), so we create an
479454 // incorrect sum that doesn't use prev vjp, replace uses, and fix the sum.
480- Value * new_vjp = toVar (tmp_vjp_in) + toVar ( tmp_vjp_in);
455+ Value * new_vjp = createAutogradAdd (tmp_vjp_in, tmp_vjp_in);
481456 new_vjp->node ()->moveAfter (tmp_vjp_prev->node ());
482457 tmp_vjp_prev->replaceAllUsesWith (new_vjp);
483458 new_vjp->node ()->replaceInput (1 , tmp_vjp_prev);
@@ -526,7 +501,6 @@ Gradient differentiate(std::shared_ptr<Graph>& _graph, const std::vector<bool>&
526501 // Fills in f, df, f_real_outputs, df_input_captures,
527502 // modifies df_input_vjps (new vjps are added for temporaries)
528503 lambdaLiftReverse (grad_desc, rev_info);
529- passthroughUndefs (grad_desc.df );
530504 return grad_desc;
531505}
532506
0 commit comments