@@ -1609,6 +1609,60 @@ def forward(self, input):
16091609 self .assertEqual (out .get_device (), 0 )
16101610 self .assertEqual (out .data , expected_out )
16111611
1612+ @unittest .skipIf (not TEST_CUDA , "CUDA unavailable" )
1613+ def test_data_parallel_module_kwargs_only_empty_list (self ):
1614+ class Net (nn .Module ):
1615+ def __init__ (self ):
1616+ super (Net , self ).__init__ ()
1617+ self .l = l
1618+
1619+ def forward (self , input ):
1620+ return self .l (input ['data' ])
1621+
1622+ l = nn .Linear (10 , 5 ).float ().cuda ()
1623+ i = Variable (torch .randn (20 , 10 ).float ().cuda ())
1624+ expected_out = l (i ).data
1625+ n = nn .DataParallel (Net ())
1626+ out = n (input = {'data' : i , 'unused' : []})
1627+ self .assertEqual (out .get_device (), 0 )
1628+ self .assertEqual (out .data , expected_out )
1629+
1630+ @unittest .skipIf (not TEST_CUDA , "CUDA unavailable" )
1631+ def test_data_parallel_module_kwargs_only_empty_dict (self ):
1632+ class Net (nn .Module ):
1633+ def __init__ (self ):
1634+ super (Net , self ).__init__ ()
1635+ self .l = l
1636+
1637+ def forward (self , input ):
1638+ return self .l (input ['data' ])
1639+
1640+ l = nn .Linear (10 , 5 ).float ().cuda ()
1641+ i = Variable (torch .randn (20 , 10 ).float ().cuda ())
1642+ expected_out = l (i ).data
1643+ n = nn .DataParallel (Net ())
1644+ out = n (input = {'data' : i , 'unused' : {}})
1645+ self .assertEqual (out .get_device (), 0 )
1646+ self .assertEqual (out .data , expected_out )
1647+
1648+ @unittest .skipIf (not TEST_CUDA , "CUDA unavailable" )
1649+ def test_data_parallel_module_kwargs_only_empty_tuple (self ):
1650+ class Net (nn .Module ):
1651+ def __init__ (self ):
1652+ super (Net , self ).__init__ ()
1653+ self .l = l
1654+
1655+ def forward (self , input ):
1656+ return self .l (input ['data' ])
1657+
1658+ l = nn .Linear (10 , 5 ).float ().cuda ()
1659+ i = Variable (torch .randn (20 , 10 ).float ().cuda ())
1660+ expected_out = l (i ).data
1661+ n = nn .DataParallel (Net ())
1662+ out = n (input = {'data' : i , 'unused' : ()})
1663+ self .assertEqual (out .get_device (), 0 )
1664+ self .assertEqual (out .data , expected_out )
1665+
16121666 def test_state_dict (self ):
16131667 l = nn .Linear (5 , 5 )
16141668 block = nn .Module ()
0 commit comments