Skip to content

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Feb 28, 2022

Stack from ghstack (oldest at bottom):

Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:~torch.Tensor subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a validate() API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.

Differential Revision: D34529374

NOTE FOR REVIEWERS: This PR has internal Facebook specific changes or comments, please review them on Phabricator!

Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.

Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Feb 28, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 64f70b1 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build pull / pytorch-xla-linux-bionic-py3.7-clang8 / test (xla, 1, 1, linux.2xlarge) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-03-24T00:21:40.0074595Z RuntimeError: /var...or_util.cpp:1109 : Type not supported: ComplexHalf
2022-03-24T00:21:40.0069735Z ----------------------------------------------------------------------
2022-03-24T00:21:40.0070114Z Traceback (most recent call last):
2022-03-24T00:21:40.0070759Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 389, in instantiated_test
2022-03-24T00:21:40.0071275Z     raise rte
2022-03-24T00:21:40.0071867Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 376, in instantiated_test
2022-03-24T00:21:40.0072302Z     result = test(self, **param_kwargs)
2022-03-24T00:21:40.0072686Z   File "/var/lib/jenkins/workspace/xla/test/../../test/test_torch.py", line 5198, in test_copy_
2022-03-24T00:21:40.0073073Z     src = make_tensor_wrapper((50,), dtype=src_dtype)
2022-03-24T00:21:40.0073621Z   File "/var/lib/jenkins/workspace/xla/test/../../test/test_torch.py", line 5193, in make_tensor_wrapper
2022-03-24T00:21:40.0074066Z     return torch.randn(shape, device=device, dtype=dtype)
2022-03-24T00:21:40.0074595Z RuntimeError: /var/lib/jenkins/workspace/xla/torch_xla/csrc/tensor_util.cpp:1109 : Type not supported: ComplexHalf
2022-03-24T00:21:40.0074957Z 
2022-03-24T00:21:40.1030718Z ----------------------------------------------------------------------
2022-03-24T00:21:40.1031172Z Ran 594 tests in 558.694s
2022-03-24T00:21:40.1031293Z 
2022-03-24T00:21:40.1031416Z FAILED (errors=9, skipped=367, expected failures=27)
2022-03-24T00:21:40.1031565Z 
2022-03-24T00:21:40.1031651Z Generating XML reports...
2022-03-24T00:21:40.1032121Z Generated XML report: test-reports/python-unittest/test.......test.test_torch/TEST-TestTorchDeviceTypeXLA-20220324001221.xml
2022-03-24T00:21:40.4660003Z + cleanup
2022-03-24T00:21:40.4660338Z + retcode=1

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 28, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/89bbf1737460e72ceaf04c13f6d958d3244863e8/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-manywheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk ✅ triggered
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-arm64-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-arm64-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
macos-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.3-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
pytorch-xla-linux-bionic-py3.7-clang8 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk, ciflow/xla 🚫 skipped

@facebook-github-bot facebook-github-bot added cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue labels Feb 28, 2022
wanchaol pushed a commit that referenced this pull request Feb 28, 2022
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.

Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!

ghstack-source-id: 150134881
Pull Request resolved: #73529
@wanchaol wanchaol requested a review from fduwjj February 28, 2022 20:04
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.

Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!

[ghstack-poisoned]
Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks great and thanks for working on this!


class ReplicatedTensor(torch.Tensor):
"""
ReplicatedTensor represents a tensor which is replicated across the world_size and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: world_size?

inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add one for _PartialTensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it when adding the comment, I decided to leave _PartialTensor out because:

  1. It's not a public API yet
  2. Ideally PartialTensor is not sth that user need to be aware of or worry about, it's more like intermediate results handled by our internal system. So I'm worried if we add the comment here, user might be a bit confused and need to learn what a PartialTensor is?

Let me know if that make sense or not :)

Copy link
Contributor

@fduwjj fduwjj Mar 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, makes sense. We can always change it later on.

return f"ReplicatedTensor({super(ReplicatedTensor, self).__repr__()})"

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one n00b question here, with this override, we can enable the comparison in the test right? (Like self.assertEqual?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean comparison between ReplicatedTensor and Tensor, or ShardedTensor? Since ReplicatedTensor is a subclass of Tensor, it should work well with Tensor, have to check if it works with ShardedTensor, will try adding some tests there.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked there's already assertEqual between Tensor/ReplicatedTensor, it's working, but for ReplicatedTensor/ShardedTensor, it's not yet supported, we need to add handling logic to binary_cmp, but in theory, I guess ShardedTensor will never equal to a ReplicatedTensor as they are different topology? we might need to define the rule for this here. cc @pritamdamania

# base on the inter-op rules we defined.
with torch._C.DisableTorchFunction():
rs = func(*new_args, **new_kwargs)
if func in get_default_nowrap_functions():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For learning purpose. Does this often mean the situation when it's field access like t.grad?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, more related to field access, this was copied from the tensor.__torch_function__ because we don't want to go into this function as we manage our own type through the rules we defined.

Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.

Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!

[ghstack-poisoned]

def _replicate_tensor(tensor: torch.Tensor) -> ReplicatedTensor:
"""
Given a :class:`torch.Tensor`, mark it as a ReplicatedTensore where all
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: s/ReplicatedTensore/ReplicatedTensor

setattr(module, param_name, st)


def _replicate_tensor(tensor: torch.Tensor) -> ReplicatedTensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need an API like this? Users can just call ReplicatedTensor(tensor)?

Copy link
Collaborator Author

@wanchaol wanchaol Mar 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My original intention on introducing this API is to make sure we have consistent APIs provided to the user. We now have:

  1. shard_module(module,plan)
  2. shard_parameter(module, param_name, spec)
  3. _shard_tensor(tensor, spec)

