Skip to content

Commit 9ea576d

Browse files
authored
Implement neg for all types (#4075)
The C/C++ unary negation operator is well defined for unsigned types. We should use that behavior. This also implements neg for CharTensor. That behavior currently depends on whether char is signed or unsigned. Fixes #4066, #3225
1 parent 60c03bc commit 9ea576d

File tree

9 files changed

+6
-44
lines changed

9 files changed

+6
-44
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,11 +2063,6 @@
20632063
]]
20642064
[[
20652065
name: neg
2066-
types:
2067-
- floating_point
2068-
- Long
2069-
- Int
2070-
- Short
20712066
backends:
20722067
- CPU
20732068
- CUDA
@@ -2084,11 +2079,6 @@
20842079
]]
20852080
[[
20862081
name: neg_
2087-
types:
2088-
- floating_point
2089-
- Long
2090-
- Int
2091-
- Short
20922082
backends:
20932083
- CPU
20942084
- CUDA

aten/src/TH/generic/THTensorMath.c

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2995,14 +2995,14 @@ TENSOR_IMPLEMENT_LOGICAL(ne,!=)
29952995
} \
29962996
}
29972997

2998+
LAB_IMPLEMENT_BASIC_FUNCTION(neg,-)
2999+
29983000
#if defined(TH_REAL_IS_LONG)
29993001
LAB_IMPLEMENT_BASIC_FUNCTION(abs,labs)
3000-
LAB_IMPLEMENT_BASIC_FUNCTION(neg,-)
30013002
#endif /* int64_t only part */
30023003

30033004
#if defined(TH_REAL_IS_SHORT) || defined(TH_REAL_IS_INT)
30043005
LAB_IMPLEMENT_BASIC_FUNCTION(abs,abs)
3005-
LAB_IMPLEMENT_BASIC_FUNCTION(neg,-)
30063006
#endif /* int only part */
30073007

30083008
#if defined(TH_REAL_IS_BYTE)
@@ -3053,7 +3053,6 @@ LAB_IMPLEMENT_BASIC_FUNCTION(round,TH_MATH_NAME(round))
30533053
LAB_IMPLEMENT_BASIC_FUNCTION(abs,TH_MATH_NAME(fabs))
30543054
LAB_IMPLEMENT_BASIC_FUNCTION(trunc,TH_MATH_NAME(trunc))
30553055
LAB_IMPLEMENT_BASIC_FUNCTION(frac,TH_MATH_NAME(TH_frac))
3056-
LAB_IMPLEMENT_BASIC_FUNCTION(neg,-)
30573056
LAB_IMPLEMENT_BASIC_FUNCTION(cinv, TH_MATH_NAME(1.0) / )
30583057

30593058

