Skip to content

Commit a7ec889

Browse files
pieternfacebook-github-bot
authored andcommitted
Add sparse tensor allreduce (#22036)
Summary: Pull Request resolved: #22036 Implemented only on ProcessGroupGloo, as an allgather of metadata (sparse_dim, dense_dim, and nnz), followed by an allgather of indices, followed by an allgather of values. Once these operations have finished, all ranks locally compute a reduction over these sparse tensors. Works for both CPU and CUDA tensors. This surfaced a problem with the existing assumption of only modifying tensors that are passed at the call site, because for sparse tensors we don't know the dimensions of the output tensors before we run the collective. To deal with this unknown, this commit adds a `result` function to the `c10d::ProcessGroup::Work` class that returns a vector of tensors. It's a bit odd to have to retrieve the result through this function only for operations on sparse tensors. To make this work irrespective of tensor layout, we can create a follow-up commit to make all in place operations make their results accessible through this function as well. This doesn't break any existing contracts but does have the potential to add interface ambiguity. This is a resubmission of #19146. Reviewed By: mrshenli Differential Revision: D15926384 fbshipit-source-id: b6ee5d81606bfa8ed63c3d63a9e307613491e0ae
1 parent 313960d commit a7ec889

File tree

6 files changed

+474
-9
lines changed

6 files changed

+474
-9
lines changed

test/test_c10d.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from datetime import timedelta
1212

1313
from itertools import groupby
14-
from functools import wraps
14+
from functools import partial, reduce, wraps
1515
from collections import namedtuple
1616

1717
import torch
@@ -157,6 +157,49 @@ def simple_multi_input_reduce_tests(rank, world_size):
157157
]
158158

159159

160+
def simple_sparse_reduce_tests(rank, world_size, num_inputs=1):
161+
"""
162+
Generate a number of basic test cases for sparse reduction.
163+
These cover tensors with a varying number of sparse dimensions and a varying
164+
number of dense dimensions. The only reduction operation we support is sum.
165+
"""
166+
def generate(rank, world_size, sparse_dims=1, dense_dims=0):
167+
# First sparse dimension is [0..rank].
168+
# Subsequent dimensions are always 0, so we know there is
169+
# a non-empty intersection between any two sparse tensors.
170+
indices = [range(rank + 1)]
171+
shape = [world_size] + [2 for _ in range(dense_dims)]
172+
for _ in range(sparse_dims - 1):
173+
indices.append([0] * (rank + 1))
174+
shape.append(world_size)
175+
values = torch.ones([rank + 1] + [2 for _ in range(dense_dims)])
176+
return torch.sparse_coo_tensor(indices, values, shape)
177+
178+
def compute_sum(fn, world_size):
179+
return reduce(lambda a, b: a + b, [fn(rank, world_size) for rank in range(world_size)])
180+
181+
return [
182+
(
183+
[
184+
fn(num_inputs * rank + i, num_inputs * world_size)
185+
for i in range(num_inputs)
186+
],
187+
[
188+
compute_sum(fn, num_inputs * world_size)
189+
for i in range(num_inputs)
190+
],
191+
)
192+
for fn in [
193+
partial(generate, sparse_dims=1),
194+
partial(generate, sparse_dims=2),
195+
partial(generate, sparse_dims=3),
196+
partial(generate, dense_dims=1),
197+
partial(generate, dense_dims=2),
198+
partial(generate, dense_dims=3),
199+
]
200+
]
201+
202+
160203
class StoreTestBase(object):
161204
def _create_store(self, i):
162205
raise RuntimeError("not implemented")
@@ -788,6 +831,54 @@ def test_allreduce_stress_cuda(self):
788831
inputs = [torch.Tensor([i + self.rank]).cuda() for i in range(1000)]
789832
self._test_allreduce_stress(inputs)
790833

