Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions test/distributed/_tensor/test_dtensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def wrapped(fn):
xfail("bernoulli"),
xfail("block_diag"),
xfail("broadcast_shapes"),
xfail("cat"),
xfail("cartesian_prod"),
xfail("cdist"),
xfail("cholesky"),
Expand All @@ -128,7 +127,6 @@ def wrapped(fn):
xfail("clamp"),
xfail("clamp_max"),
xfail("clamp_min"),
xfail("column_stack"),
xfail("combinations"),
xfail("complex"),
xfail("constant_pad_nd"),
Expand All @@ -147,10 +145,8 @@ def wrapped(fn):
xfail("diagonal"),
xfail("diagonal_copy"),
xfail("diagonal_scatter"),
xfail("diff"),
xfail("dist"),
xfail("dot"),
xfail("dstack"),
xfail("einsum"),
xfail("empty"),
xfail("empty_like"),
Expand Down Expand Up @@ -188,7 +184,6 @@ def wrapped(fn):
xfail("histc"),
xfail("histogram"),
xfail("histogramdd"),
xfail("hstack"),
xfail("index_add"),
xfail("index_copy"),
xfail("index_fill"),
Expand Down Expand Up @@ -507,7 +502,6 @@ def wrapped(fn):
xfail("vdot"),
xfail("view_copy"),
xfail("view_as_complex"),
xfail("vstack"),
xfail("where"),
xfail("zeros"),
# ops inside this might even fail without dtensor
Expand Down
18 changes: 0 additions & 18 deletions test/distributed/_tensor/test_tp_sharding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,6 @@ def test_replicated_permute(self):
self.assertEqual(new_dt.to_local(), tensor.permute(1, 0, 2))
self.assertEqual(new_dt.stride(), tensor.permute(1, 0, 2).stride())

@with_comms
def test_sharded_cat(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
torch.manual_seed(self.rank)
tensor_1 = torch.rand(3, 5, 6)
tensor_2 = torch.rand(3, 5, 6)
tensor_3 = torch.rand(3, 5, 6)
sharding = [Shard(0)]
dt_1 = DTensor.from_local(tensor_1, device_mesh, sharding)
dt_2 = DTensor.from_local(tensor_2, device_mesh, sharding)
dt_3 = DTensor.from_local(tensor_3, device_mesh, sharding)
new_dt = torch.cat([dt_1, dt_2, dt_3])
cat_dt = DTensor.from_local(
torch.cat([tensor_1, tensor_2, tensor_3]), device_mesh, sharding
)
self.assertEqual(new_dt.to_local(), cat_dt.to_local())
self.assertEqual(new_dt.size(), cat_dt.size())

@with_comms
def test_sharded_split(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
Expand Down
132 changes: 130 additions & 2 deletions torch/distributed/_tensor/ops/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
Shard,
)
from torch.distributed._tensor.dispatch import OpSchema, OutputSharding
from torch.distributed._tensor.ops.common_rules import pointwise_rule
from torch.distributed._tensor.ops.utils import register_prop_rule
from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule
from torch.distributed._tensor.ops.utils import register_prop_rule, normalize_dim


# NOTE: the default propagation rule should apply for
Expand Down Expand Up @@ -160,6 +160,13 @@ def unshard_tensor_dim(
)


def is_tensor_dim_sharded(
spec: DTensorSpec, dim: int
) -> bool:
"""Return True if tensor dim is sharded"""
return (dim < spec.ndim) and spec.dim_map[dim] >= 0


def _prop_all_but_dim(
op_schema: OpSchema, dim: int, out_shape: torch.Size
) -> OutputSharding:
Expand Down Expand Up @@ -472,3 +479,124 @@ def place(vp: Placement, ip: Placement) -> Placement:
],
)
return result


@register_prop_rule("aten.cat.default")
def cat_rule(op_schema: OpSchema) -> OutputSharding:
# the first arg is a list of input tensors' specs
tensor_list_specs = cast(List[DTensorSpec], op_schema.args_schema[0])
# ndim will also be the result's ndim
ndim = 1
for spec in tensor_list_specs:
ndim = max(ndim, spec.ndim)

dim = 0 # default dim = 0
if (len(op_schema.args_schema) > 1):
dim = cast(int, op_schema.args_schema[1])
dim = normalize_dim(dim, ndim)

# Unshard all input tensors on cat dim before running einop rule
# to avoid _Partial in result.
need_reshard = False
tensor_list_specs_after = []
for spec in tensor_list_specs:
if is_tensor_dim_sharded(spec, dim=dim):
need_reshard = True
tensor_list_specs_after.append(
DTensorSpec(
mesh=spec.mesh,
placements=unshard_tensor_dim(spec.placements, dim=dim),
shape=spec.shape,
ndim=spec.ndim,
)
)
else:
tensor_list_specs_after.append(spec)
tensor_list_specs = tensor_list_specs_after

