Skip to content

Commit bf29abd

Browse files
authored
propagate nan in some activations (#8033)
* propagate nan in some activations * fix py2 not having math.nan * flake8
1 parent 8b447fa commit bf29abd

File tree

6 files changed

+58
-18
lines changed

6 files changed

+58
-18
lines changed

aten/src/TH/generic/THTensorMath.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2803,13 +2803,13 @@ void THTensor_(cmin)(THTensor *r, THTensor *t, THTensor *src) {
28032803
void THTensor_(cmaxValue)(THTensor *r, THTensor *t, real value) {
28042804
THTensor_(resizeAs)(r, t);
28052805
TH_TENSOR_APPLY2(real, r, real, t,
2806-
*r_data = *t_data > value ? *t_data : value;);
2806+
*r_data = *t_data < value ? value : *t_data;); // this order propagates NaN
28072807
}
28082808

28092809
void THTensor_(cminValue)(THTensor *r, THTensor *t, real value) {
28102810
THTensor_(resizeAs)(r, t);
28112811
TH_TENSOR_APPLY2(real, r, real, t,
2812-
*r_data = *t_data < value ? *t_data : value;);
2812+
*r_data = *t_data > value ? value : *t_data;); // this order propagates NaN
28132813
}
28142814

28152815
void THTensor_(zeros)(THTensor *r_, THLongStorage *size)

aten/src/THC/THCTensorMathPointwise.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -655,11 +655,11 @@ struct TensorMaxValueOp {
655655
TensorMaxValueOp(T v) : val(v) {}
656656

657657
__device__ __forceinline__ void operator()(T* out) {
658-
*out = THCNumerics<T>::gt(*out, val) ? *out : val;
658+
*out = THCNumerics<T>::lt(*out, val) ? val : *out; // this order propagates NaN
659659
}
660660

661661
__device__ __forceinline__ void operator()(T* out, T* in) {
662-
*out = THCNumerics<T>::gt(*in, val) ? *in : val;
662+
*out = THCNumerics<T>::lt(*in, val) ? val : *in; // this order propagates NaN
663663
}
664664

665665
T val;
@@ -670,11 +670,11 @@ struct TensorMinValueOp {
670670
TensorMinValueOp(T v) : val(v) {}
671671

672672
__device__ __forceinline__ void operator()(T* out) {
673-
*out = THCNumerics<T>::lt(*out, val) ? *out : val;
673+
*out = THCNumerics<T>::gt(*out, val) ? val : *out; // this order propagates NaN
674674
}
675675

676676
__device__ __forceinline__ void operator()(T* out, T* in) {
677-
*out = THCNumerics<T>::lt(*in, val) ? *in : val;
677+
*out = THCNumerics<T>::gt(*in, val) ? val : *in; // this order propagates NaN
678678
}
679679

680680
T val;

aten/src/THCUNN/HardTanh.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ struct hardtanhupdateOutput_functor
1818
{
1919
if (*input < min_val_)
2020
*output = min_val_;
21-
else if (*input <= max_val_)
22-
*output = *input;
23-
else
21+
else if (*input > max_val_)
2422
*output = max_val_;
23+
else
24+
*output = *input;
2525
}
2626

2727
__device__ void operator()(T *input) const

aten/src/THNN/generic/HardShrink.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ void THNN_(HardShrink_updateOutput)(
1414
TH_TENSOR_APPLY2(real, output, real, input,
1515
if (*input_data > lambda)
1616
*output_data = *input_data;
17-
else if (*input_data < -lambda)
18-
*output_data = *input_data;
19-
else
17+
else if (*input_data >= -lambda)
2018
*output_data = 0;
19+
else
20+
*output_data = *input_data; // let NaN case pass through here
2121
);
2222
}
2323

@@ -32,10 +32,10 @@ void THNN_(HardShrink_updateGradInput)(
3232
THNN_CHECK_NELEMENT(input, gradOutput);
3333
THTensor_(resizeAs)(gradInput, input);
3434
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input,
35-
if (*input_data > lambda || *input_data < -lambda)
36-
*gradInput_data = *gradOutput_data;
37-
else
35+
if (*input_data >= -lambda && *input_data <= lambda)
3836
*gradInput_data = 0;
37+
else
38+
*gradInput_data = *gradOutput_data; // let NaN case pass through here
3939
);
4040
}
4141

aten/src/THNN/generic/HardTanh.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ void THNN_(HardTanh_updateOutput)(
3333
TH_TENSOR_APPLY2(real, output, real, input,
3434
if (*input_data < min_val)
3535
*output_data = min_val;
36-
else if (*input_data <= max_val)
37-
*output_data = *input_data;
38-
else
36+
else if (*input_data > max_val)
3937
*output_data = max_val;
38+
else
39+
*output_data = *input_data;
4040
);
4141
}
4242
}

test/test_nn.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,46 @@ def test_vector_to_parameters(self):
13381338
sample = next(model.parameters())[0, 0, 0]
13391339
self.assertTrue(torch.equal(sample.data, vec.data[:5]))
13401340

1341+
# We don't want to make propagating NaN a hard requirement on ops, but for
1342+
# these easy ones, we should make them do so.
1343+
def _test_nonlinearity_propagate_nan(self, device):
1344+
nan = float('nan')
1345+
1346+
def test(nonlinearity, *args, **kwargs):
1347+
x = torch.tensor([nan], device=device)
1348+
fn = getattr(F, nonlinearity)
1349+
try:
1350+
self.assertTrue(math.isnan(fn(x, *args, **kwargs).item()))
1351+
except Exception as e:
1352+
if 'not implemented' not in str(e):
1353+
raise
1354+
1355+
test('relu')
1356+
test('relu', inplace=True)
1357+
test('relu6')
1358+
test('elu')
1359+
test('selu')
1360+
test('rrelu')
1361+
test('rrelu', inplace=True)
1362+
test('hardtanh')
1363+
test('tanh')
1364+
test('sigmoid')
1365+
test('logsigmoid')
1366+
test('hardshrink')
1367+
test('tanhshrink')
1368+
test('softsign')
1369+
test('softmin', 0)
1370+
test('softmax', 0)
1371+
test('log_softmax', 0)
1372+
test('leaky_relu', 0.2)
1373+
1374+
def test_nonlinearity_propagate_nan(self):
1375+
self._test_nonlinearity_propagate_nan('cpu')
1376+
1377+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1378+
def test_nonlinearity_propagate_nan_cuda(self):
1379+
self._test_nonlinearity_propagate_nan('cuda')
1380+
13411381
def test_weight_norm(self):
13421382
input = torch.randn(3, 5)
13431383
m = nn.Linear(5, 7)

0 commit comments

Comments
 (0)