aten/src/TH/generic/THVector.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ TH_API void THVector_(muls)(real *y, const real *x, const real c, const ptrdiff_
1010
TH_API void THVector_(cdiv)(real *z, const real *x, const real *y, const ptrdiff_t n);
1111
TH_API void THVector_(divs)(real *y, const real *x, const real c, const ptrdiff_t n);
1212
TH_API void THVector_(copy)(real *y, const real *x, const ptrdiff_t n);
13+
TH_API void THVector_(neg)(real *y, const real *x, const ptrdiff_t n);
1314

1415
#if defined(TH_REAL_IS_SHORT) || defined(TH_REAL_IS_INT) || defined(TH_REAL_IS_LONG)
1516
TH_API void THVector_(abs)(real *y, const real *x, const ptrdiff_t n);
@@ -47,10 +48,6 @@ TH_API void THVector_(cinv)(real *y, const real *x, const ptrdiff_t n);
4748

4849
#endif /* floating point only part */
4950

50-
#ifndef TH_REAL_IS_BYTE
51-
TH_API void THVector_(neg)(real *y, const real *x, const ptrdiff_t n);
52-
#endif
53-
5451
/* Initialize the dispatch pointers */
5552
TH_API void THVector_(vectorDispatchInit)(void);
5653

aten/src/TH/generic/THVectorDefault.c

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,6 @@ VECTOR_IMPLEMENT_FUNCTION(cinv, TH_MATH_NAME(1.0) / )
206206
#undef TH_MATH_NAME
207207
#endif /* floating point only part */
208208

209-
#ifndef TH_REAL_IS_BYTE
210209
VECTOR_IMPLEMENT_FUNCTION(neg,-)
211-
#endif
212210

213211
#endif

aten/src/THC/THCNumerics.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ struct THCNumerics<uint8_t> {
2525
static inline __host__ __device__ bool eq(uint8_t a, uint8_t b) { return a == b; }
2626
static inline __host__ __device__ bool ne(uint8_t a, uint8_t b) { return a != b; }
2727

28+
static inline __host__ __device__ uint8_t neg(int8_t a) { return -a; }
2829
static inline __host__ __device__ uint8_t add(uint8_t a, uint8_t b) { return a + b; }
2930
static inline __host__ __device__ uint8_t mul(uint8_t a, uint8_t b) { return a * b; }
3031
static inline __host__ __device__ uint8_t sub(uint8_t a, uint8_t b) { return a - b; }

aten/src/THC/generic/THCTensorMathPointwise.cu

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,7 @@ IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( cinv, THCNumerics<real>::cinv, Real)
6262

6363
#endif
6464

65-
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) || \
66-
defined(THC_REAL_IS_SHORT) || defined(THC_REAL_IS_INT) || defined(THC_REAL_IS_LONG)
67-
6865
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( neg, THCNumerics<real>::neg, Real)
69-
70-
#endif
71-
7266
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( abs, THCNumerics<real>::abs, Real)
7367

7468
#undef IMPLEMENT_CUDA_TENSOR_BASIC_FUNC_

aten/src/THC/generic/THCTensorMathPointwise.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,7 @@ THC_API void THCTensor_(cinv)(THCState *state, THCTensor *self, THCTensor *src);
3636

3737
#endif
3838

39-
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) || \
40-
defined(THC_REAL_IS_SHORT) || defined(THC_REAL_IS_INT) || defined(THC_REAL_IS_LONG)
41-
4239
THC_API void THCTensor_(neg)(THCState *state, THCTensor *self, THCTensor *src);
43-
44-
#endif
45-
4640
THC_API void THCTensor_(abs)(THCState *state, THCTensor *self, THCTensor *src);
4741
THC_API void THCTensor_(sign)(THCState *state, THCTensor *self, THCTensor *src);
4842
THC_API void THCTensor_(clamp)(THCState *state, THCTensor *self, THCTensor *src, real min_value, real max_value);

test/test_torch.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,8 @@ def test_csub(self):
525525
@staticmethod
526526
def _test_neg(self, cast):
527527
float_types = ['torch.DoubleTensor', 'torch.FloatTensor', 'torch.LongTensor']
528-
int_types = ['torch.IntTensor', 'torch.ShortTensor']
528+
int_types = ['torch.IntTensor', 'torch.ShortTensor', 'torch.ByteTensor',
529+
'torch.CharTensor']
529530

530531
for t in float_types + int_types:
531532
if t in float_types:
@@ -552,8 +553,6 @@ def test_neg(self):
552553

553554
def test_reciprocal(self):
554555
a = torch.randn(100, 89)
555-
zeros = torch.Tensor().resize_as_(a).zero_()
556-
557556
res_div = 1 / a
558557
res_reciprocal = a.clone()
559558
res_reciprocal.reciprocal_()

torch/csrc/generic/methods/TensorMath.cwrap

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,11 +1039,6 @@
10391039

10401040
[[
10411041
name: neg
1042-
types:
1043-
- floating_point
1044-
- Long
1045-
- Int
1046-
- Short
10471042
backends:
10481043
- CPU
10491044
- CUDA
@@ -1061,11 +1056,6 @@
10611056

10621057
[[
10631058
name: neg_
1064-
types:
1065-
- floating_point
1066-
- Long
1067-
- Int
1068-
- Short
10691059
backends:
10701060
- CPU
10711061
- CUDA

0 commit comments

Comments
 (0)