Skip to content

Commit 98e3432

Browse files
committed
[DTensor] implement dist_split as a sharding prop rule
ghstack-source-id: 17951d9 Pull Request resolved: #93306
1 parent 04082fc commit 98e3432

File tree

4 files changed

+90
-3
lines changed

4 files changed

+90
-3
lines changed

test/distributed/_tensor/test_dtensor_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def wrapped(fn):
453453
xfail("special.spherical_bessel_j0"),
454454
xfail("special.xlog1py"),
455455
xfail("special.zeta"),
456-
xfail("split"),
456+
#xfail("split"),
457457
xfail("split", "list_args"),
458458
xfail("split_with_sizes"),
459459
xfail("squeeze", "multiple"),
@@ -553,6 +553,7 @@ def wrapped(fn):
553553
"torch.eq",
554554
"torch.isfinite",
555555
"torch.isnan",
556+
#"torch.functional.split",
556557
]
557558

558559

test/distributed/_tensor/test_tp_sharding_ops.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_replicated_permute(self):
7070
self.assertEqual(new_dt.stride(), tensor.permute(1, 0, 2).stride())
7171

7272
@with_comms
73-
def test_sharded_split(self):
73+
def test_sharded_split_1(self):
7474
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
7575
torch.manual_seed(self.rank)
7676
tensor = torch.rand(3, 5, 6, device=self.device_type)
@@ -82,6 +82,21 @@ def test_sharded_split(self):
8282
self.assertTrue(dt.placements[0].is_shard(dim=2))
8383
self.assertEqual(dt.to_local(), local_tensors[idx])
8484

85+
@with_comms
86+
def test_sharded_split_2(self):
87+
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
88+
torch.manual_seed(0)
89+
tensor = torch.rand(4, 4, 4, device=self.device_type, requires_grad=True)
90+
sharding = [Replicate()]
91+
dist_tensor = distribute_tensor(tensor, device_mesh, sharding)
92+
dt_list = dist_tensor.split(dist_tensor.size(0) // 2, dim=0)
93+
print(dt_list)
94+
local_tensors = tensor.split(2, dim=0)
95+
for idx, dt in enumerate(dt_list):
96+
#self.assertTrue(dt.placements[0].is_shard(dim=0))
97+
self.assertEqual(dt.to_local(), local_tensors[idx])
98+
dt_list[0].to_local().sum().backward()
99+
85100

86101
if __name__ == "__main__":
87102
run_tests()

torch/distributed/_tensor/ops/tensor_ops.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,3 +600,75 @@ def _update_schema_suggestion_for_cat(
600600
)
601601
]
602602
return output_sharding
603+
604+
@register_prop_rule("aten.split.Tensor")
605+
def split_rule(op_schema: OpSchema) -> OutputSharding:
606+
"""
607+
The OutputSpecType of tensor split should be Sequence[DTensorSpec]
608+
"""
609+
print(op_schema)
610+
output_spec_list: List[DTensorSpec] = []
611+
input_spec = cast(DTensorSpec, op_schema.args_schema[0])
612+
ndim = input_spec.ndim
613+
split_size_or_sections = op_schema.args_schema[1]
614+
dim = 0
615+
if len(op_schema.args_schema) > 2:
616+
dim = cast(int, op_schema.args_schema[2])
617+
dim = normalize_dim(dim, ndim)
618+
619+
# TODO: just like slice op, split replicates before splitting
620+
# on a sharded dimension
621+
# TODO: shall we consider partial???
622+
# TODO: consider splitting an empty tensor
623+
need_reshard = False
624+
if is_tensor_dim_sharded(input_spec, dim=dim):
625+
need_reshard = True
626+
input_spec = DTensorSpec(
627+
mesh=input_spec.mesh,
628+
placements=unshard_tensor_dim(input_spec.placements, dim=dim),
629+
shape=input_spec.shape,
630+
ndim=input_spec.ndim,
631+
)
632+
633+
if need_reshard:
634+
return OutputSharding(
635+
None,
636+
schema_suggestions=[
637+
OpSchema(
638+
func_schema=op_schema.func_schema,
639+
args_schema=(input_spec,) + op_schema.args_schema[1:],
640+
kwargs_schema={},
641+
),
642+
]
643+
)
644+
645+
def size_split(N, i):
646+
# Last chunk will be smaller if the tensor size N
647+
# along the given dimension dim is not divisible by i.
648+
assert i > 0
649+
return [i] * (N // i) + ([N % i] if N % i != 0 else [])
650+
651+
output_size_list = (
652+
size_split(input_spec.shape[dim], split_size_or_sections)
653+
if isinstance(split_size_or_sections, int)
654+
else split_size_or_sections
655+
)
656+
output_shape_list = [
657+
torch.Size(
658+
tuple(input_spec.shape[:dim])
659+
+ (size,)
660+
+ tuple(input_spec.shape[dim + 1 :])
661+
)
662+
for size in output_size_list
663+
]
664+
output_spec_list = [
665+
DTensorSpec(
666+
mesh=input_spec.mesh,
667+
placements=input_spec.placements,
668+
shape=shape,
669+
ndim=input_spec.ndim,
670+
)
671+
for shape in output_shape_list
672+
]
673+
print(output_spec_list)
674+
return OutputSharding(output_spec_list)

torch/distributed/_tensor/ops/tp_sharding_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""
1616

1717

18-
@register_impl("aten.split.Tensor")
1918
# pyre-fixme[2]: Parameter must be annotated.
2019
def dist_split(self: DTensor, split_size_or_sections, dim=0) -> List[DTensor]:
2120
local_mat = pytree.tree_map(unwrap_local_tensor, self)

0 commit comments

Comments
 (0)