Skip to content

Commit d05ec0e

Browse files
wanchaolpytorchmergebot
authored andcommitted
[dtensor] add split_with_sizes op (#93957)
add the split_with_sizes op, sharing with split op impl Pull Request resolved: #93957 Approved by: https://github.com/XilunWu
1 parent bfe5e12 commit d05ec0e

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

test/distributed/_tensor/test_dtensor_ops.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,6 @@ def wrapped(fn):
451451
xfail("special.spherical_bessel_j0"),
452452
xfail("special.xlog1py"),
453453
xfail("special.zeta"),
454-
xfail("split", "list_args"),
455-
xfail("split_with_sizes"),
456454
xfail("squeeze", "multiple"),
457455
xfail("signal.windows.bartlett"),
458456
xfail("signal.windows.blackman"),
@@ -617,13 +615,21 @@ def assert_ref_dtensor_equal(self, dtensor_rs, rs):
617615
def run_dtensor_crossref(self, func, args, kwargs):
618616
to_dtensor = DTensorConverter(self.mesh, args, kwargs)
619617

618+
def concat_res_if_necessary(func, res: object) -> object:
619+
# concat the result on corresponding dim for ops like
620+
# split, so that we can call backward on a single tensor
621+
if (
622+
(resolve_name(func) is not None)
623+
and ("split" in resolve_name(func))
624+
):
625+
dim = args[2] if len(args) == 3 else 0
626+
return torch.cat(res, dim=dim)
627+
else:
628+
return res
629+
620630
# TODO: also handle cases where func raise an exception
621631
rs = func(*args, **kwargs)
622-
if (
623-
(resolve_name(func) is not None)
624-
and ("split" in resolve_name(func))
625-
):
626-
rs = torch.cat(rs)
632+
rs = concat_res_if_necessary(func, rs)
627633

628634
def to_replicate(e: object) -> object:
629635
return (
@@ -664,11 +670,7 @@ def to_replicate(e: object) -> object:
664670

665671
# redistribute/all_gather the results to compare with normal output
666672
dtensor_rs = tree_map(to_replicate, dtensor_rs)
667-
if (
668-
(resolve_name(func) is not None)
669-
and ("split" in resolve_name(func))
670-
):
671-
dtensor_rs = torch.cat(dtensor_rs)
673+
dtensor_rs = concat_res_if_necessary(func, dtensor_rs)
672674
try:
673675
if resolve_name(func) not in skip_bw:
674676
if isinstance(dtensor_rs, DTensor):

torch/distributed/_tensor/ops/tensor_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def _update_schema_suggestion_for_cat(
600600
return output_sharding
601601

602602

603-
@register_prop_rule(aten.split.Tensor)
603+
@register_prop_rule([aten.split.Tensor, aten.split_with_sizes.default])
604604
def split_rule(op_schema: OpSchema) -> OutputSharding:
605605
output_spec_list: List[DTensorSpec] = []
606606
input_spec = cast(DTensorSpec, op_schema.args_schema[0])

0 commit comments

Comments
 (0)