Skip to content

Commit 4ee9dd5

Browse files
committed
V3: Initial commit
Differential Revision: D15113272 Differential Version: 80846101
1 parent 5a83a74 commit 4ee9dd5

File tree

2 files changed

+83
-28
lines changed

2 files changed

+83
-28
lines changed

test/test_c10d.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2257,6 +2257,57 @@ def forward(self, x):
22572257
loss2 = criterion(output2, target)
22582258
loss2.backward()
22592259

2260+
@skip_if_not_nccl
2261+
@skip_if_not_multigpu
2262+
def test_no_used_parameters(self):
2263+
"""
2264+
Note: this test can be sped up by only running it on a CPU module
2265+
once DistributedDataParallel supports them.
2266+
"""
2267+
store = c10d.FileStore(self.file.name, self.world_size)
2268+
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
2269+
2270+
class NoUsedParameters(nn.Module):
2271+
def __init__(self):
2272+
super(NoUsedParameters, self).__init__()
2273+
2274+
# Make sure this module has some parameters, only to then decide
2275+
# to never use them from the `forward` function.
2276+
self.fc1 = nn.Linear(2, 10, bias=False)
2277+
self.fc2 = nn.Linear(10, 4, bias=False)
2278+
self.fc3 = nn.Linear(4, 4, bias=False)
2279+
self.relu = nn.ReLU()
2280+
2281+
def forward(self, x):
2282+
return x * 0.0
2283+
2284+
device_id = gpus_for_rank(self.world_size)[self.rank][0]
2285+
model = DistributedDataParallel(
2286+
NoUsedParameters().float().to(device_id),
2287+
device_ids=[device_id],
2288+
process_group=process_group,
2289+
)
2290+
2291+
batch_size = 4
2292+
input = torch.rand([batch_size, 2], dtype=torch.float)
2293+
2294+
# After initialization, no parameter has their gradient set.
2295+
for p in model.parameters():
2296+
self.assertTrue(p.requires_grad)
2297+
self.assertIsNone(p.grad)
2298+
2299+
# Run `forward` function.
2300+
model(input)
2301+
2302+
# Because none of the parameters were used, we expect reduction for
2303+
# all parameters will be executed right when initializing the reducer.
2304+
# Once `forward` returns, all the parameter's gradients must be set.
2305+
for p in model.parameters():
2306+
self.assertTrue(p.requires_grad)
2307+
self.assertIsNotNone(p.grad)
2308+
self.assertTrue(torch.is_tensor(p.grad))
2309+
self.assertEqual(p.size(), p.grad.size())
2310+
22602311

22612312
class ReducerModule(nn.Module):
22622313
def __init__(self):

torch/csrc/distributed/c10d/reducer.cpp

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,16 @@ void Reducer::mark_variable_ready(
228228
}
229229
}
230230

231-
// Queue function to finalize once the final bucket was marked ready.
231+
// Run finalizer function once the final bucket was marked ready.
232232
if (next_bucket_ == buckets_.size()) {
233-
// Autograd callbacks can only be registered while the engine is running.
234-
AT_ASSERT(called_from_autograd);
235-
torch::autograd::Engine::get_default_engine().queue_callback(
236-
[=] { this->finalize_backward(); });
233+
if (called_from_autograd) {
234+
torch::autograd::Engine::get_default_engine().queue_callback([=] {
235+
std::lock_guard<std::mutex> lock(this->mutex_);
236+
this->finalize_backward();
237+
});
238+
} else {
239+
finalize_backward();
240+
}
237241
}
238242
}
239243

@@ -375,6 +379,28 @@ void Reducer::prepare_for_backward(
375379
std::unordered_set<torch::autograd::Function*> seen;
376380
std::vector<torch::autograd::Function*> queue;
377381

382+
// Check that any prior reduction has finished.
383+
// The variable `expect_autograd_hooks` is true until gradients for all
384+
// parameters have been received and all buckets are ready.
385+
if (expect_autograd_hooks_) {
386+
AT_ERROR(
387+
"Expected to have finished reduction in the prior iteration before ",
388+
"starting a new one. ",
389+
"",
390+
"This error indicates that your module has parameters that were ",
391+
"not used in producing its output (the return value of `forward`). ",
392+
"",
393+
"You can enable unused parameter detection by passing the keyword "
394+
"argument `find_unused_parameters=True` to ",
395+
"`torch.nn.parallel.DistributedDataParallel`. ",
396+
"",
397+
"If you already have this argument set, then the distributed data ",
398+
"parallel module wasn't able to locate the output tensors in the ",
399+
"return value of your module's `forward` function. ",
400+
"Please include the structure of the return value of `forward` of ",
401+
"your module when reporting this issue (e.g. list, dict, iterable).");
402+
}
403+
378404
// Reset accounting.
379405
has_marked_unused_parameters_ = true;
380406
expect_autograd_hooks_ = true;
@@ -433,34 +459,12 @@ void Reducer::prepare_for_backward(
433459
}
434460

435461
void Reducer::finalize_backward() {
436-
std::lock_guard<std::mutex> lock(mutex_);
437-
438462
// No longer expect autograd hooks to fire after this function returns.
439463
AT_ASSERT(expect_autograd_hooks_);
440464
expect_autograd_hooks_ = false;
441465

442466
// Check that all buckets were completed and had their work kicked off.
443-
if (next_bucket_ < buckets_.size()) {
444-
// If the reducer marked unused parameters and we STILL didn't get
445-
// gradients for all module parameters, something is seriously wrong.
446-
AT_ASSERT(!has_marked_unused_parameters_);
447-
AT_ERROR(
448-
"Expected to have gradients for all module parameters upon returning ",
449-
"from the call to `torch.autograd.backward`. ",
450-
"",
451-
"This error indicates that your module has parameters that were ",
452-
"not used in producing its output (the return value of `forward`). ",
453-
"",
454-
"You can enable unused parameter detection by passing the keyword "
455-
"argument `find_unused_parameters=True` to ",
456-
"`torch.nn.parallel.DistributedDataParallel`. ",
457-
"",
458-
"If you already have this argument set, then the distributed data ",
459-
"parallel module wasn't able to locate the output tensors in the ",
460-
"return value of your module's `forward` function. ",
461-
"Please include the structure of the return value of `forward` of ",
462-
"your module when reporting this issue (e.g. list, dict, iterable).");
463-
}
467+
AT_ASSERT(next_bucket_ == buckets_.size());
464468

465469
// Wait for asynchronous reduction to complete and unflatten contents.
466470
for (auto& bucket : buckets_) {

0 commit comments

Comments
 (0)