@@ -228,12 +228,16 @@ void Reducer::mark_variable_ready(
228228 }
229229 }
230230
231- // Queue function to finalize once the final bucket was marked ready.
231+ // Run finalizer function once the final bucket was marked ready.
232232 if (next_bucket_ == buckets_.size ()) {
233- // Autograd callbacks can only be registered while the engine is running.
234- AT_ASSERT (called_from_autograd);
235- torch::autograd::Engine::get_default_engine ().queue_callback (
236- [=] { this ->finalize_backward (); });
233+ if (called_from_autograd) {
234+ torch::autograd::Engine::get_default_engine ().queue_callback ([=] {
235+ std::lock_guard<std::mutex> lock (this ->mutex_ );
236+ this ->finalize_backward ();
237+ });
238+ } else {
239+ finalize_backward ();
240+ }
237241 }
238242}
239243
@@ -375,6 +379,28 @@ void Reducer::prepare_for_backward(
375379 std::unordered_set<torch::autograd::Function*> seen;
376380 std::vector<torch::autograd::Function*> queue;
377381
382+ // Check that any prior reduction has finished.
383+ // The variable `expect_autograd_hooks` is true until gradients for all
384+ // parameters have been received and all buckets are ready.
385+ if (expect_autograd_hooks_) {
386+ AT_ERROR (
387+ " Expected to have finished reduction in the prior iteration before " ,
388+ " starting a new one. " ,
389+ " " ,
390+ " This error indicates that your module has parameters that were " ,
391+ " not used in producing its output (the return value of `forward`). " ,
392+ " " ,
393+ " You can enable unused parameter detection by passing the keyword "
394+ " argument `find_unused_parameters=True` to " ,
395+ " `torch.nn.parallel.DistributedDataParallel`. " ,
396+ " " ,
397+ " If you already have this argument set, then the distributed data " ,
398+ " parallel module wasn't able to locate the output tensors in the " ,
399+ " return value of your module's `forward` function. " ,
400+ " Please include the structure of the return value of `forward` of " ,
401+ " your module when reporting this issue (e.g. list, dict, iterable)." );
402+ }
403+
378404 // Reset accounting.
379405 has_marked_unused_parameters_ = true ;
380406 expect_autograd_hooks_ = true ;
@@ -433,34 +459,12 @@ void Reducer::prepare_for_backward(
433459}
434460
435461void Reducer::finalize_backward () {
436- std::lock_guard<std::mutex> lock (mutex_);
437-
438462 // No longer expect autograd hooks to fire after this function returns.
439463 AT_ASSERT (expect_autograd_hooks_);
440464 expect_autograd_hooks_ = false ;
441465
442466 // Check that all buckets were completed and had their work kicked off.
443- if (next_bucket_ < buckets_.size ()) {
444- // If the reducer marked unused parameters and we STILL didn't get
445- // gradients for all module parameters, something is seriously wrong.
446- AT_ASSERT (!has_marked_unused_parameters_);
447- AT_ERROR (
448- " Expected to have gradients for all module parameters upon returning " ,
449- " from the call to `torch.autograd.backward`. " ,
450- " " ,
451- " This error indicates that your module has parameters that were " ,
452- " not used in producing its output (the return value of `forward`). " ,
453- " " ,
454- " You can enable unused parameter detection by passing the keyword "
455- " argument `find_unused_parameters=True` to " ,
456- " `torch.nn.parallel.DistributedDataParallel`. " ,
457- " " ,
458- " If you already have this argument set, then the distributed data " ,
459- " parallel module wasn't able to locate the output tensors in the " ,
460- " return value of your module's `forward` function. " ,
461- " Please include the structure of the return value of `forward` of " ,
462- " your module when reporting this issue (e.g. list, dict, iterable)." );
463- }
467+ AT_ASSERT (next_bucket_ == buckets_.size ());
464468
465469 // Wait for asynchronous reduction to complete and unflatten contents.
466470 for (auto & bucket : buckets_) {
0 commit comments