Skip to content

Commit baf88b6

Browse files
committed
Draft for sub-tasks 1 and 2 of DDP Communication Hook
Pull Request resolved: #40848 Draft for sub-tasks 1 and 2 of [39272](#39272) ghstack-source-id: 107207660 Differential Revision: [D22328310](https://our.internmc.facebook.com/intern/diff/D22328310/)
1 parent 300a3aa commit baf88b6

File tree

6 files changed

+192
-3
lines changed

6 files changed

+192
-3
lines changed

test/distributed/test_c10d.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2989,6 +2989,54 @@ def test_param_layout_mismatch_error(self):
29892989
with self.assertRaisesRegex(RuntimeError, ".* appears not to match strides of the same param in process 0"):
29902990
m_ddp = DistributedDataParallel(m, device_ids=[dev0], process_group=process_group)
29912991

2992+
@requires_gloo()
2993+
def test_ddp_comm_hook_future_passing(self):
2994+
"""
2995+
This unit test verifies whether the Future object is passed properly.
2996+
The callback function creates a Future object and sets a value to it.
2997+
"""
2998+
class test_ddp_comm_hook(nn.Module):
2999+
def __init__(self):
3000+
super(test_ddp_comm_hook, self).__init__()
3001+
self.t0 = Task()
3002+
3003+
def forward(self, x, rank):
3004+
return self.t0(x + rank)
3005+
3006+
def run_and_verify_grad(model):
3007+
# Run forward
3008+
output = model(8, self.rank)
3009+
3010+
# # The grads of all parameters should be None at this point.
3011+
[self.assertIsNone(p.grad) for p in model.parameters()]
3012+
3013+
# Run backward
3014+
output.mean().backward()
3015+
3016+
# # # Now locally unused parameter should have grad updated on all ranks.
3017+
[self.assertEqual(p.grad, torch.ones(2, 2)) for p in model.parameters()]
3018+
3019+
def simple_hook(state, bucket):
3020+
fut = torch.futures.Future()
3021+
fut.set_result([torch.ones(4)])
3022+
3023+
def fut_then(fut):
3024+
# bucket.set_tensors(fut.wait())
3025+
for bt, ft in zip(bucket.get_tensors(), fut.wait()):
3026+
bt.copy_(ft)
3027+
return fut.then(fut_then)
3028+
3029+
store = c10d.FileStore(self.file_name, self.world_size)
3030+
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
3031+
3032+
# Test on CPU
3033+
cpu_model = DistributedDataParallel(
3034+
test_ddp_comm_hook().cpu(),
3035+
process_group=process_group
3036+
)
3037+
cpu_model.reducer.register_comm_hook(None, simple_hook)
3038+
run_and_verify_grad(cpu_model)
3039+
29923040

29933041
class ReducerModule(nn.Module):
29943042
def __init__(self):
@@ -3125,6 +3173,42 @@ def test_forward_backward_optimizer(self):
31253173
output.backward()
31263174
optimizer.step()
31273175

3176+
def test_ddp_comm_hook_register_just_once(self):
3177+
"""
3178+
DDP communication hook can only be registered once. This test validates whether
3179+
the error is thrown properly when register_comm_hook is called more than once.
3180+
"""
3181+
model = self._create_mixed_precision_model()
3182+
reducer = self._create_reducer_for_models([model], find_unused_parameters=True)
3183+
3184+
def dummy_hook(state, bucket):
3185+
fut = torch.futures.Future()
3186+
fut.set_result(bucket.get_tensors())
3187+
return fut.then()
3188+
reducer.register_comm_hook(None, dummy_hook)
3189+
try:
3190+
reducer.register_comm_hook(None, dummy_hook)
3191+
except Exception as e:
3192+
if "register_comm_hook can only be called once" in str(e):
3193+
return
3194+
else:
3195+
raise e
3196+
3197+
def test_ddp_comm_hook_callable(self):
3198+
"""
3199+
The Python hook must be callable. This unit test checks whether this condition
3200+
is properly checked inside reducer.
3201+
"""
3202+
model = self._create_mixed_precision_model()
3203+
reducer = self._create_reducer_for_models([model], find_unused_parameters=True)
3204+
try:
3205+
reducer.register_comm_hook(state=None, hook=1)
3206+
except Exception as e:
3207+
if "comm_hook must be callable" in str(e):
3208+
return
3209+
else:
3210+
raise e
3211+
31283212

