-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add torch.nn.init.uniform_ operator to ShardedTensor. #63997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
4d9e704
Introduce _sharded_tensor.[normal_, uniform_, kaiming_uniform_] utils…
bowangbj 2de1837
revert nn.init on "Introduce _sharded_tensor.[normal_, uniform_, kaim…
bowangbj 549c2c9
revert nn.init for real on "Introduce _sharded_tensor.[normal_, unifo…
bowangbj 7c2f831
test on "Introduce _sharded_tensor.[normal_, uniform_, kaiming_unifor…
bowangbj 7828a25
test on "Introduce _sharded_tensor.[normal_, uniform_, kaiming_unifor…
bowangbj fb5d785
use torch function to extend on "Add torch.nn.init.uniform_ operator …
bowangbj ab1b249
SYNC head and resolve conflict on "Add torch.nn.init.uniform_ operato…
bowangbj 591901d
lint on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj 5a8fc94
fix test on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj 6cbcc7f
resolve pritam comment on "Add torch.nn.init.uniform_ operator to Sha…
bowangbj 7473742
revert test_sharded_tensor on "Add torch.nn.init.uniform_ operator to…
bowangbj 830dd80
extra empty line on "Add torch.nn.init.uniform_ operator to ShardedTe…
bowangbj f0172ea
Update on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj eebb704
block test_init for windows on "Add torch.nn.init.uniform_ operator t…
bowangbj 1351095
block test_init for windows on "Add torch.nn.init.uniform_ operator t…
bowangbj 9a8993a
resolve pritam comment on "Add torch.nn.init.uniform_ operator to Sha…
bowangbj 853402d
lint error on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj 6f3f7ea
rename validate_params on "Add torch.nn.init.uniform_ operator to Sha…
bowangbj 44e92ac
lint on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj 320000c
lint on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj e1c17e4
final update on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj cfb8301
final update on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj f18df46
resolve conflict on "Add torch.nn.init.uniform_ operator to ShardedTe…
bowangbj a570b13
fix linear on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| import sys | ||
| import torch | ||
|
|
||
| from torch.distributed import _sharded_tensor | ||
| from torch.distributed._sharding_spec import ( | ||
| ChunkShardingSpec, | ||
| ) | ||
| from torch.testing._internal.common_distributed import ( | ||
| requires_nccl, | ||
| skip_if_lt_x_gpu, | ||
| ) | ||
| from torch.testing._internal.distributed._sharded_tensor import ( | ||
| ShardedTensorTestBase, | ||
| with_comms, | ||
| ) | ||
| from torch.testing._internal.common_utils import ( | ||
| TEST_WITH_DEV_DBG_ASAN, | ||
| run_tests, | ||
| ) | ||
|
|
||
| if TEST_WITH_DEV_DBG_ASAN: | ||
| print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr) | ||
| sys.exit(0) | ||
|
|
||
| class TestShardedTensorNNInit(ShardedTensorTestBase): | ||
| """ Testing torch.nn.init functions for ShardedTensor """ | ||
|
|
||
| @with_comms | ||
| @skip_if_lt_x_gpu(4) | ||
| @requires_nccl() | ||
| def test_init_sharded_tensor_with_uniform(self): | ||
| """ Test torch.nn.init.uniform_(ShardedTensor, a, b) """ | ||
|
|
||
| spec = ChunkShardingSpec( | ||
| dim=0, | ||
| placements=[ | ||
| "rank:0/cuda:0", | ||
| "rank:1/cuda:1", | ||
| "rank:2/cuda:2", | ||
| "rank:3/cuda:3", | ||
| ], | ||
| ) | ||
| h, w = 8, 2 | ||
| expected_h = 2 | ||
| expected_device = torch.device(f"cuda:{self.rank}") | ||
| a, b = 10, 20 | ||
|
|
||
| seed = 1234 | ||
| dtype = torch.double | ||
|
|
||
| sharded_tensor = _sharded_tensor.empty(spec, h, w, dtype=dtype) | ||
| self.assertEqual(1, len(sharded_tensor.local_shards())) | ||
|
|
||
| # Clone local tensor to ensure torch.nn.init starts from the same input | ||
| local_tensor_clone = torch.clone(sharded_tensor.local_shards()[0].tensor) | ||
| torch.manual_seed(seed) | ||
| torch.nn.init.uniform_(sharded_tensor, a=a, b=b) | ||
|
|
||
| torch.manual_seed(seed) | ||
| torch.nn.init.uniform_(local_tensor_clone, a=a, b=b) | ||
| self.assertEqual(local_tensor_clone, sharded_tensor.local_shards()[0].tensor) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| run_tests() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,3 @@ | ||
| from .init import uniform_ | ||
| from .linear import sharded_linear | ||
| from .embedding import sharded_embedding |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| import torch | ||
|
|
||
| def validate_param(param, param_name): | ||
| if param is None: | ||
| raise ValueError(f"param: {param_name} shouldn't be None!") | ||
bowangbj marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def uniform_(types, args=(), kwargs=None): | ||
| r""" | ||
| Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform | ||
| distribution :math:`\mathcal{U}(a, b)`. | ||
| Args: | ||
| sharded_tensor: tensor sharded across devices | ||
| a: the lower bound of the uniform distribution | ||
| b: the upper bound of the uniform distribution | ||
| """ | ||
| validate_param(kwargs, "kwargs") | ||
| sharded_tensor = kwargs["tensor"] | ||
| validate_param(sharded_tensor, "sharded_tensor") | ||
| a = kwargs['a'] | ||
| validate_param(a, "a") | ||
| b = kwargs['b'] | ||
| validate_param(b, "b") | ||
|
|
||
| for shard in sharded_tensor.local_shards(): | ||
| torch.nn.init.uniform_(shard.tensor, a=a, b=b) | ||
| return sharded_tensor | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.