@@ -1338,6 +1338,46 @@ def test_vector_to_parameters(self):
13381338 sample = next (model .parameters ())[0 , 0 , 0 ]
13391339 self .assertTrue (torch .equal (sample .data , vec .data [:5 ]))
13401340
1341+ # We don't want to make propagating NaN a hard requirement on ops, but for
1342+ # these easy ones, we should make them do so.
1343+ def _test_nonlinearity_propagate_nan (self , device ):
1344+ nan = float ('nan' )
1345+
1346+ def test (nonlinearity , * args , ** kwargs ):
1347+ x = torch .tensor ([nan ], device = device )
1348+ fn = getattr (F , nonlinearity )
1349+ try :
1350+ self .assertTrue (math .isnan (fn (x , * args , ** kwargs ).item ()))
1351+ except Exception as e :
1352+ if 'not implemented' not in str (e ):
1353+ raise
1354+
1355+ test ('relu' )
1356+ test ('relu' , inplace = True )
1357+ test ('relu6' )
1358+ test ('elu' )
1359+ test ('selu' )
1360+ test ('rrelu' )
1361+ test ('rrelu' , inplace = True )
1362+ test ('hardtanh' )
1363+ test ('tanh' )
1364+ test ('sigmoid' )
1365+ test ('logsigmoid' )
1366+ test ('hardshrink' )
1367+ test ('tanhshrink' )
1368+ test ('softsign' )
1369+ test ('softmin' , 0 )
1370+ test ('softmax' , 0 )
1371+ test ('log_softmax' , 0 )
1372+ test ('leaky_relu' , 0.2 )
1373+
1374+ def test_nonlinearity_propagate_nan (self ):
1375+ self ._test_nonlinearity_propagate_nan ('cpu' )
1376+
1377+ @unittest .skipIf (not TEST_CUDA , "CUDA unavailable" )
1378+ def test_nonlinearity_propagate_nan_cuda (self ):
1379+ self ._test_nonlinearity_propagate_nan ('cuda' )
1380+
13411381 def test_weight_norm (self ):
13421382 input = torch .randn (3 , 5 )
13431383 m = nn .Linear (5 , 7 )
0 commit comments