Skip to content

Commit b95f493

Browse files
committed
[DTensor] implement dist_cat as a sharding prop rule
ghstack-source-id: d4cbf11 Pull Request resolved: #92677
1 parent f6acd95 commit b95f493

File tree

5 files changed

+109
-33
lines changed

5 files changed

+109
-33
lines changed

test/distributed/_tensor/test_dtensor_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ def wrapped(fn):
150150
xfail("diff"),
151151
xfail("dist"),
152152
xfail("dot"),
153-
xfail("dstack"),
154153
xfail("einsum"),
155154
xfail("empty"),
156155
xfail("empty_like"),

test/distributed/_tensor/test_tensor_ops.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,24 @@ def test_index(self):
354354
torch.randint(5, (12, 8, 12)),
355355
)
356356

357+
@with_comms
358+
def test_sharded_cat(self):
359+
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
360+
torch.manual_seed(self.rank)
361+
tensor_1 = torch.rand(3, 5, 6)
362+
tensor_2 = torch.rand(3, 5, 6)
363+
tensor_3 = torch.rand(3, 5, 6)
364+
sharding = [Shard(0)]
365+
dt_1 = DTensor.from_local(tensor_1, device_mesh, sharding)
366+
dt_2 = DTensor.from_local(tensor_2, device_mesh, sharding)
367+
dt_3 = DTensor.from_local(tensor_3, device_mesh, sharding)
368+
new_dt = torch.cat([dt_1, dt_2, dt_3])
369+
cat_dt = DTensor.from_local(
370+
torch.cat([tensor_1, tensor_2, tensor_3]), device_mesh, sharding
371+
)
372+
self.assertEqual(new_dt.to_local(), cat_dt.to_local())
373+
self.assertEqual(new_dt.size(), cat_dt.size())
374+
357375

358376
if __name__ == "__main__":
359377
run_tests()

test/distributed/_tensor/test_tp_sharding_ops.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,24 +69,6 @@ def test_replicated_permute(self):
6969
self.assertEqual(new_dt.to_local(), tensor.permute(1, 0, 2))
7070
self.assertEqual(new_dt.stride(), tensor.permute(1, 0, 2).stride())
7171

72-
@with_comms
73-
def test_sharded_cat(self):
74-
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
75-
torch.manual_seed(self.rank)
76-
tensor_1 = torch.rand(3, 5, 6)
77-
tensor_2 = torch.rand(3, 5, 6)
78-
tensor_3 = torch.rand(3, 5, 6)
79-
sharding = [Shard(0)]
80-
dt_1 = DTensor.from_local(tensor_1, device_mesh, sharding)
81-
dt_2 = DTensor.from_local(tensor_2, device_mesh, sharding)
82-
dt_3 = DTensor.from_local(tensor_3, device_mesh, sharding)
83-
new_dt = torch.cat([dt_1, dt_2, dt_3])
84-
cat_dt = DTensor.from_local(
85-
torch.cat([tensor_1, tensor_2, tensor_3]), device_mesh, sharding
86-
)
87-
self.assertEqual(new_dt.to_local(), cat_dt.to_local())
88-
self.assertEqual(new_dt.size(), cat_dt.size())
89-
9072
@with_comms
9173
def test_sharded_split(self):
9274
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

torch/distributed/_tensor/ops/tensor_ops.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
Shard,
1212
)
1313
from torch.distributed._tensor.dispatch import OpSchema, OutputSharding
14-
from torch.distributed._tensor.ops.common_rules import pointwise_rule
14+
from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule
1515
from torch.distributed._tensor.ops.utils import register_prop_rule
1616

1717