31293213
class ComputeBucketAssignmentTest(TestCase):
31303214
def test_single_limit_single_dtype(self):

torch/csrc/distributed/c10d/comm.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <ATen/core/functional.h>
66
#include <torch/csrc/distributed/c10d/reducer.h>
7+
#include <torch/csrc/jit/python/pybind_utils.h>
78
#include <torch/csrc/utils/tensor_flatten.h>
89

910
namespace c10d {
@@ -79,4 +80,25 @@ void broadcast_coalesced(
7980
}
8081
}
8182

83+
GradBucket::GradBucket(std::vector<at::Tensor>& tensors) : tensors_(tensors){};
84+
85+
std::vector<at::Tensor> GradBucket::get_tensors() {
86+
return tensors_;
87+
};
88+
89+
void GradBucket::set_tensors(std::vector<at::Tensor>& tensors) {
90+
tensors_ = tensors;
91+
}
92+
93+
PythonCommHook::PythonCommHook(py::object state, py::object hook)
94+
: state_(std::move(state)), hook_(std::move(hook)){};
95+
c10::intrusive_ptr<torch::jit::Future> PythonCommHook::operate(
96+
const GradBucket& bucket) {
97+
py::gil_scoped_acquire acquire;
98+
99+
c10::intrusive_ptr<torch::jit::Future> fut;
100+
return hook_(state_, bucket)
101+
.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>()
102+
->fut;
103+
};
82104
} // namespace c10d

torch/csrc/distributed/c10d/comm.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <ATen/ATen.h>
66
#include <c10d/ProcessGroup.hpp>
7+
#include <torch/csrc/utils/pybind.h>
78

89
namespace c10d {
910

@@ -13,4 +14,31 @@ void broadcast_coalesced(
1314
at::TensorList tensors,
1415
size_t buffer_size);
1516

17+
class GradBucket {
18+
public:
19+
explicit GradBucket(std::vector<at::Tensor>& tensors);
20+
std::vector<at::Tensor> get_tensors();
21+
void set_tensors(std::vector<at::Tensor>& tensors);
22+
23+
private:
24+
std::vector<at::Tensor> tensors_;
25+
};
26+
27+
struct CommHookInterface {
28+
public:
29+
virtual c10::intrusive_ptr<torch::jit::Future> operate(
30+
const GradBucket& bucket) = 0;
31+
};
32+
33+
class TORCH_API PythonCommHook : public CommHookInterface {
34+
public:
35+
PythonCommHook(py::object state, py::object hook);
36+
37+
c10::intrusive_ptr<torch::jit::Future> operate(
38+
const GradBucket& bucket) override;
39+
40+
private:
41+
py::object state_;
42+
py::object hook_;
43+
};
1644
} // namespace c10d

torch/csrc/distributed/c10d/init.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,18 @@ PyObject* c10d_init(PyObject* _unused) {
115115

116116
auto module = py::handle(c10d_module).cast<py::module>();
117117

118+
shared_ptr_class_<::c10d::GradBucket>(module, "GradBucket")
119+
.def(py::init<std::vector<Tensor>&>(), py::arg("tensors"))
120+
.def(
121+
"get_tensors",
122+
&::c10d::GradBucket::get_tensors,
123+
py::call_guard<py::gil_scoped_release>())
124+
.def(
125+
"set_tensors",
126+
&::c10d::GradBucket::set_tensors,
127+
py::arg("tensors"),
128+
py::call_guard<py::gil_scoped_release>());
129+
118130
shared_ptr_class_<::c10d::Reducer>(module, "Reducer")
119131
.def(
120132
py::init<
@@ -131,6 +143,12 @@ PyObject* c10d_init(PyObject* _unused) {
131143
py::arg("bucket_bytes_cap") = ::c10d::kDefaultBucketBytesCap,
132144
py::arg("find_unused_parameters") = false,
133145
py::call_guard<py::gil_scoped_release>())
146+
.def(
147+
"register_comm_hook",
148+
&::c10d::Reducer::register_comm_hook,
149+
py::arg("state"),
150+
py::arg("hook"),
151+
py::call_guard<py::gil_scoped_release>())
134152
.def(
135153
"initialize_buckets",
136154
&::c10d::Reducer::initialize_buckets,

torch/csrc/distributed/c10d/reducer.cpp

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <torch/csrc/distributed/c10d/comm.h>
1515
#include <torch/csrc/utils/hash.h>
1616
#include <torch/csrc/utils/memory.h>
17+
#include <torch/csrc/utils/pybind.h>
1718

1819
namespace c10d {
1920
namespace {
@@ -161,6 +162,8 @@ Reducer::Reducer(
161162
}
162163
}
163164
}
165+
166+
comm_hook_.reset();
164167
}
165168

166169
// Note [Skip allreducing local_used_maps_dev]
@@ -575,7 +578,11 @@ void Reducer::mark_bucket_ready(size_t bucket_index) {
575578
//
576579
tensors.push_back(replica.contents);
577580
}
578-
bucket.work = process_group_->allreduce(tensors);
581+
if (comm_hook_ == nullptr) {
582+
bucket.work = process_group_->allreduce(tensors);
583+
} else {
584+
bucket.future_work = comm_hook_->operate(GradBucket(tensors));
585+
}
579586
}
580587
}
581588

