@@ -136,7 +136,7 @@ def _init_from_local_shards_and_global_metadata(
136136 local_shards : List [Shard ],
137137 sharded_tensor_metadata : ShardedTensorMetadata ,
138138 sharding_spec = None ,
139- ) -> "ShardedTensor " :
139+ ) -> "ShardedTensorBase " :
140140 """
141141 Initialize a ShardedTensorBase with local shards and a global
142142 ShardedTensorMetadata built on each rank.
@@ -158,7 +158,7 @@ def _init_from_local_shards_and_global_metadata(
158158 else :
159159 spec = sharding_spec
160160
161- sharded_tensor_base = ShardedTensor .__new__ (
161+ sharded_tensor_base = ShardedTensorBase .__new__ (
162162 ShardedTensor ,
163163 spec ,
164164 sharded_tensor_metadata .size ,
@@ -168,67 +168,6 @@ def _init_from_local_shards_and_global_metadata(
168168 requires_grad = tensor_properties .requires_grad ,
169169 )
170170
171- def _raise_if_mismatch (expected , actual , prop_name , rank , is_property = False ):
172- tensor_property_or_metadata = (
173- "tensor property" if is_property else "local ShardMetadata"
174- )
175- if expected != actual :
176- raise ValueError (
177- f"Local shards' tensor { prop_name } property is incompatible with "
178- f"{ tensor_property_or_metadata } on rank { rank } : "
179- f"{ tensor_property_or_metadata } { prop_name } ={ expected } , "
180- f"local shard tensor { prop_name } ={ actual } ."
181- )
182-
183- for shard in local_shards :
184- shard_meta = shard .metadata
185- local_shard_tensor = shard .tensor
186- placement = shard_meta .placement
187- assert placement is not None , "Must specify placement for `Shard`!"
188- rank = placement .rank ()
189- local_device = placement .device ()
190-
191- _raise_if_mismatch (
192- tensor_properties .layout ,
193- local_shard_tensor .layout ,
194- "layout" ,
195- rank ,
196- True ,
197- )
198- if not local_shard_tensor .is_contiguous ():
199- raise ValueError (
200- "Only torch.contiguous_format memory_format is currently supported"
201- )
202-
203- _raise_if_mismatch (
204- shard_meta .shard_sizes ,
205- list (local_shard_tensor .size ()),
206- "size" ,
207- rank ,
208- )
209- _raise_if_mismatch (
210- tensor_properties .pin_memory ,
211- local_shard_tensor .is_pinned (),
212- "pin_memory" ,
213- rank ,
214- True ,
215- )
216- _raise_if_mismatch (local_device , local_shard_tensor .device , "device" , rank )
217- _raise_if_mismatch (
218- tensor_properties .dtype ,
219- local_shard_tensor .dtype ,
220- "dtype" ,
221- rank ,
222- True ,
223- )
224- _raise_if_mismatch (
225- tensor_properties .requires_grad ,
226- local_shard_tensor .requires_grad ,
227- "requires_grad" ,
228- rank ,
229- True ,
230- )
231-
232171 # check if shards_metadata have overlap shards
233172 validate_non_overlapping_shards_metadata (shards_metadata )
234173
@@ -925,11 +864,100 @@ def _init_from_local_shards_and_global_metadata( # type: ignore[override]
925864 f'shards metadata in sharded_tensor_metadata ({ len (local_shard_metadatas )} ) '
926865 f'on rank ({ current_rank } ) '
927866 )
928- sharded_tensor = super (
929- ShardedTensor , cls
930- )._init_from_local_shards_and_global_metadata (
931- local_shards , sharded_tensor_metadata , sharding_spec = sharding_spec
867+
868+ shards_metadata = sharded_tensor_metadata .shards_metadata
869+ tensor_properties = sharded_tensor_metadata .tensor_properties
870+
871+ if len (shards_metadata ) == 0 :
872+ raise ValueError ("shards_metadata must not be empty!" )
873+
874+ if tensor_properties .layout != torch .strided :
875+ raise ValueError ("Only torch.strided layout is currently supported" )
876+
877+ if sharding_spec is None :
878+ spec = shard_spec ._infer_sharding_spec_from_shards_metadata (shards_metadata )
879+ else :
880+ spec = sharding_spec
881+
882+ sharded_tensor = ShardedTensor .__new__ (
883+ ShardedTensor ,
884+ spec ,
885+ sharded_tensor_metadata .size ,
886+ dtype = tensor_properties .dtype ,
887+ layout = tensor_properties .layout ,
888+ pin_memory = tensor_properties .pin_memory ,
889+ requires_grad = tensor_properties .requires_grad ,
932890 )
891+
892+ def _raise_if_mismatch (expected , actual , prop_name , rank , is_property = False ):
893+ tensor_property_or_metadata = (
894+ "tensor property" if is_property else "local ShardMetadata"
895+ )
896+ if expected != actual :
897+ raise ValueError (
898+ f"Local shards' tensor { prop_name } property is incompatible with "
899+ f"{ tensor_property_or_metadata } on rank { rank } : "
900+ f"{ tensor_property_or_metadata } { prop_name } ={ expected } , "
901+ f"local shard tensor { prop_name } ={ actual } ."
902+ )
903+
904+ for shard in local_shards :
905+ shard_meta = shard .metadata
906+ local_shard_tensor = shard .tensor
907+ placement = shard_meta .placement
908+ assert placement is not None , "Must specify placement for `Shard`!"
909+ rank = placement .rank ()
910+ local_device = placement .device ()
911+
912+ _raise_if_mismatch (
913+ tensor_properties .layout ,
914+ local_shard_tensor .layout ,
915+ "layout" ,
916+ rank ,
917+ True ,
918+ )
919+ if not local_shard_tensor .is_contiguous ():
920+ raise ValueError (
921+ "Only torch.contiguous_format memory_format is currently supported"
922+ )
923+
924+ _raise_if_mismatch (
925+ shard_meta .shard_sizes ,
926+ list (local_shard_tensor .size ()),
927+ "size" ,
928+ rank ,
929+ )
930+ _raise_if_mismatch (
931+ tensor_properties .pin_memory ,
932+ local_shard_tensor .is_pinned (),
933+ "pin_memory" ,
934+ rank ,
935+ True ,
936+ )
937+ _raise_if_mismatch (local_device , local_shard_tensor .device , "device" , rank )
938+ _raise_if_mismatch (
939+ tensor_properties .dtype ,
940+ local_shard_tensor .dtype ,
941+ "dtype" ,
942+ rank ,
943+ True ,
944+ )
945+ _raise_if_mismatch (
946+ tensor_properties .requires_grad ,
947+ local_shard_tensor .requires_grad ,
948+ "requires_grad" ,
949+ rank ,
950+ True ,
951+ )
952+
953+ # check if shards_metadata have overlap shards
954+ validate_non_overlapping_shards_metadata (shards_metadata )
955+
956+ # check if the shards_metadata is compatible with overall size of the sharded tensor.
957+ check_tensor (shards_metadata , list (sharded_tensor_metadata .size ))
958+
959+ # done validation, add local_shards
960+ sharded_tensor ._local_shards = local_shards
933961 sharded_tensor ._prepare_init (process_group = process_group , init_rrefs = init_rrefs )
934962
935963 # run post initialization, i.e. map registration, rpc initialization
0 commit comments