Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions aten/src/THC/THCNumerics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct THCNumerics<uint8_t> {
static inline __host__ __device__ uint8_t div(uint8_t a, uint8_t b) { return a / b; }
static inline __host__ __device__ uint8_t abs(uint8_t a) { return a; }
static inline __host__ __device__ uint8_t pow(uint8_t a, uint8_t b) { return powi<uint8_t>(a, b); }
static inline __host__ __device__ bool isnan(uint8_t a) { return false; }
};

template <>
Expand All @@ -68,6 +69,7 @@ struct THCNumerics<int8_t> {
static inline __host__ __device__ int8_t div(int8_t a, int8_t b) { return a / b; }
static inline __host__ __device__ int8_t abs(int8_t a) { return ::abs((int)a); }
static inline __host__ __device__ int8_t pow(int8_t a, int8_t b) { return powi<int8_t>(a, b); }
static inline __host__ __device__ bool isnan(int8_t a) { return false; }
};

template <>
Expand All @@ -89,6 +91,7 @@ struct THCNumerics<int16_t> {
static inline __host__ __device__ int16_t div(int16_t a, int16_t b) { return a / b; }
static inline __host__ __device__ int16_t abs(int16_t a) { return ::abs((int)a); }
static inline __host__ __device__ int16_t pow(int16_t a, int16_t b) { return powi<int16_t>(a, b); }
static inline __host__ __device__ bool isnan(int16_t a) { return false; }
};

template <>
Expand All @@ -110,6 +113,7 @@ struct THCNumerics<int32_t> {
static inline __host__ __device__ int32_t div(int32_t a, int32_t b) { return a / b; }
static inline __host__ __device__ int32_t abs(int32_t a) { return ::abs(a); }
static inline __host__ __device__ int32_t pow(int32_t a, int32_t b) { return powi<int32_t>(a, b); }
static inline __host__ __device__ bool isnan(int32_t a) { return false; }
};

template <>
Expand Down Expand Up @@ -137,6 +141,7 @@ struct THCNumerics<int64_t> {
static inline __host__ __device__ int64_t div(int64_t a, int64_t b) { return a / b; };
static inline __host__ __device__ int64_t abs(int64_t a) { return labs(a); }
static inline __host__ __device__ int64_t pow(int64_t a, int64_t b) { return powi<int64_t>(a, b); }
static inline __host__ __device__ bool isnan(int64_t a) { return false; }
};

#ifdef CUDA_HALF_TENSOR
Expand Down Expand Up @@ -614,6 +619,11 @@ static inline __host__ __device__ half lgamma(half a) {
#endif
}

static inline __host__ __device__ bool isnan(half a) {
// implemented using that a!=a if and only if a is nan
return ne(a, a);
}

};
#endif

Expand Down Expand Up @@ -666,6 +676,7 @@ struct THCNumerics<float> {
static inline __host__ __device__ float sub (float a, float b) { return a - b; }
static inline __host__ __device__ float pow (float a, float b) { return powf(a, b); }
static inline __host__ __device__ float atan2(float a, float b) { return atan2f(a, b); }
static inline __host__ __device__ bool isnan(float a) { return ::isnan(a); }
};

template <>
Expand Down Expand Up @@ -717,6 +728,7 @@ struct THCNumerics<double> {
static inline __host__ __device__ double sub (double a, double b) { return a - b; }
static inline __host__ __device__ double pow (double a, double b) { return ::pow(a, b); }
static inline __host__ __device__ double atan2(double a, double b) { return ::atan2(a, b); }
static inline __host__ __device__ bool isnan(double a) { return ::isnan(a); }
};

/// `half` has some type conversion issues associated with it, since it
Expand Down
18 changes: 12 additions & 6 deletions aten/src/THC/THCTensorMathReduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,24 @@ struct SquareFunctor<ResT, half> {
template <typename T>
struct ReduceMin {
inline __device__ T operator()(T a, T b) const {
return THCNumerics<T>::lt(a, b) ? a : b;
return (THCNumerics<T>::lt(a, b) ||
THCNumerics<T>::isnan(a)) ? a : b;
}
};

template <typename T>
struct ReduceMax {
inline __device__ T operator()(T a, T b) const {
return THCNumerics<T>::gt(a, b) ? a : b;
return (THCNumerics<T>::gt(a, b) ||
THCNumerics<T>::isnan(a)) ? a : b;
}
};

template <typename InT, typename AccT>
struct ReduceMaxTo {
inline __device__ AccT operator()(InT a, InT b) const {
return ScalarConvert<InT, AccT>::to(THCNumerics<InT>::gt(a, b) ? a : b);
return ScalarConvert<InT, AccT>::to(
(THCNumerics<InT>::gt(a, b) || THCNumerics<InT>::isnan(a)) ? a : b);
}
};

Expand All @@ -128,7 +131,8 @@ template <>
struct ReduceMaxTo<half, float> {
inline __device__ float operator()(float a, half b) const {
float b_f = __half2float(b);
return (THCNumerics<float>::gt(a, b_f) ? a : b_f);
return ((THCNumerics<float>::gt(a, b_f) ||
THCNumerics<float>::isnan(a)) ? a : b_f);
}
};
#endif // CUDA_HALF_TENSOR
Expand Down Expand Up @@ -789,7 +793,8 @@ struct MaxValuePair {
__host__ __device__
thrust::pair<T, Index> operator()(const thrust::pair<T, Index>& a,
const thrust::pair<T, Index>& b) {
return THCNumerics<T>::ge(a.first, b.first) ? a : b;
return (THCNumerics<T>::ge(a.first, b.first) ||
THCNumerics<T>::isnan(a.first)) ? a : b;
}
};

Expand All @@ -798,7 +803,8 @@ struct MinValuePair {
__host__ __device__
thrust::pair<T, Index> operator()(const thrust::pair<T, Index>& a,
const thrust::pair<T, Index>& b) {
return THCNumerics<T>::le(a.first, b.first) ? a : b;
return (THCNumerics<T>::le(a.first, b.first) ||
THCNumerics<T>::isnan(a.first)) ? a : b;
}
};

Expand Down
14 changes: 14 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,20 @@ def test_broadcast_cpu(self):
def test_broadcast_gpu(self):
self._test_broadcast(torch.randn(5, 5).cuda())

def test_min_max_nan(self):
tests = [(lambda x: x.min(), 'min'),
(lambda x: x.max(), 'max'),
(lambda x: x.min(0)[0], 'min_dim'),
(lambda x: x.max(0)[0], 'max_dim')]
for f, name in tests:
a = torch.arange(25.0).view(5, 5)
a[2, 2] = float('nan')
actual = f(a.cuda()).cpu()
expected = f(a).cpu()
self.assertEqual(torch.isnan(actual), torch.isnan(expected), 'nans for {}'.format(name))
self.assertEqual(actual[~torch.isnan(actual)],

This comment was marked as off-topic.

This comment was marked as off-topic.

expected[~torch.isnan(expected)], 'nans for {}'.format(name))

@staticmethod
def _test_broadcast_coalesced(self, tensors, buffer_size):
b_tensors = [comm.broadcast(t, (0, 1)) for t in tensors]
Expand Down