Skip to content

Commit 4582ceb

Browse files
IvanKobzarevpytorchmergebot
authored andcommitted
[distributed][sharded_tensor] Move local_shards check from ShardedTensorBase to ShardedTensor (#100197)
Differential Revision: [D45369211](https://our.internmc.facebook.com/intern/diff/D45369211) Pull Request resolved: #100197 Approved by: https://github.com/fduwjj
1 parent 8556cf2 commit 4582ceb

File tree

2 files changed

+156
-67
lines changed

2 files changed

+156
-67
lines changed

test/distributed/_shard/sharded_tensor/test_sharded_tensor.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
pre_load_state_dict_hook,
2424
state_dict_hook,
2525
ShardedTensor,
26+
ShardedTensorBase,
27+
ShardedTensorMetadata,
2628
Shard
2729
)
2830
from torch.distributed._shard.sharding_spec import (
@@ -2776,5 +2778,64 @@ def test_create_shard_with_no_placement(self):
27762778
shard = Shard(torch.zeros(10), md)
27772779
self.assertIsNone(shard.metadata.placement)
27782780

2781+
class TestCreateTensorNoProcessGroupMode(TestCase):
2782+
def test_init_from_local_shards_and_global_metadata(self):
2783+
st_metadata: ShardedTensorMetadata = ShardedTensorMetadata(
2784+
shards_metadata=[
2785+
ShardMetadata(
2786+
shard_offsets=[0, 0], shard_sizes=[2, 2], placement="rank:0/cpu"
2787+
),
2788+
ShardMetadata(
2789+
shard_offsets=[2, 0], shard_sizes=[2, 2], placement="rank:1/cpu"
2790+
),
2791+
],
2792+
size=torch.Size([4, 2]),
2793+
)
2794+
st_local_shards: List[Shard] = []
2795+
for shard_metadata in st_metadata.shards_metadata:
2796+
st_local_shards.append(
2797+
Shard(
2798+
tensor=torch.zeros(
2799+
shard_metadata.shard_sizes,
2800+
device=shard_metadata.placement.device(),
2801+
),
2802+
metadata=shard_metadata,
2803+
)
2804+
)
2805+
2806+
ShardedTensorBase._init_from_local_shards_and_global_metadata(
2807+
local_shards=st_local_shards,
2808+
sharded_tensor_metadata=st_metadata,
2809+
)
2810+
2811+
def test_non_contiguous_local_shards(self):
2812+
st_metadata: ShardedTensorMetadata = ShardedTensorMetadata(
2813+
shards_metadata=[
2814+
ShardMetadata(
2815+
shard_offsets=[0, 0], shard_sizes=[2, 2], placement="rank:0/cpu"
2816+
),
2817+
ShardMetadata(
2818+
shard_offsets=[2, 0], shard_sizes=[2, 2], placement="rank:1/cpu"
2819+
),
2820+
],
2821+
size=torch.Size([4, 2]),
2822+
)
2823+
st_local_shards: List[Shard] = []
2824+
src = torch.randn(4, 2)
2825+
for shard_metadata in st_metadata.shards_metadata:
2826+
offsets = shard_metadata.shard_offsets
2827+
sizes = shard_metadata.shard_sizes
2828+
st_local_shards.append(
2829+
Shard(
2830+
tensor=src[offsets[0]:offsets[0] + sizes[0], offsets[1]:offsets[1] + sizes[1]],
2831+
metadata=shard_metadata,
2832+
)
2833+
)
2834+
2835+
ShardedTensorBase._init_from_local_shards_and_global_metadata(
2836+
local_shards=st_local_shards,
2837+
sharded_tensor_metadata=st_metadata,
2838+
)
2839+
27792840
if __name__ == '__main__':
27802841
run_tests()

torch/distributed/_shard/sharded_tensor/api.py

Lines changed: 95 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def _init_from_local_shards_and_global_metadata(
136136
local_shards: List[Shard],
137137
sharded_tensor_metadata: ShardedTensorMetadata,
138138
sharding_spec=None,
139-
) -> "ShardedTensor":
139+
) -> "ShardedTensorBase":
140140
"""
141141
Initialize a ShardedTensorBase with local shards and a global
142142
ShardedTensorMetadata built on each rank.
@@ -158,7 +158,7 @@ def _init_from_local_shards_and_global_metadata(
158158
else:
159159
spec = sharding_spec
160160

161-
sharded_tensor_base = ShardedTensor.__new__(
161+
sharded_tensor_base = ShardedTensorBase.__new__(
162162
ShardedTensor,
163163
spec,
164164
sharded_tensor_metadata.size,
@@ -168,67 +168,6 @@ def _init_from_local_shards_and_global_metadata(
168168
requires_grad=tensor_properties.requires_grad,
169169
)
170170

171-
def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False):
172-
tensor_property_or_metadata = (
173-
"tensor property" if is_property else "local ShardMetadata"
174-
)
175-
if expected != actual:
176-
raise ValueError(
177-
f"Local shards' tensor {prop_name} property is incompatible with "
178-
f"{tensor_property_or_metadata} on rank {rank}: "
179-
f"{tensor_property_or_metadata} {prop_name}={expected}, "
180-
f"local shard tensor {prop_name}={actual}."
181-
)
182-
183-
for shard in local_shards:
184-
shard_meta = shard.metadata
185-
local_shard_tensor = shard.tensor
186-
placement = shard_meta.placement
187-
assert placement is not None, "Must specify placement for `Shard`!"
188-
rank = placement.rank()
189-
local_device = placement.device()
190-
191-
_raise_if_mismatch(
192-
tensor_properties.layout,
193-
local_shard_tensor.layout,
194-
"layout",
195-
rank,
196-
True,
197-
)
198-
if not local_shard_tensor.is_contiguous():
199-
raise ValueError(
200-
"Only torch.contiguous_format memory_format is currently supported"
201-
)
202-
203-
_raise_if_mismatch(
204-
shard_meta.shard_sizes,
205-
list(local_shard_tensor.size()),
206-
"size",
207-
rank,
208-
)
209-
_raise_if_mismatch(
210-
tensor_properties.pin_memory,
211-
local_shard_tensor.is_pinned(),
212-
"pin_memory",
213-
rank,
214-
True,
215-
)
216-
_raise_if_mismatch(local_device, local_shard_tensor.device, "device", rank)
217-
_raise_if_mismatch(
218-
tensor_properties.dtype,
219-
local_shard_tensor.dtype,
220-
"dtype",
221-
rank,
222-
True,
223-
)
224-
_raise_if_mismatch(
225-
tensor_properties.requires_grad,
226-
local_shard_tensor.requires_grad,
227-
"requires_grad",
228-
rank,
229-
True,
230-
)
231-
232171
# check if shards_metadata have overlap shards
233172
validate_non_overlapping_shards_metadata(shards_metadata)
234173

@@ -925,11 +864,100 @@ def _init_from_local_shards_and_global_metadata( # type: ignore[override]
925864
f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) '
926865
f'on rank ({current_rank}) '
927866
)
928-
sharded_tensor = super(
929-
ShardedTensor, cls
930-
)._init_from_local_shards_and_global_metadata(
931-
local_shards, sharded_tensor_metadata, sharding_spec=sharding_spec
867+
868+
shards_metadata = sharded_tensor_metadata.shards_metadata
869+
tensor_properties = sharded_tensor_metadata.tensor_properties
870+
871+
if len(shards_metadata) == 0:
872+
raise ValueError("shards_metadata must not be empty!")
873+
874+
if tensor_properties.layout != torch.strided:
875+
raise ValueError("Only torch.strided layout is currently supported")
876+
877+
if sharding_spec is None:
878+
spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata)
879+
else:
880+
spec = sharding_spec
881+
882+
sharded_tensor = ShardedTensor.__new__(
883+
ShardedTensor,
884+
spec,
885+
sharded_tensor_metadata.size,
886+
dtype=tensor_properties.dtype,
887+
layout=tensor_properties.layout,
888+
pin_memory=tensor_properties.pin_memory,
889+
requires_grad=tensor_properties.requires_grad,
932890
)
891+
892+
def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False):
893+
tensor_property_or_metadata = (
894+
"tensor property" if is_property else "local ShardMetadata"
895+
)
896+
if expected != actual:
897+
raise ValueError(
898+
f"Local shards' tensor {prop_name} property is incompatible with "
899+
f"{tensor_property_or_metadata} on rank {rank}: "
900+
f"{tensor_property_or_metadata} {prop_name}={expected}, "
901+
f"local shard tensor {prop_name}={actual}."
902+
)
903+
904+
for shard in local_shards:
905+
shard_meta = shard.metadata
906+
local_shard_tensor = shard.tensor
907+
placement = shard_meta.placement
908+
assert placement is not None, "Must specify placement for `Shard`!"
909+
rank = placement.rank()
910+
local_device = placement.device()
911+
912+
_raise_if_mismatch(
913+
tensor_properties.layout,
914+
local_shard_tensor.layout,
915+
"layout",
916+
rank,
917+
True,
918+
)
919+
if not local_shard_tensor.is_contiguous():
920+
raise ValueError(
921+
"Only torch.contiguous_format memory_format is currently supported"
922+
)
923+
924+
_raise_if_mismatch(
925+
shard_meta.shard_sizes,
926+
list(local_shard_tensor.size()),
927+
"size",
928+
rank,
929+
)
930+
_raise_if_mismatch(
931+
tensor_properties.pin_memory,
932+
local_shard_tensor.is_pinned(),
933+
"pin_memory",
934+
rank,
935+
True,
936+
)
937+
_raise_if_mismatch(local_device, local_shard_tensor.device, "device", rank)
938+
_raise_if_mismatch(
939+
tensor_properties.dtype,
940+
local_shard_tensor.dtype,
941+
"dtype",
942+
rank,
943+
True,
944+
)
945+
_raise_if_mismatch(
946+
tensor_properties.requires_grad,
947+
local_shard_tensor.requires_grad,
948+
"requires_grad",
949+
rank,
950+
True,
951+
)
952+
953+
# check if shards_metadata have overlap shards
954+
validate_non_overlapping_shards_metadata(shards_metadata)
955+
956+
# check if the shards_metadata is compatible with overall size of the sharded tensor.
957+
check_tensor(shards_metadata, list(sharded_tensor_metadata.size))
958+
959+
# done validation, add local_shards
960+
sharded_tensor._local_shards = local_shards
933961
sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
934962

935963
# run post initialization, i.e. map registration, rpc initialization

0 commit comments

Comments
 (0)