@@ -48,8 +48,8 @@ Reducer::Reducer(
4848 expect_sparse_gradients_(std::move(expect_sparse_gradients)),
4949 expect_autograd_hooks_(false ),
5050 require_finalize_(false ),
51- has_marked_unused_parameters_(false ),
5251 next_bucket_(0 ),
52+ has_marked_unused_parameters_(false ),
5353 backward_stats_base_(0 ) {
5454 AT_ASSERTM (replicas_.size () >= 1 , " Expected at least one model replica." );
5555 AT_ASSERTM (replicas_[0 ].size () >= 1 , " Expected at least one parameter." );
@@ -118,6 +118,10 @@ Reducer::Reducer(
118118 for (size_t variable_index = 0 ; variable_index < variable_count;
119119 variable_index++) {
120120 auto & variable = replicas_[replica_index][variable_index];
121+ const auto index = VariableIndex{
122+ replica_index : replica_index,
123+ variable_index : variable_index,
124+ };
121125
122126 // The gradient accumulator function is lazily initialized once.
123127 // Therefore we can use its presence in the autograd graph as
@@ -126,21 +130,14 @@ Reducer::Reducer(
126130
127131 // Hook to execute after the gradient accumulator has executed.
128132 hooks_.emplace_back (
129- grad_accumulator->add_post_hook (
130- torch::make_unique<LambdaPostHook>([=] {
131- std::lock_guard<std::mutex> lock (this ->mutex_ );
132- this ->mark_variable_ready (
133- replica_index,
134- variable_index,
135- /* called_from_autograd= */ true );
136- })),
133+ grad_accumulator->add_post_hook (torch::make_unique<LambdaPostHook>(
134+ [=] { this ->autograd_hook (index); })),
137135 grad_accumulator);
138136
139137 // Map raw function pointer to replica index and parameter index.
140138 // This is used later on when the autograd graph is traversed
141139 // to check for parameters for which no gradient is computed.
142- func_[grad_accumulator.get ()] =
143- std::make_tuple (replica_index, variable_index);
140+ func_[grad_accumulator.get ()] = index;
144141
145142 // The gradient accumulator is stored as weak_ptr in the autograd
146143 // metadata of the variable, so we have to keep it alive here for
@@ -177,9 +174,9 @@ Reducer::~Reducer() noexcept(false) {
177174 }
178175}
179176
180- void Reducer::mark_variable_ready_dense (
181- size_t replica_index,
182- size_t variable_index) {
177+ void Reducer::mark_variable_ready_dense (VariableIndex index) {
178+ const auto replica_index = index. replica_index ;
179+ const auto variable_index = index. variable_index ;
183180 const auto & bucket_index = variable_locators_[variable_index];
184181 auto & bucket = buckets_[bucket_index.bucket_index ];
185182 auto & replica = bucket.replicas [replica_index];
@@ -214,9 +211,9 @@ void Reducer::mark_variable_ready_dense(
214211 }
215212}
216213
217- void Reducer::mark_variable_ready_sparse (
218- size_t replica_index,
219- size_t variable_index) {
214+ void Reducer::mark_variable_ready_sparse (VariableIndex index) {
215+ const auto replica_index = index. replica_index ;
216+ const auto variable_index = index. variable_index ;
220217 const auto & bucket_index = variable_locators_[variable_index];
221218 auto & bucket = buckets_[bucket_index.bucket_index ];
222219 auto & replica = bucket.replicas [replica_index];
@@ -235,22 +232,37 @@ void Reducer::mark_variable_ready_sparse(
235232 replica.contents = grad;
236233}
237234
238- // Called when the gradient for the specified variable is ready.
239- // It can be called from two places:
240- // - By an autograd thread after executing a gradient accumulator function.
241- // - By the `Reducer::prepare_for_backward` function if the variable doesn't
242- // show up in the autograd graph (and it wouldn't be called by autograd).
243- void Reducer::mark_variable_ready (
244- size_t replica_index,
245- size_t variable_index,
246- bool called_from_autograd) {
235+ // The function `autograd_hook` is called after the gradient for a
236+ // model parameter has been accumulated into its gradient tensor.
237+ // This function is only to be called from the autograd thread.
238+ void Reducer::autograd_hook (VariableIndex index) {
239+ std::lock_guard<std::mutex> lock (this ->mutex_ );
240+
247241 // Ignore if we don't expect to be called.
248242 // This may be the case if the user wants to accumulate gradients
249243 // for number of iterations before reducing them.
250244 if (!expect_autograd_hooks_) {
251245 return ;
252246 }
253247
248+ // If there are model parameters that went unused when computing the model
249+ // output, they won't be part of the autograd graph, and won't receive
250+ // gradients. These parameters are discovered in the `prepare_for_backward`
251+ // function and their indexes stored in the `unused_parameters_` vector.
252+ if (!has_marked_unused_parameters_ && !unused_parameters_.empty ()) {
253+ has_marked_unused_parameters_ = true ;
254+ for (const auto & unused_index : unused_parameters_) {
255+ mark_variable_ready (unused_index);
256+ }
257+ }
258+
259+ // Finally mark variable for which this function was originally called.
260+ mark_variable_ready (index);
261+ }
262+
263+ void Reducer::mark_variable_ready (VariableIndex index) {
264+ const auto replica_index = index.replica_index ;
265+ const auto variable_index = index.variable_index ;
254266 AT_ASSERTM (replica_index < replicas_.size (), " Out of range replica index." );
255267 AT_ASSERTM (
256268 variable_index < variable_locators_.size (),
@@ -293,9 +305,9 @@ void Reducer::mark_variable_ready(
293305 }
294306
295307 if (bucket.expect_sparse_gradient ) {
296- mark_variable_ready_sparse (replica_index, variable_index );
308+ mark_variable_ready_sparse (index );
297309 } else {
298- mark_variable_ready_dense (replica_index, variable_index );
310+ mark_variable_ready_dense (index );
299311 }
300312
301313 // TODO(@pietern): Make this work for both CPU/CUDA tensors.
@@ -316,14 +328,10 @@ void Reducer::mark_variable_ready(
316328
317329 // Run finalizer function once the final bucket was marked ready.
318330 if (next_bucket_ == buckets_.size ()) {
319- if (called_from_autograd) {
320- torch::autograd::Engine::get_default_engine ().queue_callback ([=] {
321- std::lock_guard<std::mutex> lock (this ->mutex_ );
322- this ->finalize_backward ();
323- });
324- } else {
325- finalize_backward ();
326- }
331+ torch::autograd::Engine::get_default_engine ().queue_callback ([=] {
332+ std::lock_guard<std::mutex> lock (this ->mutex_ );
333+ this ->finalize_backward ();
334+ });
327335 }
328336}
329337
@@ -489,8 +497,8 @@ void Reducer::prepare_for_backward(
489497 std::vector<torch::autograd::Function*> queue;
490498
491499 // Check that any prior reduction has finished.
492- // The variable `expect_autograd_hooks ` is true until gradients for all
493- // parameters have been received and all buckets are ready .
500+ // The variable `require_finalize_ ` is true until all gradients
501+ // have been computed and reduction of all buckets has been kicked off .
494502 if (require_finalize_) {
495503 AT_ERROR (
496504 " Expected to have finished reduction in the prior iteration before " ,
@@ -513,7 +521,6 @@ void Reducer::prepare_for_backward(
513521 }
514522
515523 // Reset accounting.
516- has_marked_unused_parameters_ = true ;
517524 expect_autograd_hooks_ = true ;
518525 next_bucket_ = 0 ;
519526 backward_stats_base_ = current_time_in_nanos ();
@@ -524,11 +531,14 @@ void Reducer::prepare_for_backward(
524531 bucket.pending = bucket.replicas .size ();
525532 }
526533
534+ // Reset unused parameter accounting.
535+ has_marked_unused_parameters_ = false ;
536+ unused_parameters_.clear ();
537+
527538 // If no outputs are specified, we assume that autograd hooks for ALL
528539 // variables will be called, and we don't have to search the autograd graph
529540 // for presence of these hooks.
530541 if (outputs.empty ()) {
531- has_marked_unused_parameters_ = false ;
532542 return ;
533543 }
534544
@@ -562,10 +572,7 @@ void Reducer::prepare_for_backward(
562572 continue ;
563573 }
564574
565- size_t replica_index;
566- size_t variable_index;
567- std::tie (replica_index, variable_index) = it.second ;
568- mark_variable_ready (replica_index, variable_index);
575+ unused_parameters_.push_back (it.second );
569576 }
570577}
571578
0 commit comments