Skip to content

Commit 39d4814

Browse files
sighingnowapaszke
authored andcommitted
Make any and all on ByteTensor behave like sum/prod. (#4627)
1 parent 241a1e0 commit 39d4814

File tree

12 files changed

+423
-38
lines changed

12 files changed

+423
-38
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,25 +1045,55 @@
10451045
name: all
10461046
types:
10471047
- Byte
1048+
variants:
1049+
- method
1050+
- function
10481051
backends:
10491052
- CPU
10501053
- CUDA
1051-
cname: logicalall
1052-
return: real
1053-
arguments:
1054-
- THTensor* self
1054+
options:
1055+
- cname: logicalAndAll
1056+
return: real
1057+
arguments:
1058+
- THTensor* self
1059+
- cname: logicalAnd
1060+
return: argument 0
1061+
scalar_check: self_->isScalar() || (keepdim == false && self_->dim() == 1)
1062+
arguments:
1063+
- arg: THTensor* result
1064+
output: True
1065+
- THTensor* self
1066+
- arg: long dim
1067+
wrap_dim: self
1068+
- arg: bool keepdim
1069+
default: "false"
10551070
]]
10561071
[[
10571072
name: any
10581073
types:
10591074
- Byte
1075+
variants:
1076+
- method
1077+
- function
10601078
backends:
10611079
- CPU
10621080
- CUDA
1063-
cname: logicalany
1064-
return: real
1065-
arguments:
1066-
- THTensor* self
1081+
options:
1082+
- cname: logicalAnyAll
1083+
return: real
1084+
arguments:
1085+
- THTensor* self
1086+
- cname: logicalAny
1087+
return: argument 0
1088+
scalar_check: self_->isScalar() || (keepdim == false && self_->dim() == 1)
1089+
arguments:
1090+
- arg: THTensor* result
1091+
output: True
1092+
- THTensor* self
1093+
- arg: long dim
1094+
wrap_dim: self
1095+
- arg: bool keepdim
1096+
default: "false"
10671097
]]
10681098
[[
10691099
name: getDevice

aten/src/TH/generic/THTensorMath.c

Lines changed: 197 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3856,16 +3856,205 @@ LAB_IMPLEMENT_BASIC_FUNCTION(abs,abs)
38563856

38573857
#if defined(TH_REAL_IS_BYTE)
38583858

3859-
#define TENSOR_IMPLEMENT_LOGICAL_SUM(NAME, OP, INIT_VALUE) \
3860-
int THTensor_(NAME)(THTensor *tensor) \
3861-
{ \
3862-
int sum = INIT_VALUE; \
3863-
TH_TENSOR_APPLY(real, tensor, sum = sum OP *tensor_data;); \
3864-
return sum; \
3859+
int THTensor_(logicalAndAll)(THTensor *tensor)
3860+
{
3861+
real prod = 1;
3862+
int serial_path = 0;
3863+
#ifdef _OPENMP
3864+
int inOMP = omp_in_parallel();
3865+
if(inOMP) {
3866+
serial_path = 1;
3867+
} else {
3868+
TH_TENSOR_APPLY_REDUCTION_OMP(real, tensor, &&:prod, prod = prod && *tensor_data;);
3869+
}
3870+
#else
3871+
serial_path = 1;
3872+
#endif
3873+
if (serial_path) {
3874+
TH_TENSOR_APPLY(real, tensor, prod = prod && *tensor_data;);
3875+
}
3876+
return prod;
3877+
}
3878+
3879+
int THTensor_(logicalAnyAll)(THTensor *tensor)
3880+
{
3881+
real sum = 0;
3882+
int serial_path = 0;
3883+
#ifdef _OPENMP
3884+
int inOMP = omp_in_parallel();
3885+
if(inOMP) {
3886+
serial_path = 1;
3887+
} else {
3888+
TH_TENSOR_APPLY_REDUCTION_OMP(real, tensor, ||:sum, sum = sum || *tensor_data;);
3889+
}
3890+
#else
3891+
serial_path = 1;
3892+
#endif
3893+
if (serial_path) {
3894+
TH_TENSOR_APPLY(real, tensor, sum = sum || *tensor_data;);
3895+
}
3896+
return (bool)sum;
3897+
}
3898+
3899+
void THTensor_(logicalAnd)(THTensor *r_, THTensor *t, int dimension, int keepdim)
3900+
{
3901+
THLongStorage *dim;
3902+
3903+
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
3904+
dimension + TH_INDEX_BASE);
3905+
3906+
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
3907+
dim = THTensor_(newSizeOf)(t);
3908+
THLongStorage_set(dim, dimension, 1);
3909+
THTensor_(resize)(r_, dim, NULL);
3910+
THLongStorage_free(dim);
3911+
3912+
int serial_path = 0;
3913+
#ifdef _OPENMP
3914+
int inOMP = omp_in_parallel();
3915+
if (inOMP) {
3916+
serial_path = 1;
3917+
} else {
3918+
int r_Contig = THTensor_(isContiguous)(r_);
3919+
real *tp = THTensor_(data)(t);
3920+
real *rp = THTensor_(data)(r_);
3921+
if(r_Contig && (tp != rp)){
3922+
ptrdiff_t iter = 0;
3923+
ptrdiff_t r_Size = THTensor_(nElement)(r_);
3924+
int r_Dim = r_->nDimension;
3925+
#pragma omp parallel for if ( r_Size > TH_OMP_OVERHEAD_THRESHOLD)
3926+
for (iter = 0; iter < r_Size; iter++) {
3927+
int j;
3928+
int64_t quot;
3929+
int64_t rem = iter;
3930+
ptrdiff_t tBasicIndex = 0;
3931+
3932+
for(j = 0; j < r_Dim; ++j) {
3933+
if(j != dimension){
3934+
quot = rem/r_->stride[j];
3935+
rem = rem%r_->stride[j];
3936+
tBasicIndex += quot*t->stride[j];
3937+
}
3938+
}
3939+
real *t_data = tp+tBasicIndex;
3940+
real *r__data = rp+iter;
3941+
*r__data = 1;
3942+
for(j=0; j < t->size[dimension]; ++j) {
3943+
*r__data = *r__data && *(t_data + j*t->stride[dimension]);
3944+
}
3945+
}
3946+
} else {
3947+
serial_path = 1;
3948+
}
3949+
}
3950+
#else
3951+
serial_path = 1;
3952+
#endif
3953+
3954+
if(serial_path) {
3955+
// two implementations optimized for data locality
3956+
if (t->stride[dimension] == 1) {
3957+
TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension,
3958+
accreal prod = 1;
3959+
int64_t i;
3960+
for(i = 0; i < t_size; i++)
3961+
prod = prod && t_data[i*t_stride];
3962+
*r__data = (real)prod;);
3963+
} else {
3964+
THTensor_(fill)(r_, 1);
3965+
THTensor *temp_ = THTensor_(newWithTensor)(r_);
3966+
// r_.expand_as(t)
3967+
temp_->size[dimension] = t->size[dimension];
3968+
temp_->stride[dimension] = 0;
3969+
3970+
TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data && *t_data;);
3971+
THTensor_(free)(temp_);
3972+
}
3973+
}
3974+
if (!keepdim) {
3975+
THTensor_(squeeze1d)(r_, r_, dimension);
3976+
}
3977+
}
3978+
3979+
void THTensor_(logicalAny)(THTensor *r_, THTensor *t, int dimension, int keepdim)
3980+
{
3981+
THLongStorage *dim;
3982+
3983+
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
3984+
dimension + TH_INDEX_BASE);
3985+
3986+
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
3987+
dim = THTensor_(newSizeOf)(t);
3988+
THLongStorage_set(dim, dimension, 1);
3989+
THTensor_(resize)(r_, dim, NULL);
3990+
THLongStorage_free(dim);
3991+
3992+
int serial_path = 0;
3993+
#ifdef _OPENMP
3994+
int inOMP = omp_in_parallel();
3995+
if (inOMP) {
3996+
serial_path = 1;
3997+
} else {
3998+
int r_Contig = THTensor_(isContiguous)(r_);
3999+
real *tp = THTensor_(data)(t);
4000+
real *rp = THTensor_(data)(r_);
4001+
if(r_Contig && (tp != rp)){
4002+
ptrdiff_t iter = 0;
4003+
ptrdiff_t r_Size = THTensor_(nElement)(r_);
4004+
int r_Dim = r_->nDimension;
4005+
#pragma omp parallel for if ( r_Size > TH_OMP_OVERHEAD_THRESHOLD)
4006+
for (iter = 0; iter < r_Size; iter++) {
4007+
int j;
4008+
int64_t quot;
4009+
int64_t rem = iter;
4010+
ptrdiff_t tBasicIndex = 0;
4011+
4012+
for(j = 0; j < r_Dim; ++j) {
4013+
if(j != dimension){
4014+
quot = rem/r_->stride[j];
4015+
rem = rem%r_->stride[j];
4016+
tBasicIndex += quot*t->stride[j];
4017+
}
4018+
}
4019+
real *t_data = tp+tBasicIndex;
4020+
real *r__data = rp+iter;
4021+
*r__data = 0;
4022+
for(j=0; j < t->size[dimension]; ++j) {
4023+
*r__data = *r__data || *(t_data + j*t->stride[dimension]);
4024+
}
4025+
}
4026+
} else {
4027+
serial_path = 1;
4028+
}
38654029
}
4030+
#else
4031+
serial_path = 1;
4032+
#endif
4033+
if (serial_path) {
4034+
// two implementations optimized for data locality
4035+
if (t->stride[dimension] == 1) {
4036+
TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension,
4037+
accreal sum = 0;
4038+
int64_t i;
4039+
for(i = 0; i < t_size; i++)
4040+
sum = sum || t_data[i*t_stride];
4041+
*r__data = (real)sum;);
4042+
} else {
4043+
THTensor_(zero)(r_);
4044+
THTensor *temp_ = THTensor_(newWithTensor)(r_);
4045+
// r_.expand_as(t)
4046+
temp_->size[dimension] = t->size[dimension];
4047+
temp_->stride[dimension] = 0;
38664048

3867-
TENSOR_IMPLEMENT_LOGICAL_SUM(logicalall, &&, 1)
3868-
TENSOR_IMPLEMENT_LOGICAL_SUM(logicalany, ||, 0)
4049+
TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data || *t_data;);
4050+
THTensor_(free)(temp_);
4051+
}
4052+
}
4053+
4054+
if (!keepdim) {
4055+
THTensor_(squeeze1d)(r_, r_, dimension);
4056+
}
4057+
}
38694058

