-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[shard] use gather_object for gather API #71624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead. Differential Revision: [D33688907](https://our.internmc.facebook.com/intern/diff/D33688907/) [ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
You can add a comment to the PR and tag @pytorchbot with the following commands: # ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun
# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slowFor more information, please take a look at the CI Flow Wiki. |
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 5aa888b (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead. fixes #66187 Differential Revision: [D33688907](https://our.internmc.facebook.com/intern/diff/D33688907/) [ghstack-poisoned]
Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead. fixes #66187 Differential Revision: [D33688907](https://our.internmc.facebook.com/intern/diff/D33688907/) [ghstack-poisoned]
Pull Request resolved: #71624 Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead. ghstack-source-id: 147386510 Differential Revision: [D33688907](https://our.internmc.facebook.com/intern/diff/D33688907/)
| tensor = shard.tensor | ||
|
|
||
| out_narrow_view = out | ||
| assert out_narrow_view is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't _validate_output_tensor_for_gather validate this? Why do we need another assert here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is added purely for mypy linter, it seems like mypy couldn't understand _validate_output_tensor_for_gather checks, so have to do this assert here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually I add a mypy ignore for something like this, but its up to you.
| # https://github.com/pytorch/pytorch/issues/66187 | ||
| dist.all_gather_object( | ||
| gathered_shards: List[Optional[List[Shard]]] = [None] * world_size if rank == dst else [] | ||
| dist.gather_object( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(no need to update current PR, can be done in followup ones)
Not sure how performance critical this is, but as we discussed in today's meeting, this indeed looks more expensive than necessary, as there will be additional H2D + D2H copies. I'd assume handling tensor and non-tensor parts separately would be faster for large ShardedTensors.
pytorch/torch/distributed/distributed_c10d.py
Lines 1551 to 1565 in 03f1f0c
| def _object_to_tensor(obj): | |
| f = io.BytesIO() | |
| _pickler(f).dump(obj) | |
| byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined] | |
| # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. | |
| # Otherwise, it will casue 100X slowdown. | |
| # See: https://github.com/pytorch/pytorch/issues/65696 | |
| byte_tensor = torch.ByteTensor(byte_storage) | |
| local_size = torch.LongTensor([byte_tensor.numel()]) | |
| return byte_tensor, local_size | |
| def _tensor_to_object(tensor, tensor_size): | |
| buf = tensor.numpy().tobytes()[:tensor_size] | |
| return _unpickler(io.BytesIO(buf)).load() |
BTW, is the non-tensor meta info static? If so, we can cache those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah thanks for suggestion, will try to use two separate gather to improve the perf. The non-tensor meta info might not be static I think (i.e. if we do resharding on a ShardedTensor, the Shard.metadata might get changed, to different ranks, or the shard_offset/size changes).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am thinking of a new way to possibly only requires one gather call, this requires us make Shard a subclass of torch.Tensor(metadata is a field in python), so that we can do gather alone. But I am not sure if our c10d collectives support custom tensor objects? (Maybe not as we eventually lowering the collective to C++ and we might only have the at::Tensor do the real communication, not the metadata)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I am not sure if our c10d collectives support custom tensor objects?
It should. At least it worked for SparseTensor. But I am not sure if that's sufficient for ShardedTensor
pytorch/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp
Lines 1050 to 1134 in 7beb030
| class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { | |
| public: | |
| AsyncSparseAllreduceWork( | |
| const std::shared_ptr<gloo::Context>& context, | |
| std::vector<at::Tensor>& inputs, | |
| uint32_t tag) | |
| : ProcessGroupGloo::AsyncWork({inputs}, "gloo:sparse_all_reduce", inputs), | |
| context(context), | |
| inputs(inputs), | |
| tag(tag) {} | |
| std::shared_ptr<gloo::Context> context; | |
| std::vector<at::Tensor> inputs; | |
| const uint32_t tag; | |
| // We share dimensionality about the sparse tensors before collecting | |
| // their contents. We assume here that the maximum number of sparse | |
| // and dense dimensions is 4. This is stored in a contiguous piece of | |
| // memory so that we can easily run allgather on it. | |
| // | |
| // The layout of this memory is as follows: | |
| // | |
| // - [0:4]: sparse dims | |
| // - [4:8]: dense dims | |
| // - [8]: nnz | |
| // | |
| class SparseTensorMetadata { | |
| public: | |
| static constexpr auto dim = 9; | |
| // Construct from an existing metadata tensor to facilitate structured | |
| // access to metadata from peers, after gathering it. | |
| explicit SparseTensorMetadata(at::Tensor metadata) | |
| : metadata_(metadata), data_(metadata_.data_ptr<int64_t>()) { | |
| AT_ASSERT(metadata.scalar_type() == at::kLong); | |
| AT_ASSERT(metadata.dim() == 1); | |
| AT_ASSERT(metadata.size(0) == dim); | |
| } | |
| // Populate the metadata. | |
| void populate_from_sparse_tensor(const at::Tensor& tensor) { | |
| const auto sparse_dim = tensor.sparse_dim(); | |
| AT_ASSERT(sparse_dim <= 4); | |
| for (const auto i : c10::irange(4)) { | |
| if (i < sparse_dim) { | |
| data_[i] = tensor.size(i); | |
| } | |
| } | |
| const auto dense_dim = tensor.dense_dim(); | |
| AT_ASSERT(dense_dim <= 4); | |
| for (const auto i : c10::irange(4)) { | |
| if (i < dense_dim) { | |
| data_[i + 4] = tensor.size(sparse_dim + i); | |
| } | |
| } | |
| data_[8] = tensor._nnz(); | |
| } | |
| std::vector<int64_t> sizes() const { | |
| std::vector<int64_t> sizes; | |
| // Sparse sizes | |
| for (const auto i : c10::irange(4)) { | |
| if (data_[i] <= 0) { | |
| break; | |
| } | |
| sizes.push_back(data_[i]); | |
| } | |
| // Dense sizes | |
| for (const auto i : c10::irange(4, 8)) { | |
| if (data_[i] <= 0) { | |
| break; | |
| } | |
| sizes.push_back(data_[i]); | |
| } | |
| return sizes; | |
| } | |
| int64_t nnz() const { | |
| return data_[8]; | |
| } | |
| protected: | |
| at::Tensor metadata_; | |
| int64_t* data_; | |
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I gave a try on using two separate gather for metadata and tensors to avoid the pickling copies with gather_object today. I think it's more tricky than I thought, mainly because:
- We can use
gather_objectfor metadatas, but we couldn't simply usegatherfor the local shard tensors, mainly becauselocal_shards()is a list of tensor, but input ofgatheris a single tensor not a list of tensor, sogatherwon't work here - We can try using
torch.catlocally on each rank to form a single tensor before thegathercollective, but we need to split them afterwards as they might not be adjacent to each other (local shards on this rank might contain two tensors that's far away in the global position logically). So we might need to insert additional insertion point in the firstgather_objectto split the combined tensors from the secondgathercall. This is pretty tricky from my understanding.
Any suggestions are appreciated :)
For this PR itself, I think we can land it as is, as the code before this PR used all_gather_object anyways. I will think more about how to solve this pickling issue and make a follow up PR to improve the perf.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanchaol Can we create a gh issue to track this follow up improvement?
Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead. fixes #66187 Differential Revision: [D33688907](https://our.internmc.facebook.com/intern/diff/D33688907/) [ghstack-poisoned]
| tensor = shard.tensor | ||
|
|
||
| out_narrow_view = out | ||
| assert out_narrow_view is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually I add a mypy ignore for something like this, but its up to you.
| # https://github.com/pytorch/pytorch/issues/66187 | ||
| dist.all_gather_object( | ||
| gathered_shards: List[Optional[List[Shard]]] = [None] * world_size if rank == dst else [] | ||
| dist.gather_object( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanchaol Can we create a gh issue to track this follow up improvement?
fduwjj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a rookie question here, if this is for perf purpose. For change like this, do we measure the perf change before and after? If so, I am curious how do we do the experiment here?
Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead. fixes #66187 Differential Revision: [D33688907](https://our.internmc.facebook.com/intern/diff/D33688907/) [ghstack-poisoned]
Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead. fixes #66187 Differential Revision: [D33688907](https://our.internmc.facebook.com/intern/diff/D33688907/) [ghstack-poisoned]
Summary: Pull Request resolved: #71624 Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead. TODO: To further reduce the comm overhead, we need to figure out a way to avoid using `gather_object`, as `gather_object` or `all_gather_object` incurs pickling copy between devices. ghstack-source-id: 151007578 Test Plan: wait for ci Reviewed By: pritamdamania87 Differential Revision: D33688907 fbshipit-source-id: 2073c5a46c33a7a2640a9e3599dc795d9e4c0a1e
|
Hey @wanchaol. |
Stack from ghstack (oldest at bottom):
Now we have gather available in NCCL pg, we can switch our
sharded_tensor.gatherto use gather_object instead of all_gather_object, which will reduce the communication overhead.fixes #66187
Differential Revision: D33688907