Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4d9e704
Introduce _sharded_tensor.[normal_, uniform_, kaiming_uniform_] utils…
bowangbj Aug 26, 2021
2de1837
revert nn.init on "Introduce _sharded_tensor.[normal_, uniform_, kaim…
bowangbj Aug 26, 2021
549c2c9
revert nn.init for real on "Introduce _sharded_tensor.[normal_, unifo…
bowangbj Aug 26, 2021
7c2f831
test on "Introduce _sharded_tensor.[normal_, uniform_, kaiming_unifor…
bowangbj Aug 26, 2021
7828a25
test on "Introduce _sharded_tensor.[normal_, uniform_, kaiming_unifor…
bowangbj Aug 26, 2021
fb5d785
use torch function to extend on "Add torch.nn.init.uniform_ operator …
bowangbj Oct 8, 2021
ab1b249
SYNC head and resolve conflict on "Add torch.nn.init.uniform_ operato…
bowangbj Oct 8, 2021
591901d
lint on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj Oct 8, 2021
5a8fc94
fix test on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj Oct 8, 2021
6cbcc7f
resolve pritam comment on "Add torch.nn.init.uniform_ operator to Sha…
bowangbj Oct 11, 2021
7473742
revert test_sharded_tensor on "Add torch.nn.init.uniform_ operator to…
bowangbj Oct 11, 2021
830dd80
extra empty line on "Add torch.nn.init.uniform_ operator to ShardedTe…
bowangbj Oct 11, 2021
f0172ea
Update on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj Oct 11, 2021
eebb704
block test_init for windows on "Add torch.nn.init.uniform_ operator t…
bowangbj Oct 11, 2021
1351095
block test_init for windows on "Add torch.nn.init.uniform_ operator t…
bowangbj Oct 12, 2021
9a8993a
resolve pritam comment on "Add torch.nn.init.uniform_ operator to Sha…
bowangbj Oct 12, 2021
853402d
lint error on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj Oct 12, 2021
6f3f7ea
rename validate_params on "Add torch.nn.init.uniform_ operator to Sha…
bowangbj Oct 20, 2021
44e92ac
lint on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj Oct 20, 2021
320000c
lint on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj Oct 20, 2021
e1c17e4
final update on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj Oct 21, 2021
cfb8301
final update on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj Oct 21, 2021
f18df46
resolve conflict on "Add torch.nn.init.uniform_ operator to ShardedTe…
bowangbj Oct 21, 2021
a570b13
fix linear on "Add torch.nn.init.uniform_ operator to ShardedTensor."
bowangbj Oct 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions test/distributed/_sharded_tensor/ops/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import sys
import torch

from torch.distributed import _sharded_tensor
from torch.distributed._sharding_spec import (
ChunkShardingSpec,
)
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.distributed._sharded_tensor import (
ShardedTensorTestBase,
with_comms,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)

if TEST_WITH_DEV_DBG_ASAN:
print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
sys.exit(0)

class TestShardedTensorNNInit(ShardedTensorTestBase):
""" Testing torch.nn.init functions for ShardedTensor """

@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_init_sharded_tensor_with_uniform(self):
""" Test torch.nn.init.uniform_(ShardedTensor, a, b) """

spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
],
)
h, w = 8, 2
expected_h = 2
expected_device = torch.device(f"cuda:{self.rank}")
a, b = 10, 20

seed = 1234
dtype = torch.double

sharded_tensor = _sharded_tensor.empty(spec, h, w, dtype=dtype)
self.assertEqual(1, len(sharded_tensor.local_shards()))

# Clone local tensor to ensure torch.nn.init starts from the same input
local_tensor_clone = torch.clone(sharded_tensor.local_shards()[0].tensor)
torch.manual_seed(seed)
torch.nn.init.uniform_(sharded_tensor, a=a, b=b)

torch.manual_seed(seed)
torch.nn.init.uniform_(local_tensor_clone, a=a, b=b)
self.assertEqual(local_tensor_clone, sharded_tensor.local_shards()[0].tensor)