@@ -472,3 +472,93 @@ def place(vp: Placement, ip: Placement) -> Placement:
472472
],
473473
)
474474
return result
475+
476+
477+
@register_prop_rule("aten.cat.default")
478+
def cat_rule(op_schema: OpSchema) -> OutputSharding:
479+
dim = 0 # default dim = 0
480+
tensor_list_specs = cast(List[DTensorSpec], op_schema.args_schema[0])
481+
if (len(op_schema.args_schema) > 1):
482+
dim = cast(int, op_schema.args_schema[1])
483+
# normalize arguments
484+
if dim < 0:
485+
dim += tensor_list_specs[0].ndim
486+
487+
# check concat dim
488+
needs_reshard_on_cat_dim = False
489+
for spec in tensor_list_specs:
490+
if dim < len(spec.placements) and spec.placements[dim].is_shard():
491+
needs_reshard_on_cat_dim = True
492+
spec.placements = unshard_tensor_dim(spec.placements, dim=dim)
493+
if needs_reshard_on_cat_dim:
494+
args_schema = (tensor_list_specs,) + op_schema.args_schema[1:]
495+
suggested_schema = OpSchema(
496+
func_schema=op_schema.func_schema,
497+
args_schema=args_schema,
498+
kwargs_schema=op_schema.kwargs_schema,
499+
)
500+
return OutputSharding(
501+
None,
502+
schema_suggestions=[suggested_schema],
503+
failed_reason="All tensors in concat must have no sharding on cat dim, need to reshard!",
504+
)
505+
alphabet = "abcdefghijklmnopqrstuvwxyz"
506+
einop_equation = ""
507+
for spec in tensor_list_specs:
508+
einop_equation += alphabet[:spec.ndim]
509+
einop_equation += ','
510+
einop_equation = einop_equation[:-1] + "->" + alphabet[:tensor_list_specs[0].ndim]
511+
output_sharding = einop_rule(
512+
einop_equation,
513+
OpSchema(
514+
func_schema=op_schema.func_schema,
515+
args_schema=tuple(tensor_list_specs),
516+
kwargs_schema={},
517+
),
518+
linearity=False
519+
)
520+
521+
if output_sharding.output_spec is None:
522+
if output_sharding.schema_suggestions is not None:
523+
return _update_schema_suggestion_for_cat(
524+
output_sharding,
525+
op_schema,
526+
dim,
527+
)
528+
else:
529+
return OutputSharding(None)
530+
# change output shape
531+
new_size = 0
532+
for spec in tensor_list_specs:
533+
new_size += spec.shape[dim]
534+
assert isinstance(output_sharding.output_spec, DTensorSpec)
535+
output_sharding.output_spec.shape = torch.Size(
536+
tuple(output_sharding.output_spec.shape[:dim])
537+
+ (new_size,)
538+
+ tuple(output_sharding.output_spec.shape[dim + 1 :])
539+
)
540+
return output_sharding
541+
542+
543+
def _update_schema_suggestion_for_cat(
544+
output_sharding: OutputSharding,
545+
op_schema: OpSchema,
546+
dim: int,
547+
) -> OutputSharding:
548+
assert output_sharding.schema_suggestions is not None
549+
suggestion_specs = output_sharding.schema_suggestions[0].args_spec
550+
551+
# check concat dim
552+
for spec in suggestion_specs:
553+
if dim < len(spec.placements) and spec.placements[dim].is_shard():
554+
spec.placements = unshard_tensor_dim(spec.placements, dim=dim)
555+
args_schema = (suggestion_specs,) + op_schema.args_schema[1:]
556+
557+
output_sharding.schema_suggestions = [
558+
OpSchema(
559+
func_schema=op_schema.func_schema,
560+
args_schema=args_schema,
561+
kwargs_schema=op_schema.kwargs_schema,
562+
)
563+
]
564+
return output_sharding

torch/distributed/_tensor/ops/tp_sharding_ops.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# implement matrix related ops for distributed tensor
33
from typing import List
44

5-
import torch
65
import torch.utils._pytree as pytree
76
from torch.distributed._tensor.api import DTensor
87
from torch.distributed._tensor.ops.utils import register_impl, unwrap_single_placement
@@ -16,18 +15,6 @@
1615
"""
1716

1817

19-
@register_impl("aten.cat.default")
20-
def dist_cat(tensor_list: List[DTensor], dim: int = 0) -> DTensor:
21-
local_inputs = pytree.tree_map(unwrap_local_tensor, tensor_list)
22-
local_tensor = torch.ops.aten.concat(local_inputs, dim=dim)
23-
return DTensor.from_local(
24-
local_tensor,
25-
tensor_list[0].device_mesh,
26-
tensor_list[0].placements,
27-
run_check=False,
28-
)
29-
30-
3118
@register_impl("aten.split.Tensor")
3219
# pyre-fixme[2]: Parameter must be annotated.
3320
def dist_split(self: DTensor, split_size_or_sections, dim=0) -> List[DTensor]:

0 commit comments

Comments
 (0)