Skip to content

Commit 254dedf

Browse files
ssnlfacebook-github-bot
authored andcommitted
Propagate NaN through threshold (#10277)
Summary: Fixes #10238 Pull Request resolved: #10277 Reviewed By: SsnL Differential Revision: D9199825 Pulled By: soumith fbshipit-source-id: 8ee7f9a72d9546d429f311c3f6028461d3c93fe2
1 parent 0bbcc7b commit 254dedf

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

aten/src/THCUNN/Threshold.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct ThresholdUpdateOutput
1717
__device__ __forceinline__ void operator()(T *out, T *in)
1818
{
1919
T x = *in;
20-
*out = (x > threshold_) ? x : val_;
20+
*out = (x <= threshold_) ? val_ : x; // this order propagates NaN
2121
}
2222
};
2323

@@ -35,7 +35,7 @@ struct ThresholdUpdateOutputIP
3535

3636
__device__ __forceinline__ void operator()(T *x)
3737
{
38-
*x = (*x > threshold_) ? *x : val_;
38+
*x = (*x <= threshold_) ? val_ : *x; // this order propagates NaN
3939
}
4040
};
4141

@@ -51,7 +51,7 @@ struct ThresholdUpdateGradInput
5151
__device__ __forceinline__ void operator()(
5252
T *gradInput, T *input, T *gradOutput) const
5353
{
54-
*gradInput = (*input > threshold_) ? *gradOutput : ScalarConvert<int, T>::to(0);
54+
*gradInput = (*input <= threshold_) ? ScalarConvert<int, T>::to(0) : *gradOutput; // this order propagates NaN
5555
}
5656
};
5757

@@ -67,7 +67,7 @@ struct ThresholdUpdateGradInputIP
6767
__device__ __forceinline__ void operator()(
6868
T *gradOutput, T *input) const
6969
{
70-
*gradOutput = (*input > threshold_) ? *gradOutput : ScalarConvert<int, T>::to(0);
70+
*gradOutput = (*input <= threshold_) ? ScalarConvert<int, T>::to(0) : *gradOutput; // this order propagates NaN
7171
}
7272
};
7373

aten/src/THNN/generic/Threshold.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ void THNN_(Threshold_updateOutput)(
2424
{
2525
THTensor_(resizeAs)(output, input);
2626
TH_TENSOR_APPLY2(real, output, real, input,
27-
*output_data = (*input_data > threshold) ? *input_data : val;
27+
*output_data = (*input_data <= threshold) ? val : *input_data; // this order propagates NaN
2828
);
2929
}
3030
}
@@ -52,10 +52,10 @@ void THNN_(Threshold_updateGradInput)(
5252
{
5353
THTensor_(resizeAs)(gradInput, input);
5454
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input,
55-
if ((*input_data) > threshold)
56-
*gradInput_data = *gradOutput_data;
57-
else
55+
if ((*input_data) <= threshold)
5856
*gradInput_data = 0;
57+
else
58+
*gradInput_data = *gradOutput_data; // let NaN case pass through here
5959
);
6060
}
6161
}

test/test_nn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,6 +1614,8 @@ def test(nonlinearity, *args, **kwargs):
16141614
test('softmax', 0)
16151615
test('log_softmax', 0)
16161616
test('leaky_relu', 0.2)
1617+
test('threshold', 3, 2)
1618+
test('threshold', 3, 2, inplace=True)
16171619

16181620
def test_nonlinearity_propagate_nan(self):
16191621
self._test_nonlinearity_propagate_nan('cpu')

0 commit comments

Comments
 (0)