834+
def test_sparse_allreduce_checks(self):
835+
store = c10d.FileStore(self.file.name, self.world_size)
836+
pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts())
837+
838+
t1 = torch.zeros([1])
839+
t2 = torch.sparse_coo_tensor([[0]], [1], size=(2,))
840+
t3 = torch.sparse_coo_tensor([[0]], [1], size=(4,))
841+
842+
with self.assertRaisesRegex(ValueError, "requires non-empty tensor list"):
843+
opts = c10d.AllreduceOptions()
844+
pg.allreduce([], opts)
845+
846+
with self.assertRaisesRegex(ValueError, "invalid tensor layout"):
847+
opts = c10d.AllreduceOptions()
848+
pg.allreduce([t1, t2], opts)
849+
850+
with self.assertRaisesRegex(ValueError, "invalid tensor size"):
851+
opts = c10d.AllreduceOptions()
852+
pg.allreduce([t2, t3], opts)
853+
854+
# Sparse allreduce only works with c10d.ReduceOp.SUM.
855+
for op in [c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX]:
856+
with self.assertRaisesRegex(ValueError, "unsupported reduction operation"):
857+
opts = c10d.AllreduceOptions()
858+
opts.reduceOp = op
859+
pg.allreduce([t3], opts)
860+
861+
def _test_sparse_allreduce_basics(self, fn):
862+
store = c10d.FileStore(self.file.name, self.world_size)
863+
pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts())
864+
865+
for num_inputs_per_rank in [1, 2]:
866+
tests = simple_sparse_reduce_tests(
867+
self.rank,
868+
self.world_size,
869+
num_inputs=num_inputs_per_rank)
870+
for (inputs, outputs) in tests:
871+
work = pg.allreduce([fn(input) for input in inputs])
872+
work.wait()
873+
self.assertEqual(work.result(), outputs)
874+
875+
def test_sparse_allreduce_basics(self):
876+
self._test_sparse_allreduce_basics(lambda t: t)
877+
878+
@skip_if_not_multigpu
879+
def test_sparse_allreduce_basics_cuda(self):
880+
self._test_sparse_allreduce_basics(lambda t: t.clone().cuda())
881+
791882
def test_scatter_checks(self):
792883
store = c10d.FileStore(self.file.name, self.world_size)
793884
pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts())

torch/csrc/distributed/c10d/init.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,15 @@ They are used in specifying strategies for reduction collectives, e.g.,
477477
.def("is_success", &::c10d::ProcessGroup::Work::isSuccess)
478478
.def("exception", &::c10d::ProcessGroup::Work::exception)
479479
.def("source_rank", &::c10d::ProcessGroup::Work::sourceRank)
480+
.def(
481+
"result",
482+
[](::c10d::ProcessGroup::Work& work) -> std::vector<at::Tensor> {
483+
auto tensors = work.result();
484+
for (auto& tensor : tensors) {
485+
tensor = autograd::make_variable(tensor);
486+
}
487+
return tensors;
488+
})
480489
.def("synchronize", &::c10d::ProcessGroup::Work::synchronize)
481490
.def(
482491
"wait",

torch/lib/c10d/ProcessGroup.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ int ProcessGroup::Work::sourceRank() const {
2727
"that correspond to a recv or recv-from-any call.");
2828
}
2929

30+
std::vector<at::Tensor> ProcessGroup::Work::result() const {
31+
throw std::runtime_error("result() not implemented.");
32+
}
33+
3034
void ProcessGroup::Work::synchronize() {}
3135

3236
void ProcessGroup::Work::wait() {

torch/lib/c10d/ProcessGroup.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class ProcessGroup {
5252
// Returns source rank if this objects represents a recv-from-any.
5353
virtual int sourceRank() const;
5454

55+
// Returns result tensors, if applicable.
56+
virtual std::vector<at::Tensor> result() const;
57+
5558
// Ensures that operations on the output tensors that are invoked
5659
// after this function returns are correctly sequenced after the
5760
// asynchronous completion of this work.

0 commit comments

Comments
 (0)