Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aten/src/TH/generic/THTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2803,13 +2803,13 @@ void THTensor_(cmin)(THTensor *r, THTensor *t, THTensor *src) {
void THTensor_(cmaxValue)(THTensor *r, THTensor *t, real value) {
THTensor_(resizeAs)(r, t);
TH_TENSOR_APPLY2(real, r, real, t,
*r_data = *t_data > value ? *t_data : value;);
*r_data = *t_data < value ? value : *t_data;); // this order propagates NaN
}

void THTensor_(cminValue)(THTensor *r, THTensor *t, real value) {
THTensor_(resizeAs)(r, t);
TH_TENSOR_APPLY2(real, r, real, t,
*r_data = *t_data < value ? *t_data : value;);
*r_data = *t_data > value ? value : *t_data;); // this order propagates NaN
}

void THTensor_(zeros)(THTensor *r_, THLongStorage *size)
Expand Down
8 changes: 4 additions & 4 deletions aten/src/THC/THCTensorMathPointwise.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -655,11 +655,11 @@ struct TensorMaxValueOp {
TensorMaxValueOp(T v) : val(v) {}

__device__ __forceinline__ void operator()(T* out) {
*out = THCNumerics<T>::gt(*out, val) ? *out : val;
*out = THCNumerics<T>::lt(*out, val) ? val : *out; // this order propagates NaN
}

__device__ __forceinline__ void operator()(T* out, T* in) {
*out = THCNumerics<T>::gt(*in, val) ? *in : val;
*out = THCNumerics<T>::lt(*in, val) ? val : *in; // this order propagates NaN
}

T val;
Expand All @@ -670,11 +670,11 @@ struct TensorMinValueOp {
TensorMinValueOp(T v) : val(v) {}

__device__ __forceinline__ void operator()(T* out) {
*out = THCNumerics<T>::lt(*out, val) ? *out : val;
*out = THCNumerics<T>::gt(*out, val) ? val : *out; // this order propagates NaN
}

__device__ __forceinline__ void operator()(T* out, T* in) {
*out = THCNumerics<T>::lt(*in, val) ? *in : val;
*out = THCNumerics<T>::gt(*in, val) ? val : *in; // this order propagates NaN
}

T val;
Expand Down
6 changes: 3 additions & 3 deletions aten/src/THCUNN/HardTanh.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ struct hardtanhupdateOutput_functor
{
if (*input < min_val_)
*output = min_val_;
else if (*input <= max_val_)
*output = *input;
else
else if (*input > max_val_)
*output = max_val_;
else
*output = *input;
}

__device__ void operator()(T *input) const
Expand Down
12 changes: 6 additions & 6 deletions aten/src/THNN/generic/HardShrink.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ void THNN_(HardShrink_updateOutput)(
TH_TENSOR_APPLY2(real, output, real, input,
if (*input_data > lambda)
*output_data = *input_data;
else if (*input_data < -lambda)
*output_data = *input_data;
else
else if (*input_data >= -lambda)
*output_data = 0;
else
*output_data = *input_data; // let NaN case pass through here
);
}

Expand All @@ -32,10 +32,10 @@ void THNN_(HardShrink_updateGradInput)(
THNN_CHECK_NELEMENT(input, gradOutput);
THTensor_(resizeAs)(gradInput, input);
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input,
if (*input_data > lambda || *input_data < -lambda)
*gradInput_data = *gradOutput_data;
else
if (*input_data >= -lambda && *input_data <= lambda)
*gradInput_data = 0;
else
*gradInput_data = *gradOutput_data; // let NaN case pass through here
);
}

Expand Down
6 changes: 3 additions & 3 deletions aten/src/THNN/generic/HardTanh.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ void THNN_(HardTanh_updateOutput)(
TH_TENSOR_APPLY2(real, output, real, input,
if (*input_data < min_val)
*output_data = min_val;
else if (*input_data <= max_val)
*output_data = *input_data;
else
else if (*input_data > max_val)
*output_data = max_val;
else
*output_data = *input_data;
);
}
}
Expand Down
40 changes: 40 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,46 @@ def test_vector_to_parameters(self):
sample = next(model.parameters())[0, 0, 0]
self.assertTrue(torch.equal(sample.data, vec.data[:5]))

# We don't want to make propagating NaN a hard requirement on ops, but for
# these easy ones, we should make them do so.
def _test_nonlinearity_propagate_nan(self, device):
nan = float('nan')

def test(nonlinearity, *args, **kwargs):
x = torch.tensor([nan], device=device)
fn = getattr(F, nonlinearity)
try:
self.assertTrue(math.isnan(fn(x, *args, **kwargs).item()))
except Exception as e:
if 'not implemented' not in str(e):
raise

test('relu')
test('relu', inplace=True)
test('relu6')
test('elu')
test('selu')
test('rrelu')
test('rrelu', inplace=True)
test('hardtanh')
test('tanh')
test('sigmoid')
test('logsigmoid')
test('hardshrink')
test('tanhshrink')
test('softsign')
test('softmin', 0)
test('softmax', 0)
test('log_softmax', 0)
test('leaky_relu', 0.2)

def test_nonlinearity_propagate_nan(self):
self._test_nonlinearity_propagate_nan('cpu')

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_nonlinearity_propagate_nan_cuda(self):
self._test_nonlinearity_propagate_nan('cuda')

def test_weight_norm(self):
input = torch.randn(3, 5)
m = nn.Linear(5, 7)
Expand Down