@@ -40,6 +40,19 @@ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> allreduce_(
4040 std::move (tensor_vec), work);
4141}
4242
43+ c10::intrusive_ptr<Work> allreduce_coalesced_ (
44+ at::TensorList tensors,
45+ const c10::intrusive_ptr<ProcessGroup>& process_group,
46+ const c10::intrusive_ptr<ReduceOp>& reduce_op,
47+ int64_t timeout) {
48+ auto tensor_vec = tensors.vec ();
49+ AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{};
50+ opts.reduceOp = *reduce_op.get ();
51+ opts.timeout = std::chrono::milliseconds (timeout);
52+
53+ return process_group->allreduce_coalesced (tensor_vec, opts);
54+ }
55+
4356c10::intrusive_ptr<Work> reduce_ (
4457 at::TensorList tensors,
4558 const c10::intrusive_ptr<ProcessGroup>& process_group,
@@ -177,6 +190,10 @@ TORCH_LIBRARY(c10d, m) {
177190 m.def (
178191 " allreduce_" ,
179192 dispatch (c10::DispatchKey::CompositeExplicitAutograd, allreduce_));
193+ m.def (
194+ " allreduce_coalesced_" ,
195+ dispatch (
196+ c10::DispatchKey::CompositeExplicitAutograd, allreduce_coalesced_));
180197 m.def (
181198 " allgather_" ,
182199 dispatch (c10::DispatchKey::CompositeExplicitAutograd, allgather_));
@@ -249,6 +266,25 @@ c10::intrusive_ptr<Work> allreduce(
249266 opts.timeout .count ()));
250267}
251268
269+ c10::intrusive_ptr<Work> allreduce_coalesced (
270+ const c10::intrusive_ptr<ProcessGroup>& process_group,
271+ at::TensorList tensors,
272+ const AllreduceCoalescedOptions& opts) {
273+ static auto op = c10::Dispatcher::singleton ()
274+ .findSchemaOrThrow (" c10d::allreduce_coalesced_" , " " )
275+ .typed <c10::intrusive_ptr<::c10d::Work>(
276+ at::TensorList,
277+ const c10::intrusive_ptr<::c10d::ProcessGroup>&,
278+ const c10::intrusive_ptr<::c10d::ReduceOp>&,
279+ int64_t )>();
280+
281+ return op.call (
282+ tensors,
283+ process_group,
284+ c10::make_intrusive<ReduceOp>(opts.reduceOp ),
285+ opts.timeout .count ());
286+ }
287+
252288c10::intrusive_ptr<Work> allgather (
253289 const c10::intrusive_ptr<ProcessGroup>& process_group,
254290 const std::vector<std::vector<at::Tensor>>& output_tensors,
0 commit comments