@@ -451,8 +451,6 @@ def wrapped(fn):
451451 xfail ("special.spherical_bessel_j0" ),
452452 xfail ("special.xlog1py" ),
453453 xfail ("special.zeta" ),
454- xfail ("split" , "list_args" ),
455- xfail ("split_with_sizes" ),
456454 xfail ("squeeze" , "multiple" ),
457455 xfail ("signal.windows.bartlett" ),
458456 xfail ("signal.windows.blackman" ),
@@ -617,13 +615,21 @@ def assert_ref_dtensor_equal(self, dtensor_rs, rs):
617615 def run_dtensor_crossref (self , func , args , kwargs ):
618616 to_dtensor = DTensorConverter (self .mesh , args , kwargs )
619617
618+ def concat_res_if_necessary (func , res : object ) -> object :
619+ # concat the result on corresponding dim for ops like
620+ # split, so that we can call backward on a single tensor
621+ if (
622+ (resolve_name (func ) is not None )
623+ and ("split" in resolve_name (func ))
624+ ):
625+ dim = args [2 ] if len (args ) == 3 else 0
626+ return torch .cat (res , dim = dim )
627+ else :
628+ return res
629+
620630 # TODO: also handle cases where func raise an exception
621631 rs = func (* args , ** kwargs )
622- if (
623- (resolve_name (func ) is not None )
624- and ("split" in resolve_name (func ))
625- ):
626- rs = torch .cat (rs )
632+ rs = concat_res_if_necessary (func , rs )
627633
628634 def to_replicate (e : object ) -> object :
629635 return (
@@ -664,11 +670,7 @@ def to_replicate(e: object) -> object:
664670
665671 # redistribute/all_gather the results to compare with normal output
666672 dtensor_rs = tree_map (to_replicate , dtensor_rs )
667- if (
668- (resolve_name (func ) is not None )
669- and ("split" in resolve_name (func ))
670- ):
671- dtensor_rs = torch .cat (dtensor_rs )
673+ dtensor_rs = concat_res_if_necessary (func , dtensor_rs )
672674 try :
673675 if resolve_name (func ) not in skip_bw :
674676 if isinstance (dtensor_rs , DTensor ):
0 commit comments