All these three APIs are being used to mark a param or tensor as ShardedTensor, I feel that we should have a similar API to mark a tensor as ReplicatedTensor, it makes the API more consistent from user prospective. Let me know if this make sense or we could just use ReplicatedTensor(tensor)

Comment on lines +21 to +25
NOTE: We do not gurantee equal content of ReplicatedTensor across nodes after its
construction. Although we defined proper inter-op rules to make sure ReplicatedTensor
stays the same, there's no enforcement on it (i.e. if you manually modify content on
some ranks, the modified value will not automatically get synced to other nodes). If
you wish to manually validate tensors are the same across ranks, use `validate()`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but I think we need two modes for ReplicatedTensor. The first mode is what we have here where it is just a tag to help sharded computations, but probably this should not be the default mode.

I think the default mode should be similar to DDP, where ReplicatedTensor broadcasts the torch.Tensor on rank 0 (probably can also be optionally specified which rank). Then in the backward pass for this mode, we always allreduce the gradients for the ReplicatedTensor. This means ReplicatedTensor can stand on its own.

you wish to manually validate tensors are the same across ranks, use `validate()`.
"""
def __new__(cls, data=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReplicatedTensor should take an optional process_group indicating the replication environment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree we should have ReplicatedTensor tight with a replication env (a pg), but I didn't do this initially bc it requires the Tensor to hold a metadata as its member, which torch.Tensor.make_subclass does not work in that way. Let me see if I can change it to make_wrapper_subclass and define __torch_dispatch__ together with __torch_function__

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually found a way to propagate the field even with make_subclass, so no need to define __torch_dispatch__ yet, just updated the PR

# back to tensor subclasses, where in our case, we need to control the output type
# base on the inter-op rules we defined.
with torch._C.DisableTorchFunction():
rs = func(*new_args, **new_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need new_args and new_kwargs here? Can't we jus pass in args and kwargs?


return rs

def validate(self, process_group=None) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

process_group should be passed in during construction of ReplicatedTensor.

rs = func(*new_args, **new_kwargs)
if func in get_default_nowrap_functions():
return rs
if not has_tensor and isinstance(rs, torch.Tensor) and not isinstance(rs, cls):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are adding two ReplicatedTensors here, shouldn't we validate they are on the same PG? I feel this check might be more clear if we assert all args are ReplicatedTensor and only in that case we return a ReplicatedTensor using rs.as_subclass. In all other cases, we return rs.

# validate it's a replicated tensor by checking values on all rank
validated = replica_tensor.validate()
self.assertEqual(validated, True)
self.assertEqual(replica_tensor + 2, torch.ones(3, 3) * 6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: validate type of replica_tensor + 2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the type of replica_tensor + 2 should becomes tensor instead of replicated tensor

Comment on lines 40 to 46
replica_tensor1 = ReplicatedTensor(local_tensor * 4)
replica_tensor2 = ReplicatedTensor(local_tensor * 6)

new_tensor = replica_tensor1 * replica_tensor2
self.assertTrue(isinstance(new_tensor, ReplicatedTensor))
self.assertEqual(new_tensor, torch.ones(3, 3) * 24)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work only if PGs are same.

Comment on lines +50 to +53
def test_replicated_tensor_inter_op_tensor(self):
local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
replica_tensor = ReplicatedTensor(local_tensor)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these two interops coming in follow up PRs?

  1. ShardedTensor + ReplicatedTensor
  2. PartialTensor + ReplicatedTensor

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, ShardedTensor + ReplicatedTensor coming in the next PR, PartialTensor should be a follow up PR as well.

Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.

Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!

[ghstack-poisoned]
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.

Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!

[ghstack-poisoned]
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.

Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!

[ghstack-poisoned]
Copy link
Contributor

@pritamdamania87 pritamdamania87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Just a minor comment regarding some refactoring.

Comment on lines 75 to 88
if isinstance(v, ShardedTensor):
# redispatch to ShardedTensor
# TODO: handle ShardedTensor inter-op with ReplicatedTensor
# TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor
return v.__torch_function__(func, types, args, kwargs)
if isinstance(v, ReplicatedTensor):
if replicated_pg is None:
replicated_pg = v.process_group
elif replicated_pg != v.process_group:
raise RuntimeError(
f"ReplicatedTensor operands must be in the same process group "
f"in torch function '{func.__name__}', but found at least two "
f"ReplicatedTensor operands in different process groups! ")
else:
new_kwargs[k] = v
all_replicated = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: repeated code between args and kwargs, maybe create a simple inline helper function and dedup this.

Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.

Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!

[ghstack-poisoned]
wanchaol added 2 commits March 23, 2022 12:14
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.

Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!

[ghstack-poisoned]
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.

Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!

[ghstack-poisoned]
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.

Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Mar 24, 2022
Summary:
Pull Request resolved: #73529

Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
ghstack-source-id: 152064781

Test Plan: test_replicated_tensor

Reviewed By: pritamdamania87, fduwjj

Differential Revision: D34529374

fbshipit-source-id: 16ccb300e9f9c47ac29a17eb6d46d029ab7d60b8
@github-actions
Copy link
Contributor

Hey @wanchaol.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

shahofblah pushed a commit that referenced this pull request Mar 25, 2022
Summary:
Pull Request resolved: #73529

Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
ghstack-source-id: 152064781

Test Plan: test_replicated_tensor

Reviewed By: pritamdamania87, fduwjj

Differential Revision: D34529374

fbshipit-source-id: 16ccb300e9f9c47ac29a17eb6d46d029ab7d60b8
(cherry picked from commit 44f4e11)
@facebook-github-bot facebook-github-bot deleted the gh/wanchaol/204/head branch March 27, 2022 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (sharded) release notes category sharded_tensor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants