Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions test/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2490,6 +2490,71 @@ def forward(self, x):
loss = criterion(output, target)
loss.backward()

@skip_if_not_nccl
@skip_if_not_multigpu
def test_failure_recovery(self):
store = c10d.FileStore(self.file.name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

# need to create a separate file for the recovered FileStore, because
# the original one will be deleted when destructing the first FileStore.
recovery_filename = self.file.name + "_recovery"

if self.rank == 0:
# the file will be deleted by the recovered FileStore
open(recovery_filename, "w").close()

# not necessary to run barrier here, as DDP will synchronize

class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
self.fc1 = nn.Linear(2, 10, bias=False)
self.fc2 = nn.Linear(10, 4, bias=False)
self.relu = nn.ReLU()

def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
return F.softmax(x, dim=1)

device_id = gpus_for_rank(self.world_size)[self.rank][0]
model = TestModel().float().to(device_id)
ddp = DistributedDataParallel(
model,
device_ids=[device_id],
process_group=process_group,
)

batch_size = 4
criterion = nn.CrossEntropyLoss()
input = torch.rand([batch_size, 2], dtype=torch.float)
target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id)

for _ in range(6):
output = ddp(input)
loss = criterion(output, target)
loss.backward()

del ddp
del process_group
del store # this will delete self.file

store = c10d.FileStore(recovery_filename, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
ddp = DistributedDataParallel(
model,
device_ids=[device_id],
process_group=process_group,
)

input = torch.rand([batch_size, 2], dtype=torch.float)
target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id)
for _ in range(6):
output = ddp(input)
loss = criterion(output, target)
loss.backward()


class ReducerModule(nn.Module):
def __init__(self):
Expand Down
16 changes: 15 additions & 1 deletion torch/csrc/autograd/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,29 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
// Hook API
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

void add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) {
uintptr_t add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) {
post_hooks_.push_back(std::move(post_hook));
// Use the raw pointer as the unique key to identify this hook. This key
// can then be used in del_post_hook(key) to remove this hook.
return reinterpret_cast<std::uintptr_t>(post_hooks_.back().get());
}

const std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() const
noexcept {
return post_hooks_;
}

// delete a post hook matching the key
bool del_post_hook(const uintptr_t& key) {
for (auto it = post_hooks_.begin(); it != post_hooks_.end();) {
if (key == reinterpret_cast<std::uintptr_t>(it->get())) {
post_hooks_.erase(it);
return true;
}
}
return false;
}

std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() noexcept {
return post_hooks_;
}
Expand Down
26 changes: 21 additions & 5 deletions torch/csrc/distributed/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,14 @@ Reducer::Reducer(
auto grad_accumulator = variable.grad_accumulator();

// Hook to execute after the gradient accumulator has executed.
grad_accumulator->add_post_hook(torch::make_unique<LambdaPostHook>([=] {
std::lock_guard<std::mutex> lock(this->mutex_);
this->mark_variable_ready(
replica_index, variable_index, /* called_from_autograd= */ true);
}));
hooks_[grad_accumulator->add_post_hook(
torch::make_unique<LambdaPostHook>([=] {
std::lock_guard<std::mutex> lock(this->mutex_);
this->mark_variable_ready(
replica_index,
variable_index,
/* called_from_autograd= */ true);
}))] = grad_accumulator;

// Map raw function pointer to replica index and parameter index.
// This is used later on when the autograd graph is traversed
Expand Down Expand Up @@ -138,6 +141,19 @@ Reducer::Reducer(
}
}

Reducer::~Reducer() noexcept(false) {
// Remove all hooks on variables registered by this Reducer. This is necessary
// to make DDP failure recoverable. Otherwise, multiple Reducer instances
// (from recoveries) will add their hooks to the original model, and those
// hooks will try to invoke methods on a deleted Reducer objects.
for (auto& hook : hooks_) {
auto& key = hook.first;
auto& grad_accumulator = hook.second;
AT_ASSERTM(grad_accumulator->del_post_hook(key),
"Reducer attempts to delete a non-existing hook.");
}
}

// Called when the gradient for the specified variable is ready.
// It can be called from two places:
// - By an autograd thread after executing a gradient accumulator function.
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/distributed/c10d/reducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class Reducer {
std::vector<std::vector<size_t>> bucket_indices,
std::shared_ptr<c10d::ProcessGroup> process_group);

~Reducer() noexcept(false);

// To (re-)initialize bucket assignment, pass a list of buckets, each
// of which is specified by a list of indices in the variables list.
// This function performs validation that the variables within a bucket
Expand Down Expand Up @@ -52,6 +54,8 @@ class Reducer {
std::vector<std::vector<std::shared_ptr<torch::autograd::Function>>>
grad_accumulators_;
std::unordered_map<torch::autograd::Function*, std::tuple<int, int>> func_;
std::unordered_map<uintptr_t, std::shared_ptr<torch::autograd::Function>>
hooks_;

bool expect_autograd_hooks_;
bool require_finalize_;
Expand Down