Skip to content

Commit 0d90fb8

Browse files
committed
[DTensor] implement dist_cat as a sharding prop rule
ghstack-source-id: 4c2f629 Pull Request resolved: #92677
1 parent 0e92bbe commit 0d90fb8

File tree

4 files changed

+130
-39
lines changed

4 files changed

+130
-39
lines changed

test/distributed/_tensor/test_dtensor_ops.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def wrapped(fn):
118118
xfail("bernoulli"),
119119
xfail("block_diag"),
120120
xfail("broadcast_shapes"),
121-
xfail("cat"),
122121
xfail("cartesian_prod"),
123122
xfail("cdist"),
124123
xfail("cholesky"),
@@ -128,7 +127,6 @@ def wrapped(fn):
128127
xfail("clamp"),
129128
xfail("clamp_max"),
130129
xfail("clamp_min"),
131-
xfail("column_stack"),
132130
xfail("combinations"),
133131
xfail("complex"),
134132
xfail("constant_pad_nd"),
@@ -147,10 +145,8 @@ def wrapped(fn):
147145
xfail("diagonal"),
148146
xfail("diagonal_copy"),
149147
xfail("diagonal_scatter"),
150-
xfail("diff"),
151148
xfail("dist"),
152149
xfail("dot"),
153-
xfail("dstack"),
154150
xfail("einsum"),
155151
xfail("empty"),
156152
xfail("empty_like"),
@@ -188,7 +184,6 @@ def wrapped(fn):
188184
xfail("histc"),
189185
xfail("histogram"),
190186
xfail("histogramdd"),
191-
xfail("hstack"),
192187
xfail("index_add"),
193188
xfail("index_copy"),
194189
xfail("index_fill"),
@@ -507,7 +502,6 @@ def wrapped(fn):
507502
xfail("vdot"),
508503
xfail("view_copy"),
509504
xfail("view_as_complex"),
510-
xfail("vstack"),
511505
xfail("where"),
512506
xfail("zeros"),
513507
# ops inside this might even fail without dtensor

test/distributed/_tensor/test_tp_sharding_ops.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,24 +69,6 @@ def test_replicated_permute(self):
6969
self.assertEqual(new_dt.to_local(), tensor.permute(1, 0, 2))
7070
self.assertEqual(new_dt.stride(), tensor.permute(1, 0, 2).stride())
7171

72-
@with_comms
73-
def test_sharded_cat(self):
74-
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
75-
torch.manual_seed(self.rank)
76-
tensor_1 = torch.rand(3, 5, 6)
77-
tensor_2 = torch.rand(3, 5, 6)
78-
tensor_3 = torch.rand(3, 5, 6)
79-
sharding = [Shard(0)]
80-
dt_1 = DTensor.from_local(tensor_1, device_mesh, sharding)
81-
dt_2 = DTensor.from_local(tensor_2, device_mesh, sharding)
82-
dt_3 = DTensor.from_local(tensor_3, device_mesh, sharding)
83-
new_dt = torch.cat([dt_1, dt_2, dt_3])
84-
cat_dt = DTensor.from_local(
85-
torch.cat([tensor_1, tensor_2, tensor_3]), device_mesh, sharding
86-
)
87-
self.assertEqual(new_dt.to_local(), cat_dt.to_local())
88-
self.assertEqual(new_dt.size(), cat_dt.size())
89-
9072
@with_comms
9173
def test_sharded_split(self):
9274
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

torch/distributed/_tensor/ops/tensor_ops.py

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
Shard,
1212
)
1313
from torch.distributed._tensor.dispatch import OpSchema, OutputSharding
14-
from torch.distributed._tensor.ops.common_rules import pointwise_rule
15-
from torch.distributed._tensor.ops.utils import register_prop_rule
14+
from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule
15+
from torch.distributed._tensor.ops.utils import register_prop_rule, normalize_dim
1616

1717

1818
# NOTE: the default propagation rule should apply for
@@ -160,6 +160,13 @@ def unshard_tensor_dim(
160160
)
161161

162162

