Skip to content

Commit 39d4121

Browse files
mrshenlifacebook-github-bot
authored andcommitted
Fix ProcessGroupGloo allgather for tensors with shared storage (#21490)
Summary: Fix #20421 `ProcessGroupGloo` only requires input/output tensors to be contiguous. Contiguous tensors might not start from the beginning of the underlying storage, e.g., `chunk(..., dim=0)[1]`. The current implementation passes `tensor.storage().data()` ptr to gloo buffer. This leads to wrong results if the tensor has a non-zero storage offset. The proposed solution is to use `tensor.data_ptr()` instead. Let's see if this breaks any tests. cc qijianan777 Pull Request resolved: #21490 Differential Revision: D15768907 Pulled By: mrshenli fbshipit-source-id: 9d7d1e9baf0461b31187c7d21a4a53b1fbb07397
1 parent ad73ea2 commit 39d4121

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

test/test_c10d_spawn.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _test_multiprocess(self, f, shared_tensors, init_pg, n_output):
7272
expected,
7373
result,
7474
(
75-
"Expect rank {} to broadcast result {} but got {}."
75+
"Expect rank {} to receive tensor {} but got {}."
7676
).format(pid, expected, result)
7777
)
7878

@@ -187,6 +187,26 @@ def test_shared_allgather_nccl(self):
187187
ProcessGroupShareTensorTest._init_pg_nccl,
188188
self.world_size)
189189

190+
@classmethod
191+
def _test_allgather_chunk_process(
192+
cls, rank, filename, shared_tensor, world_size, init_pg, c2p, p2c):
193+
pg = init_pg(rank, filename, world_size)
194+
chunks = torch.chunk(shared_tensor, world_size, dim=0)
195+
x = chunks[rank]
196+
ys = [torch.zeros_like(x) for _ in range(world_size)]
197+
pg.allgather(ys, x).wait()
198+
c2p.put((rank, chunks[0].to("cpu"), ys[0].to("cpu")))
199+
c2p.put((rank, chunks[1].to("cpu"), ys[1].to("cpu")))
200+
p2c.get()
201+
202+
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
203+
def test_shared_allgather_chunk_gloo(self):
204+
self._test_multiprocess(
205+
ProcessGroupShareTensorTest._test_allgather_chunk_process,
206+
torch.tensor(range(4)).reshape(2, 2),
207+
ProcessGroupShareTensorTest._init_pg_gloo,
208+
self.world_size)
209+
190210

191211
if __name__ == '__main__':
192212
run_tests()

torch/lib/c10d/Utils.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,13 @@ inline std::vector<int> getDevices(const std::vector<at::Tensor>& tensors) {
272272

273273
template <typename T>
274274
inline T* getDataPointer(const at::Tensor& tensor) {
275-
// NB: This does NOT respect storage_offset from the tensor
276-
return static_cast<T*>(tensor.storage().data());
275+
// This method is only used in ProcessGroupGloo for now. Call sites must make
276+
// sure that the input tensor is contiguous. It is OK if the tensor does not
277+
// start from the beginning of the storage. For example, it could come from
278+
// chunk(..., dim=0)[1]. Hence, we need to use data_ptr() instead of
279+
// tensor.storage().data()
280+
// NB: not using tensor.data<T>() because tensor is not aware of gloo::TYPE
281+
return static_cast<T*>(tensor.data_ptr());
277282
}
278283

279284
template <typename T>

0 commit comments

Comments
 (0)