-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[shard] Add ReplicatedTensor #73529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[shard] Add ReplicatedTensor #73529
Conversation
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.
TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!
[ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 64f70b1 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
CI Flow Status⚛️ CI FlowRuleset - Version:
|
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.
TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!
ghstack-source-id: 150134881
Pull Request resolved: #73529
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.
TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!
[ghstack-poisoned]
fduwjj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks great and thanks for working on this!
|
|
||
| class ReplicatedTensor(torch.Tensor): | ||
| """ | ||
| ReplicatedTensor represents a tensor which is replicated across the world_size and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: world_size?
| inter-op rules defined as (using torch.add as an example op): | ||
| ReplicatedTensor + ReplicatedTensor = ReplicatedTensor | ||
| ReplicatedTensor + torch.Tensor = torch.Tensor | ||
| ReplicatedTensor + ShardedTensor = ShardedTensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add one for _PartialTensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about it when adding the comment, I decided to leave _PartialTensor out because:
- It's not a public API yet
- Ideally PartialTensor is not sth that user need to be aware of or worry about, it's more like intermediate results handled by our internal system. So I'm worried if we add the comment here, user might be a bit confused and need to learn what a PartialTensor is?
Let me know if that make sense or not :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, makes sense. We can always change it later on.
| return f"ReplicatedTensor({super(ReplicatedTensor, self).__repr__()})" | ||
|
|
||
| @classmethod | ||
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one n00b question here, with this override, we can enable the comparison in the test right? (Like self.assertEqual?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean comparison between ReplicatedTensor and Tensor, or ShardedTensor? Since ReplicatedTensor is a subclass of Tensor, it should work well with Tensor, have to check if it works with ShardedTensor, will try adding some tests there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checked there's already assertEqual between Tensor/ReplicatedTensor, it's working, but for ReplicatedTensor/ShardedTensor, it's not yet supported, we need to add handling logic to binary_cmp, but in theory, I guess ShardedTensor will never equal to a ReplicatedTensor as they are different topology? we might need to define the rule for this here. cc @pritamdamania
| # base on the inter-op rules we defined. | ||
| with torch._C.DisableTorchFunction(): | ||
| rs = func(*new_args, **new_kwargs) | ||
| if func in get_default_nowrap_functions(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For learning purpose. Does this often mean the situation when it's field access like t.grad?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, more related to field access, this was copied from the tensor.__torch_function__ because we don't want to go into this function as we manage our own type through the rules we defined.
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.
TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!
[ghstack-poisoned]
torch/distributed/_shard/api.py
Outdated
|
|
||
| def _replicate_tensor(tensor: torch.Tensor) -> ReplicatedTensor: | ||
| """ | ||
| Given a :class:`torch.Tensor`, mark it as a ReplicatedTensore where all |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: s/ReplicatedTensore/ReplicatedTensor
torch/distributed/_shard/api.py
Outdated
| setattr(module, param_name, st) | ||
|
|
||
|
|
||
| def _replicate_tensor(tensor: torch.Tensor) -> ReplicatedTensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need an API like this? Users can just call ReplicatedTensor(tensor)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My original intention on introducing this API is to make sure we have consistent APIs provided to the user. We now have:
- shard_module(module,plan)
- shard_parameter(module, param_name, spec)
- _shard_tensor(tensor, spec)
All these three APIs are being used to mark a param or tensor as ShardedTensor, I feel that we should have a similar API to mark a tensor as ReplicatedTensor, it makes the API more consistent from user prospective. Let me know if this make sense or we could just use ReplicatedTensor(tensor)
| NOTE: We do not gurantee equal content of ReplicatedTensor across nodes after its | ||
| construction. Although we defined proper inter-op rules to make sure ReplicatedTensor | ||
| stays the same, there's no enforcement on it (i.e. if you manually modify content on | ||
| some ranks, the modified value will not automatically get synced to other nodes). If | ||
| you wish to manually validate tensors are the same across ranks, use `validate()`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for this PR, but I think we need two modes for ReplicatedTensor. The first mode is what we have here where it is just a tag to help sharded computations, but probably this should not be the default mode.
I think the default mode should be similar to DDP, where ReplicatedTensor broadcasts the torch.Tensor on rank 0 (probably can also be optionally specified which rank). Then in the backward pass for this mode, we always allreduce the gradients for the ReplicatedTensor. This means ReplicatedTensor can stand on its own.
| you wish to manually validate tensors are the same across ranks, use `validate()`. | ||
| """ | ||
| def __new__(cls, data=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ReplicatedTensor should take an optional process_group indicating the replication environment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree we should have ReplicatedTensor tight with a replication env (a pg), but I didn't do this initially bc it requires the Tensor to hold a metadata as its member, which torch.Tensor.make_subclass does not work in that way. Let me see if I can change it to make_wrapper_subclass and define __torch_dispatch__ together with __torch_function__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually found a way to propagate the field even with make_subclass, so no need to define __torch_dispatch__ yet, just updated the PR
| # back to tensor subclasses, where in our case, we need to control the output type | ||
| # base on the inter-op rules we defined. | ||
| with torch._C.DisableTorchFunction(): | ||
| rs = func(*new_args, **new_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need new_args and new_kwargs here? Can't we jus pass in args and kwargs?
|
|
||
| return rs | ||
|
|
||
| def validate(self, process_group=None) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
process_group should be passed in during construction of ReplicatedTensor.
| rs = func(*new_args, **new_kwargs) | ||
| if func in get_default_nowrap_functions(): | ||
| return rs | ||
| if not has_tensor and isinstance(rs, torch.Tensor) and not isinstance(rs, cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are adding two ReplicatedTensors here, shouldn't we validate they are on the same PG? I feel this check might be more clear if we assert all args are ReplicatedTensor and only in that case we return a ReplicatedTensor using rs.as_subclass. In all other cases, we return rs.
| # validate it's a replicated tensor by checking values on all rank | ||
| validated = replica_tensor.validate() | ||
| self.assertEqual(validated, True) | ||
| self.assertEqual(replica_tensor + 2, torch.ones(3, 3) * 6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: validate type of replica_tensor + 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the type of replica_tensor + 2 should becomes tensor instead of replicated tensor
| replica_tensor1 = ReplicatedTensor(local_tensor * 4) | ||
| replica_tensor2 = ReplicatedTensor(local_tensor * 6) | ||
|
|
||
| new_tensor = replica_tensor1 * replica_tensor2 | ||
| self.assertTrue(isinstance(new_tensor, ReplicatedTensor)) | ||
| self.assertEqual(new_tensor, torch.ones(3, 3) * 24) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should work only if PGs are same.
| def test_replicated_tensor_inter_op_tensor(self): | ||
| local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4 | ||
| replica_tensor = ReplicatedTensor(local_tensor) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these two interops coming in follow up PRs?
- ShardedTensor + ReplicatedTensor
- PartialTensor + ReplicatedTensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, ShardedTensor + ReplicatedTensor coming in the next PR, PartialTensor should be a follow up PR as well.
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.
TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!
[ghstack-poisoned]
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.
TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!
[ghstack-poisoned]
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.
TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!
[ghstack-poisoned]
pritamdamania87
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just a minor comment regarding some refactoring.
| if isinstance(v, ShardedTensor): | ||
| # redispatch to ShardedTensor | ||
| # TODO: handle ShardedTensor inter-op with ReplicatedTensor | ||
| # TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor | ||
| return v.__torch_function__(func, types, args, kwargs) | ||
| if isinstance(v, ReplicatedTensor): | ||
| if replicated_pg is None: | ||
| replicated_pg = v.process_group | ||
| elif replicated_pg != v.process_group: | ||
| raise RuntimeError( | ||
| f"ReplicatedTensor operands must be in the same process group " | ||
| f"in torch function '{func.__name__}', but found at least two " | ||
| f"ReplicatedTensor operands in different process groups! ") | ||
| else: | ||
| new_kwargs[k] = v | ||
| all_replicated = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: repeated code between args and kwargs, maybe create a simple inline helper function and dedup this.
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.
TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!
[ghstack-poisoned]
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.
TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!
[ghstack-poisoned]
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.
TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!
[ghstack-poisoned]
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.
TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
Differential Revision: [D34529374](https://our.internmc.facebook.com/intern/diff/D34529374/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34529374/)!
[ghstack-poisoned]
Summary: Pull Request resolved: #73529 Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size. ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op): ReplicatedTensor + ReplicatedTensor = ReplicatedTensor ReplicatedTensor + torch.Tensor = torch.Tensor ReplicatedTensor + ShardedTensor = ShardedTensor We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not. TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor. ghstack-source-id: 152064781 Test Plan: test_replicated_tensor Reviewed By: pritamdamania87, fduwjj Differential Revision: D34529374 fbshipit-source-id: 16ccb300e9f9c47ac29a17eb6d46d029ab7d60b8
|
Hey @wanchaol. |
Summary: Pull Request resolved: #73529 Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size. ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op): ReplicatedTensor + ReplicatedTensor = ReplicatedTensor ReplicatedTensor + torch.Tensor = torch.Tensor ReplicatedTensor + ShardedTensor = ShardedTensor We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not. TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor. ghstack-source-id: 152064781 Test Plan: test_replicated_tensor Reviewed By: pritamdamania87, fduwjj Differential Revision: D34529374 fbshipit-source-id: 16ccb300e9f9c47ac29a17eb6d46d029ab7d60b8 (cherry picked from commit 44f4e11)
Stack from ghstack (oldest at bottom):
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.
ReplicatedTensor is a :class:
~torch.Tensorsubclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
We also added a
validate()API to help user validate if a replicated tensor on certain process_group is truly replicated or not.TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
Differential Revision: D34529374
NOTE FOR REVIEWERS: This PR has internal Facebook specific changes or comments, please review them on Phabricator!