if __name__ == '__main__':
run_tests()
4 changes: 4 additions & 0 deletions test/distributed/_sharded_tensor/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
from torch.testing._internal.distributed._sharded_tensor import (
ShardedTensorTestBase,
Expand Down Expand Up @@ -85,3 +86,6 @@ def test_sharded_linear_rowwise(self):
# Test uneven split.
self._run_sharded_linear(spec, [5, 19], [19, 11], 1)
self._run_sharded_linear(spec, [5, 21], [21, 11], 1)

if __name__ == '__main__':
run_tests()
3 changes: 3 additions & 0 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def skip_test_p(name: str) -> bool:
"distributed/elastic/multiprocessing/api_test",
"distributed/_sharded_tensor/test_sharded_tensor",
"distributed/_sharded_tensor/ops/test_embedding",
"distributed/_sharded_tensor/ops/test_init",
"distributed/_sharded_tensor/ops/test_linear",
] + FSDP_TEST

Expand All @@ -209,6 +210,7 @@ def skip_test_p(name: str) -> bool:
"distributed/rpc/cuda/test_tensorpipe_agent",
"distributed/_sharded_tensor/test_sharded_tensor",
"distributed/_sharded_tensor/ops/test_embedding",
"distributed/_sharded_tensor/ops/test_init",
"distributed/_sharded_tensor/ops/test_linear",
"test_determination",
"test_multiprocessing",
Expand Down Expand Up @@ -345,6 +347,7 @@ def skip_test_p(name: str) -> bool:
"distributed/_sharding_spec/test_sharding_spec",
"distributed/_sharded_tensor/test_sharded_tensor",
"distributed/_sharded_tensor/ops/test_embedding",
"distributed/_sharded_tensor/ops/test_init",
"distributed/_sharded_tensor/ops/test_linear",
] + [test for test in TESTS if test.startswith("distributed/fsdp")]

Expand Down
5 changes: 3 additions & 2 deletions torch/distributed/_sharded_tensor/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
get_chunked_dim_size,
)
from torch.types import Number
from .ops import sharded_embedding, sharded_linear
from .ops import sharded_embedding, sharded_linear, uniform_

# Tracking for sharded tensor objects.
_sharded_tensor_lock = threading.Lock()
Expand Down Expand Up @@ -638,7 +638,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
return sharded_linear(types, args, kwargs, self._process_group)
if func == torch.nn.functional.embedding:
return sharded_embedding(types, args, kwargs, self._process_group)

elif func == torch.nn.init.uniform_:
return uniform_(types, args, kwargs)
raise RuntimeError(
f"torch function '{func.__name__}', with args: {args} and "
f"kwargs: {kwargs} not supported for ShardedTensor!")
Expand Down
1 change: 1 addition & 0 deletions torch/distributed/_sharded_tensor/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .init import uniform_
from .linear import sharded_linear
from .embedding import sharded_embedding
26 changes: 26 additions & 0 deletions torch/distributed/_sharded_tensor/ops/init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch

def validate_param(param, param_name):
if param is None:
raise ValueError(f"param: {param_name} shouldn't be None!")

def uniform_(types, args=(), kwargs=None):
r"""
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
Args:
sharded_tensor: tensor sharded across devices
a: the lower bound of the uniform distribution
b: the upper bound of the uniform distribution
"""
validate_param(kwargs, "kwargs")
sharded_tensor = kwargs["tensor"]
validate_param(sharded_tensor, "sharded_tensor")
a = kwargs['a']
validate_param(a, "a")
b = kwargs['b']
validate_param(b, "b")

for shard in sharded_tensor.local_shards():
torch.nn.init.uniform_(shard.tensor, a=a, b=b)
return sharded_tensor
5 changes: 5 additions & 0 deletions torch/nn/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from torch import Tensor
import torch

from ..overrides import (
has_torch_function_variadic,
handle_torch_function)

# These no_grad_* functions are necessary as wrappers around the parts of these
# functions that use `with torch.no_grad()`. The JIT doesn't support context
Expand Down Expand Up @@ -132,6 +135,8 @@ def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
>>> w = torch.empty(3, 5)
>>> nn.init.uniform_(w)
"""
if has_torch_function_variadic(tensor, a, b):
return handle_torch_function(uniform_, (tensor, a, b), tensor=tensor, a=a, b=b)
return _no_grad_uniform_(tensor, a, b)


Expand Down