@@ -173,6 +173,21 @@ c10::intrusive_ptr<Work> alltoall_(
173173 AllToAllOptions{std::chrono::milliseconds (timeout)});
174174}
175175
176+ c10::intrusive_ptr<Work> alltoall_base_ (
177+ at::Tensor& output,
178+ at::Tensor& input,
179+ const c10::intrusive_ptr<ProcessGroup>& process_group,
180+ std::vector<int64_t > output_split_sizes,
181+ std::vector<int64_t > input_split_sizes,
182+ int64_t timeout) {
183+ return process_group->alltoall_base (
184+ output,
185+ input,
186+ output_split_sizes,
187+ input_split_sizes,
188+ AllToAllOptions{std::chrono::milliseconds (timeout)});
189+ }
190+
176191c10::intrusive_ptr<Work> barrier (
177192 const c10::intrusive_ptr<ProcessGroup>& process_group,
178193 const std::vector<int64_t >& device_ids,
@@ -271,6 +286,9 @@ TORCH_LIBRARY(c10d, m) {
271286 m.def (
272287 " alltoall_" ,
273288 dispatch (c10::DispatchKey::CompositeExplicitAutograd, alltoall_));
289+ m.def (
290+ " alltoall_base_" ,
291+ dispatch (c10::DispatchKey::CompositeExplicitAutograd, alltoall_base_));
274292 m.def (
275293 " barrier" ,
276294 dispatch (c10::DispatchKey::CompositeExplicitAutograd, barrier));
@@ -523,6 +541,31 @@ c10::intrusive_ptr<Work> alltoall(
523541 output_tensors, input_tensors, process_group, opts.timeout .count ());
524542}
525543
544+ c10::intrusive_ptr<Work> alltoall_base (
545+ const c10::intrusive_ptr<ProcessGroup>& process_group,
546+ at::Tensor& output,
547+ at::Tensor& input,
548+ std::vector<int64_t > output_split_sizes,
549+ std::vector<int64_t > input_split_sizes,
550+ const AllToAllOptions& opts) {
551+ static auto op = c10::Dispatcher::singleton ()
552+ .findSchemaOrThrow (" c10d::alltoall_base_" , " " )
553+ .typed <c10::intrusive_ptr<::c10d::Work>(
554+ at::Tensor&,
555+ at::Tensor&,
556+ const c10::intrusive_ptr<::c10d::ProcessGroup>&,
557+ std::vector<int64_t >,
558+ std::vector<int64_t >,
559+ int64_t )>();
560+ return op.call (
561+ output,
562+ input,
563+ process_group,
564+ output_split_sizes,
565+ input_split_sizes,
566+ opts.timeout .count ());
567+ }
568+
526569void monitored_barrier (
527570 const c10::intrusive_ptr<ProcessGroup>& process_group,
528571 const BarrierOptions& opts,
0 commit comments