Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion test/distributed/_tensor/test_dtensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,6 @@ def wrapped(fn):
xfail("special.spherical_bessel_j0"),
xfail("special.xlog1py"),
xfail("special.zeta"),
xfail("split"),
xfail("split", "list_args"),
xfail("split_with_sizes"),
xfail("squeeze", "multiple"),
Expand Down Expand Up @@ -620,6 +619,11 @@ def run_dtensor_crossref(self, func, args, kwargs):

# TODO: also handle cases where func raise an exception
rs = func(*args, **kwargs)
if (
(resolve_name(func) is not None)
and ("split" in resolve_name(func))
):
rs = torch.cat(rs)

def to_replicate(e: object) -> object:
return (
Expand Down Expand Up @@ -660,6 +664,11 @@ def to_replicate(e: object) -> object:

# redistribute/all_gather the results to compare with normal output
dtensor_rs = tree_map(to_replicate, dtensor_rs)
if (
(resolve_name(func) is not None)
and ("split" in resolve_name(func))
):
dtensor_rs = torch.cat(dtensor_rs)
try:
if resolve_name(func) not in skip_bw:
if isinstance(dtensor_rs, DTensor):
Expand Down
17 changes: 8 additions & 9 deletions test/distributed/_tensor/test_tp_sharding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Replicate,
Shard,
)
from torch.distributed._tensor.placement_types import _Partial
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
Expand Down Expand Up @@ -70,17 +71,15 @@ def test_replicated_permute(self):
self.assertEqual(new_dt.stride(), tensor.permute(1, 0, 2).stride())

@with_comms
def test_sharded_split(self):
def test_split_partial_tensor(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
torch.manual_seed(self.rank)
tensor = torch.rand(3, 5, 6, device=self.device_type)
sharding = [Shard(2)]
dist_tensor = DTensor.from_local(tensor, device_mesh, sharding)
dt_list = dist_tensor.split(dist_tensor.size(-1) // 2, dim=-1)
local_tensors = tensor.split(3, dim=-1)
for idx, dt in enumerate(dt_list):
self.assertTrue(dt.placements[0].is_shard(dim=2))
self.assertEqual(dt.to_local(), local_tensors[idx])
dist_tensor = DTensor.from_local(tensor, device_mesh, [_Partial()])
with self.assertRaisesRegex(
RuntimeError,
"_Partial placement is not implemented",
):
dist_tensor = dist_tensor.split(3)


if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion torch/distributed/_tensor/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@
from .matrix_ops import * # noqa: F403
from .math_ops import * # noqa: F403
from .tensor_ops import * # noqa: F403
from .tp_sharding_ops import * # noqa: F403
from .pointwise_ops import * # noqa: F403
from .view_ops import * # noqa: F403
78 changes: 78 additions & 0 deletions torch/distributed/_tensor/ops/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,3 +598,81 @@ def _update_schema_suggestion_for_cat(
)
]
return output_sharding


@register_prop_rule(aten.split.Tensor)
def split_rule(op_schema: OpSchema) -> OutputSharding:
output_spec_list: List[DTensorSpec] = []
input_spec = cast(DTensorSpec, op_schema.args_schema[0])
ndim = input_spec.ndim
split_size_or_sections = op_schema.args_schema[1]
dim = (
cast(int, op_schema.args_schema[2])
if len(op_schema.args_schema) > 2
else 0
)
dim = normalize_dim(dim, ndim)

# TODO: tensor to split cannot have _Partial
# in its placements for now. Will need to
# support in future.
if input_spec.sums:
raise NotImplementedError(
f"splitting distributed tensor with "
f"_Partial placement is not implemented!\n"
f"DTensorSpec={input_spec}"
)

# TODO: just like slice op, split replicates before
# splitting on a sharded dimension
need_reshard = False
if is_tensor_dim_sharded(input_spec, dim=dim):
Comment on lines +628 to +629
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This somehow broke TP's code logic. Because a common technique people are using is that the DTensor is sharded on the last dim and they call split on the last dim too. We still want the result to be sharded on dim=-1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently after split, we got replicate as a DTensor.

need_reshard = True
input_spec = DTensorSpec(
mesh=input_spec.mesh,
placements=unshard_tensor_dim(input_spec.placements, dim=dim),
shape=input_spec.shape,
ndim=input_spec.ndim,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a check to partial input_spec and raise NotImplementedError so we know to implement this later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!


if need_reshard:
return OutputSharding(
None,
schema_suggestions=[
OpSchema(
func_schema=op_schema.func_schema,
args_schema=(input_spec,) + op_schema.args_schema[1:],
kwargs_schema=op_schema.kwargs_schema,
),
]
)

def size_split(N, i):
# Last chunk will be smaller if the tensor size N
# along the given dimension dim is not divisible by i.
assert i > 0
return [i] * (N // i) + ([N % i] if N % i != 0 else [])

output_size_list = (
size_split(input_spec.shape[dim], split_size_or_sections)
if isinstance(split_size_or_sections, int)
else split_size_or_sections
)
output_shape_list = [
torch.Size(
tuple(input_spec.shape[:dim])
+ (size,)
+ tuple(input_spec.shape[dim + 1 :])
)
for size in output_size_list
]
output_spec_list = [
DTensorSpec(
mesh=input_spec.mesh,
placements=input_spec.placements,
shape=shape,
ndim=input_spec.ndim,
)
for shape in output_shape_list
]
return OutputSharding(output_spec_list)
47 changes: 0 additions & 47 deletions torch/distributed/_tensor/ops/tp_sharding_ops.py

This file was deleted.