# TODO: currently einop rule requires every character
# in result notation must have appeared in inputs
# so we temporarily design cat notation as
# "aij,bij->aij". Once we modify this requirement,
# we can switch to the more logically reasonable notation
# "aij,bij->cij"
alphabet = "abcdefghijklmnopqrstuvwxyz"
einop_notation_list = []

l = len(tensor_list_specs)
free_dim = alphabet[l:l + ndim - 1]
for i, spec in enumerate(tensor_list_specs):
if spec.ndim == ndim:
# rewrite concat dim
dim_word = free_dim[:dim] + alphabet[i] + free_dim[dim:]
einop_notation_list.append(dim_word)
else:
einop_notation_list.append(alphabet[i])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the empty tensor annotation where it have a single char?

Copy link
Contributor Author

@XilunWu XilunWu Jan 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not entirely for empty tensor but empty tensor whose ndim is smaller than other tensors. This is for case like concatenating Tensor([], shape=torch.Size([0])) with Tensor([[1, 2], [3, 4]], shape=torch.Size([2, 2])).

In this case, an empty annotation may still work but we want to ensure that the dim char for cat_dim in output tensor annotation must appear in input as well. Adding each input tensor's cat_dim dim char into annotation guarantees that.


cat_dim_char = alphabet[0]
dim_word = free_dim[:dim] + cat_dim_char + free_dim[dim:]
einop_equation = f"{','.join(einop_notation_list)}->{dim_word}"
output_sharding = einop_rule(
einop_equation,
OpSchema(
func_schema=op_schema.func_schema,
args_schema=tuple(tensor_list_specs),
kwargs_schema={},
),
linearity=False
)

if (
(output_sharding.output_spec is not None) and
need_reshard
):
output_sharding.output_spec = None
output_sharding.schema_suggestions = [
OpSchema(
func_schema=op_schema.func_schema,
args_schema=tuple(tensor_list_specs),
kwargs_schema={},
),
]

if output_sharding.output_spec is None:
if output_sharding.schema_suggestions is not None:
# Convert args_schema from a tuple of DTensorSpec into a list
return _update_schema_suggestion_for_cat(
output_sharding,
op_schema,
)
else:
return output_sharding

# change output shape
new_size = 0
for spec in tensor_list_specs:
if dim < spec.ndim:
new_size += spec.shape[dim]
assert isinstance(output_sharding.output_spec, DTensorSpec)
output_sharding.output_spec.shape = torch.Size(
tuple(output_sharding.output_spec.shape[:dim])
+ (new_size,)
+ tuple(output_sharding.output_spec.shape[dim + 1 :])
)
return output_sharding


def _update_schema_suggestion_for_cat(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you tell me what exactly this function is doing? it looks like a lot of duplicate logic with the rule itself and I am not quite sure what this function is used for.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

einop_rule expects the op_schema argument to have its args_schema in form [DTensorSpec, DTensorSpec, ...] but when it's passed into cat_rule the schema is actually [List[DTensorSpec]]. That's why I convert the args_schema at the beginning of cat_rule (https://github.com/pytorch/pytorch/pull/92677/files#diff-ebc7be1151cf411ce7edf46c4ca1cabb74cd953a2bdf47e04b4cc733c31f6085R492) before feeding it into einop_rule. Thus, we need to convert it back if a schema_suggestion is present here.

output_sharding: OutputSharding,
op_schema: OpSchema,
) -> OutputSharding:
assert output_sharding.schema_suggestions is not None
suggestion_specs = output_sharding.schema_suggestions[0].args_spec

args_schema = (suggestion_specs,) + op_schema.args_schema[1:]

output_sharding.schema_suggestions = [
OpSchema(
func_schema=op_schema.func_schema,
args_schema=args_schema,
kwargs_schema=op_schema.kwargs_schema,
)
]
return output_sharding
13 changes: 0 additions & 13 deletions torch/distributed/_tensor/ops/tp_sharding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# implement matrix related ops for distributed tensor
from typing import List

import torch
import torch.utils._pytree as pytree
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.ops.utils import register_impl, unwrap_single_placement
Expand All @@ -16,18 +15,6 @@
"""


@register_impl("aten.cat.default")
def dist_cat(tensor_list: List[DTensor], dim: int = 0) -> DTensor:
local_inputs = pytree.tree_map(unwrap_local_tensor, tensor_list)
local_tensor = torch.ops.aten.concat(local_inputs, dim=dim)
return DTensor.from_local(
local_tensor,
tensor_list[0].device_mesh,
tensor_list[0].placements,
run_check=False,
)


@register_impl("aten.split.Tensor")
# pyre-fixme[2]: Parameter must be annotated.
def dist_split(self: DTensor, split_size_or_sections, dim=0) -> List[DTensor]:
Expand Down