1212)
1313from torch .distributed ._tensor .dispatch import OpSchema , OutputSharding
1414from torch .distributed ._tensor .ops .common_rules import einop_rule , pointwise_rule
15- from torch .distributed ._tensor .ops .utils import register_prop_rule
15+ from torch .distributed ._tensor .ops .utils import register_prop_rule , normalize_dim
1616
1717
1818# NOTE: the default propagation rule should apply for
@@ -161,15 +161,10 @@ def unshard_tensor_dim(
161161
162162
163163def is_tensor_dim_sharded (
164- placements : Sequence [ Placement ] , dim : int
164+ spec : DTensorSpec , dim : int
165165) -> bool :
166166 """Return True if tensor dim is sharded"""
167- return any (
168- tuple (
169- True if (isinstance (p , Shard ) and p .dim == dim ) else False
170- for p in placements
171- )
172- )
167+ return (dim < spec .ndim ) and spec .dim_map [dim ] >= 0
173168
174169
175170def _prop_all_but_dim (
@@ -497,17 +492,15 @@ def cat_rule(op_schema: OpSchema) -> OutputSharding:
497492
498493 dim = 0 # default dim = 0
499494 if (len (op_schema .args_schema ) > 1 ):
500- dim = op_schema .args_schema [1 ]
501- # normalize arguments
502- if dim < 0 :
503- dim += ndim
495+ dim = cast (int , op_schema .args_schema [1 ])
496+ dim = normalize_dim (dim , ndim )
504497
505498 # Unshard all input tensors on cat dim before running einop rule
506499 # to avoid _Partial in result.
507500 need_reshard = False
508501 tensor_list_specs_after = []
509502 for spec in tensor_list_specs :
510- if is_tensor_dim_sharded (spec . placements , dim = dim ):
503+ if is_tensor_dim_sharded (spec , dim = dim ):
511504 need_reshard = True
512505 tensor_list_specs_after .append (
513506 DTensorSpec (
@@ -572,7 +565,6 @@ def cat_rule(op_schema: OpSchema) -> OutputSharding:
572565 return _update_schema_suggestion_for_cat (
573566 output_sharding ,
574567 op_schema ,
575- dim ,
576568 )
577569 else :
578570 return output_sharding
@@ -594,7 +586,6 @@ def cat_rule(op_schema: OpSchema) -> OutputSharding:
594586def _update_schema_suggestion_for_cat (
595587 output_sharding : OutputSharding ,
596588 op_schema : OpSchema ,
597- dim : int ,
598589) -> OutputSharding :
599590 assert output_sharding .schema_suggestions is not None
600591 suggestion_specs = output_sharding .schema_suggestions [0 ].args_spec
0 commit comments