163+
def is_tensor_dim_sharded(
164+
spec: DTensorSpec, dim: int
165+
) -> bool:
166+
"""Return True if tensor dim is sharded"""
167+
return (dim < spec.ndim) and spec.dim_map[dim] >= 0
168+
169+
163170
def _prop_all_but_dim(
164171
op_schema: OpSchema, dim: int, out_shape: torch.Size
165172
) -> OutputSharding:
@@ -472,3 +479,124 @@ def place(vp: Placement, ip: Placement) -> Placement:
472479
],
473480
)
474481
return result
482+
483+
484+
@register_prop_rule("aten.cat.default")
485+
def cat_rule(op_schema: OpSchema) -> OutputSharding:
486+
# the first arg is a list of input tensors' specs
487+
tensor_list_specs = cast(List[DTensorSpec], op_schema.args_schema[0])
488+
# ndim will also be the result's ndim
489+
ndim = 1
490+
for spec in tensor_list_specs:
491+
ndim = max(ndim, spec.ndim)
492+
493+
dim = 0 # default dim = 0
494+
if (len(op_schema.args_schema) > 1):
495+
dim = cast(int, op_schema.args_schema[1])
496+
dim = normalize_dim(dim, ndim)
497+
498+
# Unshard all input tensors on cat dim before running einop rule
499+
# to avoid _Partial in result.
500+
need_reshard = False
501+
tensor_list_specs_after = []
502+
for spec in tensor_list_specs:
503+
if is_tensor_dim_sharded(spec, dim=dim):
504+
need_reshard = True
505+
tensor_list_specs_after.append(
506+
DTensorSpec(
507+
mesh=spec.mesh,
508+
placements=unshard_tensor_dim(spec.placements, dim=dim),
509+
shape=spec.shape,
510+
ndim=spec.ndim,
511+
)
512+
)
513+
else:
514+
tensor_list_specs_after.append(spec)
515+
tensor_list_specs = tensor_list_specs_after
516+
517+
# TODO: currently einop rule requires every character
518+
# in result notation must have appeared in inputs
519+
# so we temporarily design cat notation as
520+
# "aij,bij->aij". Once we modify this requirement,
521+
# we can switch to the more logically reasonable notation
522+
# "aij,bij->cij"
523+
alphabet = "abcdefghijklmnopqrstuvwxyz"
524+
einop_notation_list = []
525+
526+
l = len(tensor_list_specs)
527+
free_dim = alphabet[l:l + ndim - 1]
528+
for i, spec in enumerate(tensor_list_specs):
529+
if spec.ndim == ndim:
530+
# rewrite concat dim
531+
dim_word = free_dim[:dim] + alphabet[i] + free_dim[dim:]
532+
einop_notation_list.append(dim_word)
533+
else:
534+
einop_notation_list.append(alphabet[i])
535+
536+
cat_dim_char = alphabet[0]
537+
dim_word = free_dim[:dim] + cat_dim_char + free_dim[dim:]
538+
einop_equation = f"{','.join(einop_notation_list)}->{dim_word}"
539+
output_sharding = einop_rule(
540+
einop_equation,
541+
OpSchema(
542+
func_schema=op_schema.func_schema,
543+
args_schema=tuple(tensor_list_specs),
544+
kwargs_schema={},
545+
),
546+
linearity=False
547+
)
548+
549+
if (
550+
(output_sharding.output_spec is not None) and
551+
need_reshard
552+
):
553+
output_sharding.output_spec = None
554+
output_sharding.schema_suggestions = [
555+
OpSchema(
556+
func_schema=op_schema.func_schema,
557+
args_schema=tuple(tensor_list_specs),
558+
kwargs_schema={},
559+
),
560+
]
561+
562+
if output_sharding.output_spec is None:
563+
if output_sharding.schema_suggestions is not None:
564+
# Convert args_schema from a tuple of DTensorSpec into a list
565+
return _update_schema_suggestion_for_cat(
566+
output_sharding,
567+
op_schema,
568+
)
569+
else:
570+
return output_sharding
571+
572+
# change output shape
573+
new_size = 0
574+
for spec in tensor_list_specs:
575+
if dim < spec.ndim:
576+
new_size += spec.shape[dim]
577+
assert isinstance(output_sharding.output_spec, DTensorSpec)
578+
output_sharding.output_spec.shape = torch.Size(
579+
tuple(output_sharding.output_spec.shape[:dim])
580+
+ (new_size,)
581+
+ tuple(output_sharding.output_spec.shape[dim + 1 :])
582+
)
583+
return output_sharding
584+
585+
586+
def _update_schema_suggestion_for_cat(
587+
output_sharding: OutputSharding,
588+
op_schema: OpSchema,
589+
) -> OutputSharding:
590+
assert output_sharding.schema_suggestions is not None
591+
suggestion_specs = output_sharding.schema_suggestions[0].args_spec
592+
593+
args_schema = (suggestion_specs,) + op_schema.args_schema[1:]
594+
595+
output_sharding.schema_suggestions = [
596+
OpSchema(
597+
func_schema=op_schema.func_schema,
598+
args_schema=args_schema,
599+
kwargs_schema=op_schema.kwargs_schema,
600+
)
601+
]
602+
return output_sharding

torch/distributed/_tensor/ops/tp_sharding_ops.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# implement matrix related ops for distributed tensor
33
from typing import List
44

5-
import torch
65
import torch.utils._pytree as pytree
76
from torch.distributed._tensor.api import DTensor
87
from torch.distributed._tensor.ops.utils import register_impl, unwrap_single_placement
@@ -16,18 +15,6 @@
1615
"""
1716

1817

19-
@register_impl("aten.cat.default")
20-
def dist_cat(tensor_list: List[DTensor], dim: int = 0) -> DTensor:
21-
local_inputs = pytree.tree_map(unwrap_local_tensor, tensor_list)
22-
local_tensor = torch.ops.aten.concat(local_inputs, dim=dim)
23-
return DTensor.from_local(
24-
local_tensor,
25-
tensor_list[0].device_mesh,
26-
tensor_list[0].placements,
27-
run_check=False,
28-
)
29-
30-
3118
@register_impl("aten.split.Tensor")
3219
# pyre-fixme[2]: Parameter must be annotated.
3320
def dist_split(self: DTensor, split_size_or_sections, dim=0) -> List[DTensor]:

0 commit comments

Comments
 (0)