38704059
#endif /* Byte only part */
38714060

aten/src/TH/generic/THTensorMath.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,10 @@ TH_API void THTensor_(dirichlet_grad)(THTensor *self, THTensor *x, THTensor *alp
208208

209209
#if defined(TH_REAL_IS_BYTE)
210210

211-
TH_API int THTensor_(logicalall)(THTensor *self);
212-
TH_API int THTensor_(logicalany)(THTensor *self);
211+
TH_API int THTensor_(logicalAndAll)(THTensor *self);
212+
TH_API int THTensor_(logicalAnyAll)(THTensor *self);
213+
TH_API void THTensor_(logicalAnd)(THTensor *r_, THTensor *t, int dimension, int keepdim);
214+
TH_API void THTensor_(logicalAny)(THTensor *r_, THTensor *t, int dimension, int keepdim);
213215

214216
#endif /* TH_REAL_IS_BYTE */
215217

aten/src/THC/THCTensorMath.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@
4949
#include "generic/THCTensorTopK.h"
5050
#include "THCGenerateAllTypes.h"
5151

52-
THC_API int THCudaByteTensor_logicalall(THCState *state, THCudaByteTensor *self);
53-
THC_API int THCudaByteTensor_logicalany(THCState *state, THCudaByteTensor *self);
52+
THC_API int THCudaByteTensor_logicalAndAll(THCState *state, THCudaByteTensor *self);
53+
THC_API int THCudaByteTensor_logicalAnyAll(THCState *state, THCudaByteTensor *self);
54+
55+
THC_API void THCudaByteTensor_logicalAnd(THCState *state, THCudaByteTensor *self, THCudaByteTensor *src, int dimension, int keepdim);
56+
THC_API void THCudaByteTensor_logicalAny(THCState *state, THCudaByteTensor *self, THCudaByteTensor *src, int dimension, int keepdim);
5457

5558
#endif

aten/src/THC/THCTensorMathReduce.cu

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "THCTensorMathReduce.cuh"
22

33
THC_API int
4-
THCudaByteTensor_logicalall(THCState *state, THCudaByteTensor *self) {
4+
THCudaByteTensor_logicalAndAll(THCState *state, THCudaByteTensor *self) {
55
THCAssertSameGPU(THCudaByteTensor_checkGPU(state, 1, self));
66
unsigned char result;
77
if (!THC_reduceAll(state, self,
@@ -16,7 +16,7 @@ THCudaByteTensor_logicalall(THCState *state, THCudaByteTensor *self) {
1616
}
1717

1818
THC_API int
19-
THCudaByteTensor_logicalany(THCState *state, THCudaByteTensor *self) {
19+
THCudaByteTensor_logicalAnyAll(THCState *state, THCudaByteTensor *self) {
2020
THCAssertSameGPU(THCudaByteTensor_checkGPU(state, 1, self));
2121
unsigned char result;
2222
if (!THC_reduceAll(state, self,
@@ -29,3 +29,35 @@ THCudaByteTensor_logicalany(THCState *state, THCudaByteTensor *self) {
2929

3030
return (int) result;
3131
}
32+
33+
THC_API void
34+
THCudaByteTensor_logicalAnd(THCState* state, THCudaByteTensor *self, THCudaByteTensor *src, int dimension, int keepdim) {
35+
THCAssertSameGPU(THCudaByteTensor_checkGPU(state, 2, self, src));
36+
if (!THC_reduceDim(state, self, src,
37+
thrust::identity<unsigned char>(),
38+
LogicalAll(),
39+
LogicalAll(),
40+
(unsigned char) 1,
41+
dimension,
42+
keepdim)) {
43+
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
44+
}
45+
46+
THCudaCheck(cudaGetLastError());
47+
}
48+
49+
THC_API void
50+
THCudaByteTensor_logicalAny(THCState* state, THCudaByteTensor *self, THCudaByteTensor *src, int dimension, int keepdim) {
51+
THCAssertSameGPU(THCudaByteTensor_checkGPU(state, 2, self, src));
52+
if (!THC_reduceDim(state, self, src,
53+
thrust::identity<unsigned char>(),
54+
LogicalAny(),
55+
LogicalAny(),
56+
(unsigned char) 0,
57+
dimension,
58+
keepdim)) {
59+
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
60+
}
61+
62+
THCudaCheck(cudaGetLastError());
63+
}

test/test_torch.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,23 @@ def test_all_any_empty(self):
776776
self.assertTrue(x.all())
777777
self.assertFalse(x.any())
778778

779+
def test_all_any_with_dim(self):
780+
def test(x):
781+
r1 = x.prod(dim=0, keepdim=False)
782+
r2 = x.all(dim=0, keepdim=False)
783+
self.assertEqual(r1.shape, r2.shape)
784+
self.assertTrue((r1 == r2).all())
785+
786+
r3 = x.sum(dim=1, keepdim=True).clamp(0, 1)
787+
r4 = x.any(dim=1, keepdim=True)
788+
self.assertEqual(r3.shape, r4.shape)
789+
self.assertTrue((r3 == r4).all())
790+
791+
test(torch.ByteTensor([[0, 0, 0],
792+
[0, 0, 1],
793+
[0, 1, 1],
794+
[1, 1, 1]]))
795+
779796
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
780797
def test_all_any_empty_cuda(self):
781798
x = torch.cuda.ByteTensor()

0 commit comments

Comments
 (0)