Skip to content

Commit 3ea6b8c

Browse files
committed
Update on "[DTensor] implement dist_cat as a sharding prop rule"
[ghstack-poisoned]
2 parents cb506a5 + 19d72c5 commit 3ea6b8c

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

torch/distributed/_tensor/ops/tensor_ops.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from torch.distributed._tensor.dispatch import OpSchema, OutputSharding
1414
from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule
15-
from torch.distributed._tensor.ops.utils import register_prop_rule
15+
from torch.distributed._tensor.ops.utils import register_prop_rule, normalize_dim
1616

1717

1818
# NOTE: the default propagation rule should apply for
@@ -161,15 +161,10 @@ def unshard_tensor_dim(
161161

162162

163163
def is_tensor_dim_sharded(
164-
placements: Sequence[Placement], dim: int
164+
spec: DTensorSpec, dim: int
165165
) -> bool:
166166
"""Return True if tensor dim is sharded"""
167-
return any(
168-
tuple(
169-
True if (isinstance(p, Shard) and p.dim == dim) else False
170-
for p in placements
171-
)
172-
)
167+
return (dim < spec.ndim) and spec.dim_map[dim] >= 0
173168

174169

175170
def _prop_all_but_dim(
@@ -497,17 +492,15 @@ def cat_rule(op_schema: OpSchema) -> OutputSharding:
497492

498493
dim = 0 # default dim = 0
499494
if (len(op_schema.args_schema) > 1):
500-
dim = op_schema.args_schema[1]
501-
# normalize arguments
502-
if dim < 0:
503-
dim += ndim
495+
dim = cast(int, op_schema.args_schema[1])
496+
dim = normalize_dim(dim, ndim)
504497

505498
# Unshard all input tensors on cat dim before running einop rule
506499
# to avoid _Partial in result.
507500
need_reshard = False
508501
tensor_list_specs_after = []
509502
for spec in tensor_list_specs:
510-
if is_tensor_dim_sharded(spec.placements, dim=dim):
503+
if is_tensor_dim_sharded(spec, dim=dim):
511504
need_reshard = True
512505
tensor_list_specs_after.append(
513506
DTensorSpec(
@@ -572,7 +565,6 @@ def cat_rule(op_schema: OpSchema) -> OutputSharding:
572565
return _update_schema_suggestion_for_cat(
573566
output_sharding,
574567
op_schema,
575-
dim,
576568
)
577569
else:
578570
return output_sharding
@@ -594,7 +586,6 @@ def cat_rule(op_schema: OpSchema) -> OutputSharding:
594586
def _update_schema_suggestion_for_cat(
595587
output_sharding: OutputSharding,
596588
op_schema: OpSchema,
597-
dim: int,
598589
) -> OutputSharding:
599590
assert output_sharding.schema_suggestions is not None
600591
suggestion_specs = output_sharding.schema_suggestions[0].args_spec

0 commit comments

Comments
 (0)