@@ -924,8 +931,13 @@ void Reducer::finalize_backward() {
924931

925932
// Wait for asynchronous reduction to complete and unflatten contents.
926933
for (auto& bucket : buckets_) {
927-
TORCH_INTERNAL_ASSERT(bucket.work);
928-
bucket.work->wait();
934+
if (comm_hook_ == nullptr) {
935+
TORCH_INTERNAL_ASSERT(bucket.work);
936+
bucket.work->wait();
937+
} else {
938+
TORCH_INTERNAL_ASSERT(bucket.future_work);
939+
bucket.future_work->wait();
940+
}
929941
if (!bucket.expect_sparse_gradient) {
930942
// We don't need to finalize the sparse bucket since the sparse grad and
931943
// the bucket essentially point to the same storage. As a result, once
@@ -1079,6 +1091,22 @@ std::vector<std::vector<size_t>> Reducer::rebuildBuckets() {
10791091
return rebuilt_bucket_indices;
10801092
}
10811093

1094+
void Reducer::register_comm_hook(py::object state, py::object comm_hook) {
1095+
TORCH_CHECK(
1096+
py::isinstance<py::function>(comm_hook), "comm_hook must be callable.");
1097+
1098+
Reducer::register_comm_hook_internal(
1099+
std::make_unique<PythonCommHook>(std::move(state), std::move(comm_hook)));
1100+
}
1101+
1102+
void Reducer::register_comm_hook_internal(
1103+
std::unique_ptr<CommHookInterface> iface) {
1104+
TORCH_CHECK(
1105+
comm_hook_ == nullptr, "register_comm_hook can only be called once.");
1106+
1107+
comm_hook_ = std::move(iface);
1108+
}
1109+
10821110
namespace {
10831111

10841112
// Tensors may be coalesced into buckets. Buckets must contain tensors of

torch/csrc/distributed/c10d/reducer.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include <torch/csrc/autograd/function.h>
1212
#include <torch/csrc/autograd/variable.h>
1313
#include <torch/csrc/distributed/autograd/context/context.h>
14+
#include <torch/csrc/distributed/c10d/comm.h>
15+
#include <torch/csrc/utils/pybind.h>
1416

1517
namespace c10d {
1618

@@ -53,6 +55,8 @@ class Reducer {
5355
return backward_stats_;
5456
}
5557

58+
void register_comm_hook(py::object state, py::object comm_hook);
59+
5660
protected:
5761
// Forward declaration.
5862
struct Bucket;
@@ -99,6 +103,9 @@ class Reducer {
99103
// Work handle for allreduce on local_used_maps_
100104
std::shared_ptr<c10d::ProcessGroup::Work> local_used_work_;
101105

106+
std::unique_ptr<CommHookInterface> comm_hook_;
107+
void register_comm_hook_internal(std::unique_ptr<CommHookInterface> iface);
108+
102109
void verify_replicas_within_process();
103110

104111
void verify_replica0_across_processes();
@@ -197,6 +204,8 @@ class Reducer {
197204
// Keep work handle around when this set of buckets is being reduced.
198205
std::shared_ptr<c10d::ProcessGroup::Work> work;
199206

207+
c10::intrusive_ptr<torch::jit::Future> future_work;
208+
200209
// If this bucket should expect a single sparse gradient.
201210
// Implies: replicas[i].variables.size() == 1.
202211
bool expect_sparse_gradient = false;

0 commit comments

Comments
 (0)