@@ -1503,46 +1503,24 @@ def test_all_gather_coalesced_with_empty(self):
15031503 self ._barrier ()
15041504
15051505 # AllToAll
1506- def _test_all_to_all_single_equal_split_helper (
1507- self ,
1508- group ,
1509- group_id ,
1510- rank ,
1511- cuda = False ,
1512- rank_to_GPU = None ,
1513- ):
1506+ def _test_all_to_all_single_equal_split_helper (self , group , group_id , rank ):
15141507 if group_id is not None :
15151508 size = len (group )
15161509 in_tensor = torch .ones ([size , size ]) * rank
15171510 expected_tensor = torch .cat ([torch .ones ([1 , size ]) * i for i in group ])
15181511 out_tensor = torch .ones ([size , size ]) * - 1
1519- if cuda :
1520- in_tensor = in_tensor .cuda (rank_to_GPU [rank ][0 ])
1521- expected_tensor = expected_tensor .cuda (rank_to_GPU [rank ][0 ])
1522- out_tensor = out_tensor .cuda (rank_to_GPU [rank ][0 ])
15231512 dist .all_to_all_single (out_tensor , in_tensor , group = group_id )
15241513 self .assertEqual (out_tensor , expected_tensor )
15251514 self ._barrier ()
15261515
1527- def _test_all_to_all_single_unequal_split_helper (
1528- self ,
1529- group ,
1530- group_id ,
1531- rank ,
1532- cuda = False ,
1533- rank_to_GPU = None ,
1534- ):
1516+ def _test_all_to_all_single_unequal_split_helper (self , group , group_id , rank ):
15351517 if group_id is not None :
15361518 size = len (group )
15371519 in_splits = [i + 1 for i in group ]
15381520 out_splits = [rank + 1 for _ in group ]
15391521 in_tensor = torch .ones ([sum (in_splits ), size ]) * rank
15401522 out_tensor = torch .ones ([(rank + 1 ) * size , size ])
15411523 expected_tensor = torch .cat ([torch .ones ([rank + 1 , size ]) * i for i in group ])
1542- if cuda :
1543- in_tensor = in_tensor .cuda (rank_to_GPU [rank ][0 ])
1544- expected_tensor = expected_tensor .cuda (rank_to_GPU [rank ][0 ])
1545- out_tensor = out_tensor .cuda (rank_to_GPU [rank ][0 ])
15461524 dist .all_to_all_single (
15471525 out_tensor , in_tensor , out_splits , in_splits , group = group_id )
15481526 self .assertEqual (out_tensor , expected_tensor )
@@ -1562,159 +1540,49 @@ def _test_all_to_all_helper(self, group, group_id, rank):
15621540 self .assertEqual (t1 , t2 )
15631541 self ._barrier ()
15641542
1565- @unittest .skipIf (
1566- BACKEND != "mpi" , "Only MPI supports CPU all_to_all_single"
1567- )
1543+ @unittest .skipIf (BACKEND != "mpi" , "Only MPI supports all_to_all_single" )
15681544 def test_all_to_all_single_equal_split (self ):
15691545 group , group_id , rank = self ._init_global_test ()
15701546 self ._test_all_to_all_single_equal_split_helper (group , group_id , rank )
15711547
1572- @unittest .skipIf (
1573- BACKEND != "nccl" , "Only Nccl supports CUDA all_to_all_single"
1574- )
1575- @skip_if_no_gpu
1576- @skip_if_rocm
1577- def test_all_to_all_single_equal_split_cuda (self ):
1578- group , group_id , rank = self ._init_global_test ()
1579- rank_to_GPU = self ._init_multigpu_helper ()
1580- self ._test_all_to_all_single_equal_split_helper (
1581- group ,
1582- group_id ,
1583- rank ,
1584- True ,
1585- rank_to_GPU ,
1586- )
1587-
1588- @unittest .skipIf (
1589- BACKEND != "mpi" , "Only MPI supports CPU all_to_all_single"
1590- )
1548+ @unittest .skipIf (BACKEND != "mpi" , "Only MPI supports all_to_all_single" )
15911549 def test_all_to_all_single_unequal_split (self ):
15921550 group , group_id , rank = self ._init_global_test ()
15931551 self ._test_all_to_all_single_unequal_split_helper (group , group_id , rank )
15941552
1595- @unittest .skipIf (
1596- BACKEND != "nccl" , "Only Nccl supports CUDA all_to_all_single"
1597- )
1598- @skip_if_no_gpu
1599- @skip_if_rocm
1600- def test_all_to_all_single_unequal_split_cuda (self ):
1601- group , group_id , rank = self ._init_global_test ()
1602- rank_to_GPU = self ._init_multigpu_helper ()
1603- self ._test_all_to_all_single_unequal_split_helper (
1604- group ,
1605- group_id ,
1606- rank ,
1607- True ,
1608- rank_to_GPU ,
1609- )
1610-
16111553 @unittest .skipIf (BACKEND != "mpi" , "Only MPI supports all_to_all" )
16121554 def test_all_to_all (self ):
16131555 group , group_id , rank = self ._init_global_test ()
16141556 self ._test_all_to_all_helper (group , group_id , rank )
16151557
1616- @unittest .skipIf (
1617- BACKEND != "mpi" , "Only MPI supports CPU all_to_all_single"
1618- )
1558+ @unittest .skipIf (BACKEND != "mpi" , "Only MPI supports all_to_all_single" )
16191559 @skip_if_small_worldsize
16201560 def test_all_to_all_single_equal_split_group (self ):
16211561 group , group_id , rank = self ._init_group_test ()
16221562 self ._test_all_to_all_single_equal_split_helper (group , group_id , rank )
16231563
1624- @unittest .skipIf (
1625- BACKEND != "nccl" , "Only Nccl supports CUDA all_to_all_single"
1626- )
1627- @skip_if_no_gpu
1628- @skip_if_rocm
1629- @skip_if_small_worldsize
1630- def test_all_to_all_single_equal_split_group_cuda (self ):
1631- group , group_id , rank = self ._init_group_test ()
1632- rank_to_GPU = self ._init_multigpu_helper ()
1633- self ._test_all_to_all_single_equal_split_helper (
1634- group ,
1635- group_id ,
1636- rank ,
1637- True ,
1638- rank_to_GPU ,
1639- )
1640-
1641- @unittest .skipIf (
1642- BACKEND != "mpi" , "Only MPI supports CPU all_to_all_single"
1643- )
1564+ @unittest .skipIf (BACKEND != "mpi" , "Only MPI supports all_to_all_single" )
16441565 @skip_if_small_worldsize
16451566 def test_all_to_all_single_unequal_split_group (self ):
16461567 group , group_id , rank = self ._init_group_test ()
16471568 self ._test_all_to_all_single_unequal_split_helper (group , group_id , rank )
16481569
1649- @unittest .skipIf (
1650- BACKEND != "nccl" , "Only Nccl supports CUDA all_to_all_single"
1651- )
1652- @skip_if_no_gpu
1653- @skip_if_rocm
1654- @skip_if_small_worldsize
1655- def test_all_to_all_single_unequal_split_group_cuda (self ):
1656- group , group_id , rank = self ._init_global_test ()
1657- rank_to_GPU = self ._init_multigpu_helper ()
1658- self ._test_all_to_all_single_unequal_split_helper (
1659- group ,
1660- group_id ,
1661- rank ,
1662- True ,
1663- rank_to_GPU ,
1664- )
1665-
16661570 @unittest .skipIf (BACKEND != "mpi" , "Only MPI supports all_to_all" )
16671571 @skip_if_small_worldsize
16681572 def test_all_to_all_group (self ):
16691573 group , group_id , rank = self ._init_group_test ()
16701574 self ._test_all_to_all_helper (group , group_id , rank )
16711575
1672- @unittest .skipIf (
1673- BACKEND != "mpi" , "Only MPI supports CPU all_to_all_single"
1674- )
1576+ @unittest .skipIf (BACKEND != "mpi" , "Only MPI supports all_to_all_single" )
16751577 def test_all_to_all_single_equal_split_full_group (self ):
16761578 group , group_id , rank = self ._init_full_group_test ()
16771579 self ._test_all_to_all_single_equal_split_helper (group , group_id , rank )
16781580
1679- @unittest .skipIf (
1680- BACKEND != "nccl" , "Only Nccl supports CUDA all_to_all_single"
1681- )
1682- @skip_if_no_gpu
1683- @skip_if_rocm
1684- def test_all_to_all_single_equal_split_full_group_cuda (self ):
1685- group , group_id , rank = self ._init_full_group_test ()
1686- rank_to_GPU = self ._init_multigpu_helper ()
1687- self ._test_all_to_all_single_equal_split_helper (
1688- group ,
1689- group_id ,
1690- rank ,
1691- True ,
1692- rank_to_GPU ,
1693- )
1694-
1695- @unittest .skipIf (
1696- BACKEND != "mpi" , "Only MPI supports CPU all_to_all_single"
1697- )
1581+ @unittest .skipIf (BACKEND != "mpi" , "Only MPI supports all_to_all_single" )
16981582 def test_all_to_all_single_unequal_split_full_group (self ):
16991583 group , group_id , rank = self ._init_full_group_test ()
17001584 self ._test_all_to_all_single_unequal_split_helper (group , group_id , rank )
17011585
1702- @unittest .skipIf (
1703- BACKEND != "nccl" , "Only Nccl supports CUDA all_to_all_single"
1704- )
1705- @skip_if_no_gpu
1706- @skip_if_rocm
1707- def test_all_to_all_single_unequal_split_full_group_cuda (self ):
1708- group , group_id , rank = self ._init_full_group_test ()
1709- rank_to_GPU = self ._init_multigpu_helper ()
1710- self ._test_all_to_all_single_unequal_split_helper (
1711- group ,
1712- group_id ,
1713- rank ,
1714- True ,
1715- rank_to_GPU ,
1716- )
1717-
17181586 @unittest .skipIf (BACKEND != "mpi" , "Only MPI supports all_to_all" )
17191587 def test_all_to_all_full_group (self ):
17201588 group , group_id , rank = self ._init_full_group_test ()
0 commit comments