Skip to content

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Jan 21, 2022

Stack from ghstack (oldest at bottom):

Now we have gather available in NCCL pg, we can switch our sharded_tensor.gather to use gather_object instead of all_gather_object, which will reduce the communication overhead.

fixes #66187

Differential Revision: D33688907

Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead.

Differential Revision: [D33688907](https://our.internmc.facebook.com/intern/diff/D33688907/)

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 21, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/aea2c196cf43997518669e0b4255b8cd32f16b96/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-manywheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk, ciflow/xla ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/linux, ciflow/rocm, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jan 21, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 5aa888b (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead.

fixes #66187

Differential Revision: [D33688907](https://our.internmc.facebook.com/intern/diff/D33688907/)

[ghstack-poisoned]
Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead.

fixes #66187

Differential Revision: [D33688907](https://our.internmc.facebook.com/intern/diff/D33688907/)

[ghstack-poisoned]
wanchaol pushed a commit that referenced this pull request Jan 21, 2022
Pull Request resolved: #71624

Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead.
ghstack-source-id: 147386510

Differential Revision: [D33688907](https://our.internmc.facebook.com/intern/diff/D33688907/)
@wanchaol wanchaol requested a review from fduwjj January 21, 2022 22:10
tensor = shard.tensor

out_narrow_view = out
assert out_narrow_view is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't _validate_output_tensor_for_gather validate this? Why do we need another assert here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is added purely for mypy linter, it seems like mypy couldn't understand _validate_output_tensor_for_gather checks, so have to do this assert here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually I add a mypy ignore for something like this, but its up to you.

# https://github.com/pytorch/pytorch/issues/66187
dist.all_gather_object(
gathered_shards: List[Optional[List[Shard]]] = [None] * world_size if rank == dst else []
dist.gather_object(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(no need to update current PR, can be done in followup ones)

Not sure how performance critical this is, but as we discussed in today's meeting, this indeed looks more expensive than necessary, as there will be additional H2D + D2H copies. I'd assume handling tensor and non-tensor parts separately would be faster for large ShardedTensors.

def _object_to_tensor(obj):
f = io.BytesIO()
_pickler(f).dump(obj)
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
byte_tensor = torch.ByteTensor(byte_storage)
local_size = torch.LongTensor([byte_tensor.numel()])
return byte_tensor, local_size
def _tensor_to_object(tensor, tensor_size):
buf = tensor.numpy().tobytes()[:tensor_size]
return _unpickler(io.BytesIO(buf)).load()

BTW, is the non-tensor meta info static? If so, we can cache those?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah thanks for suggestion, will try to use two separate gather to improve the perf. The non-tensor meta info might not be static I think (i.e. if we do resharding on a ShardedTensor, the Shard.metadata might get changed, to different ranks, or the shard_offset/size changes).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking of a new way to possibly only requires one gather call, this requires us make Shard a subclass of torch.Tensor(metadata is a field in python), so that we can do gather alone. But I am not sure if our c10d collectives support custom tensor objects? (Maybe not as we eventually lowering the collective to C++ and we might only have the at::Tensor do the real communication, not the metadata)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I am not sure if our c10d collectives support custom tensor objects?

It should. At least it worked for SparseTensor. But I am not sure if that's sufficient for ShardedTensor

class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
public:
AsyncSparseAllreduceWork(
const std::shared_ptr<gloo::Context>& context,
std::vector<at::Tensor>& inputs,
uint32_t tag)
: ProcessGroupGloo::AsyncWork({inputs}, "gloo:sparse_all_reduce", inputs),
context(context),
inputs(inputs),
tag(tag) {}
std::shared_ptr<gloo::Context> context;
std::vector<at::Tensor> inputs;
const uint32_t tag;
// We share dimensionality about the sparse tensors before collecting
// their contents. We assume here that the maximum number of sparse
// and dense dimensions is 4. This is stored in a contiguous piece of
// memory so that we can easily run allgather on it.
//
// The layout of this memory is as follows:
//
// - [0:4]: sparse dims
// - [4:8]: dense dims
// - [8]: nnz
//
class SparseTensorMetadata {
public:
static constexpr auto dim = 9;
// Construct from an existing metadata tensor to facilitate structured
// access to metadata from peers, after gathering it.
explicit SparseTensorMetadata(at::Tensor metadata)
: metadata_(metadata), data_(metadata_.data_ptr<int64_t>()) {
AT_ASSERT(metadata.scalar_type() == at::kLong);
AT_ASSERT(metadata.dim() == 1);
AT_ASSERT(metadata.size(0) == dim);
}
// Populate the metadata.
void populate_from_sparse_tensor(const at::Tensor& tensor) {
const auto sparse_dim = tensor.sparse_dim();
AT_ASSERT(sparse_dim <= 4);
for (const auto i : c10::irange(4)) {
if (i < sparse_dim) {
data_[i] = tensor.size(i);
}
}
const auto dense_dim = tensor.dense_dim();
AT_ASSERT(dense_dim <= 4);
for (const auto i : c10::irange(4)) {
if (i < dense_dim) {
data_[i + 4] = tensor.size(sparse_dim + i);
}
}
data_[8] = tensor._nnz();
}
std::vector<int64_t> sizes() const {
std::vector<int64_t> sizes;
// Sparse sizes
for (const auto i : c10::irange(4)) {
if (data_[i] <= 0) {
break;
}
sizes.push_back(data_[i]);
}
// Dense sizes
for (const auto i : c10::irange(4, 8)) {
if (data_[i] <= 0) {
break;
}
sizes.push_back(data_[i]);
}
return sizes;
}
int64_t nnz() const {
return data_[8];
}
protected:
at::Tensor metadata_;
int64_t* data_;
};

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave a try on using two separate gather for metadata and tensors to avoid the pickling copies with gather_object today. I think it's more tricky than I thought, mainly because:

  1. We can use gather_object for metadatas, but we couldn't simply use gather for the local shard tensors, mainly because local_shards() is a list of tensor, but input of gather is a single tensor not a list of tensor, so gather won't work here
  2. We can try using torch.cat locally on each rank to form a single tensor before the gather collective, but we need to split them afterwards as they might not be adjacent to each other (local shards on this rank might contain two tensors that's far away in the global position logically). So we might need to insert additional insertion point in the first gather_object to split the combined tensors from the second gather call. This is pretty tricky from my understanding.

Any suggestions are appreciated :)

For this PR itself, I think we can land it as is, as the code before this PR used all_gather_object anyways. I will think more about how to solve this pickling issue and make a follow up PR to improve the perf.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol Can we create a gh issue to track this follow up improvement?

Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead.

fixes #66187

Differential Revision: [D33688907](https://our.internmc.facebook.com/intern/diff/D33688907/)

[ghstack-poisoned]
tensor = shard.tensor

out_narrow_view = out
assert out_narrow_view is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually I add a mypy ignore for something like this, but its up to you.

# https://github.com/pytorch/pytorch/issues/66187
dist.all_gather_object(
gathered_shards: List[Optional[List[Shard]]] = [None] * world_size if rank == dst else []
dist.gather_object(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol Can we create a gh issue to track this follow up improvement?

Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a rookie question here, if this is for perf purpose. For change like this, do we measure the perf change before and after? If so, I am curious how do we do the experiment here?

wanchaol added 2 commits March 9, 2022 16:45
Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead.

fixes #66187

Differential Revision: [D33688907](https://our.internmc.facebook.com/intern/diff/D33688907/)

[ghstack-poisoned]
Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead.

fixes #66187

Differential Revision: [D33688907](https://our.internmc.facebook.com/intern/diff/D33688907/)

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Mar 10, 2022
Summary:
Pull Request resolved: #71624

Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead.

TODO: To further reduce the comm overhead, we need to figure out a way to avoid using `gather_object`, as `gather_object` or `all_gather_object` incurs pickling copy between devices.

ghstack-source-id: 151007578

Test Plan: wait for ci

Reviewed By: pritamdamania87

Differential Revision: D33688907

fbshipit-source-id: 2073c5a46c33a7a2640a9e3599dc795d9e4c0a1e
@github-actions
Copy link
Contributor

Hey @wanchaol.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants