Skip to content

Commit 9c4ba63

Browse files
committed
Changed tensor comparison return type from uint8 to bool
gh-metadata: pytorch pytorch 21113 gh/izdeby/7/head
1 parent b0bd875 commit 9c4ba63

24 files changed

+91
-88
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -606,13 +606,13 @@
606606
options:
607607
- cname: ltValue
608608
arguments:
609-
- arg: THByteTensor* result
609+
- arg: THBoolTensor* result
610610
output: True
611611
- THTensor* self
612612
- real other
613613
- cname: ltTensor
614614
arguments:
615-
- arg: THByteTensor* result
615+
- arg: THBoolTensor* result
616616
output: True
617617
- arg: THTensor* self
618618
broadcast: other fallback
@@ -647,13 +647,13 @@
647647
options:
648648
- cname: gtValue
649649
arguments:
650-
- arg: THByteTensor* result
650+
- arg: THBoolTensor* result
651651
output: True
652652
- THTensor* self
653653
- real other
654654
- cname: gtTensor
655655
arguments:
656-
- arg: THByteTensor* result
656+
- arg: THBoolTensor* result
657657
output: True
658658
- arg: THTensor* self
659659
broadcast: other fallback
@@ -688,13 +688,13 @@
688688
options:
689689
- cname: leValue
690690
arguments:
691-
- arg: THByteTensor* result
691+
- arg: THBoolTensor* result
692692
output: True
693693
- THTensor* self
694694
- real other
695695
- cname: leTensor
696696
arguments:
697-
- arg: THByteTensor* result
697+
- arg: THBoolTensor* result
698698
output: True
699699
- arg: THTensor* self
700700
broadcast: other fallback
@@ -729,13 +729,13 @@
729729
options:
730730
- cname: geValue
731731
arguments:
732-
- arg: THByteTensor* result
732+
- arg: THBoolTensor* result
733733
output: True
734734
- THTensor* self
735735
- real other
736736
- cname: geTensor
737737
arguments:
738-
- arg: THByteTensor* result
738+
- arg: THBoolTensor* result
739739
output: True
740740
- arg: THTensor* self
741741
broadcast: other fallback
@@ -770,13 +770,13 @@
770770
options:
771771
- cname: eqValue
772772
arguments:
773-
- arg: THByteTensor* result
773+
- arg: THBoolTensor* result
774774
output: True
775775
- THTensor* self
776776
- real other
777777
- cname: eqTensor
778778
arguments:
779-
- arg: THByteTensor* result
779+
- arg: THBoolTensor* result
780780
output: True
781781
- arg: THTensor* self
782782
broadcast: other fallback
@@ -811,13 +811,13 @@
811811
options:
812812
- cname: neValue
813813
arguments:
814-
- arg: THByteTensor* result
814+
- arg: THBoolTensor* result
815815
output: True
816816
- THTensor* self
817817
- real other
818818
- cname: neTensor
819819
arguments:
820-
- arg: THByteTensor* result
820+
- arg: THBoolTensor* result
821821
output: True
822822
- arg: THTensor* self
823823
broadcast: other fallback

aten/src/ATen/native/Itertools.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Tensor _triu_mask(int64_t n, int64_t dims, bool diagonal, TensorOptions opt) {
1212
// or i <= j <= k <= ... (depending on diagonal)
1313
Tensor range = at::arange(n, opt.dtype(kLong));
1414
std::vector<Tensor> index_grids = at::meshgrid(std::vector<Tensor>(dims, range));
15-
Tensor mask = at::ones(index_grids[0].sizes(), opt.dtype(kByte));
15+
Tensor mask = at::ones(index_grids[0].sizes(), opt.dtype(kBool));
1616
if(diagonal) {
1717
for(int64_t i = 0; i < dims - 1; i++) {
1818
mask *= index_grids[i] <= index_grids[i+1];

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ Tensor matrix_rank(const Tensor& self, bool symmetric) {
136136

137137
Tensor S = _matrix_rank_helper(self, symmetric);
138138
double tol = _get_epsilon(self.scalar_type()) * std::max(self.size(0), self.size(1));
139-
return (S > S.max().mul_(tol)).sum();
139+
return (S > S.max().mul_(tol)).sum(ScalarType::Long);
140140
}
141141

142142
static void check_1d(const Tensor& t, const char* arg, const char* fn) {

aten/src/ATen/native/cpu/TensorCompareKernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ static void max_kernel_impl(
8383
Tensor& max_indices,
8484
const Tensor& self,
8585
c10::optional<int64_t> dim) {
86-
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "max", [&] {
86+
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "max", [&] {
8787
Reduction<scalar_t, int64_t>::apply(max, max_indices, self, dim, true);
8888
});
8989
}
@@ -93,7 +93,7 @@ static void min_kernel_impl(
9393
Tensor& min_indices,
9494
const Tensor& self,
9595
c10::optional<int64_t> dim) {
96-
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "min", [&] {
96+
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "min", [&] {
9797
Reduction<scalar_t, int64_t>::apply(min, min_indices, self, dim, false);
9898
});
9999
}

aten/src/ATen/test/atest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ TEST(atest, atest) {
3636
float b = a.to<float>();
3737
ASSERT_EQ(b, 4);
3838

39-
foo = (foo * foo) == (foo.pow(3));
39+
foo = ((foo * foo) == (foo.pow(3))).to(kByte);
4040
foo = 2 + (foo + 1);
4141
// foo = foo[3];
4242
auto foo_v = foo.accessor<uint8_t, 2>();

aten/src/TH/generic/THTensorMoreMath.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,30 @@
66
#include <ATen/CPUGenerator.h>
77
#include <ATen/Utils.h>
88

9-
#define TENSOR_IMPLEMENT_LOGICAL(NAME,OP) \
10-
void THTensor_(NAME##Value)(THByteTensor *r_, THTensor* t, scalar_t value) \
11-
{ \
12-
THByteTensor_resizeNd(r_, t->dim(), THTensor_getSizePtr(t), NULL); \
13-
TH_TENSOR_APPLY2(unsigned char, r_, scalar_t, t, \
14-
*r__data = (*t_data OP value) ? 1 : 0;); \
15-
} \
16-
void THTensor_(NAME##ValueT)(THTensor* r_, THTensor* t, scalar_t value) \
17-
{ \
18-
THTensor_(resizeNd)(r_, t->dim(), THTensor_getSizePtr(t), NULL); \
19-
TH_TENSOR_APPLY2(scalar_t, r_, scalar_t, t, \
20-
*r__data = (*t_data OP value) ? 1 : 0;); \
21-
} \
22-
void THTensor_(NAME##Tensor)(THByteTensor *r_, THTensor *ta, THTensor *tb) \
23-
{ \
24-
THByteTensor_resizeNd(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \
25-
TH_TENSOR_APPLY3(unsigned char, r_, scalar_t, ta, scalar_t, tb, \
26-
*r__data = (*ta_data OP *tb_data) ? 1 : 0;); \
27-
} \
28-
void THTensor_(NAME##TensorT)(THTensor *r_, THTensor *ta, THTensor *tb) \
29-
{ \
30-
THTensor_(resizeNd)(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \
31-
TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, ta, scalar_t, tb, \
32-
*r__data = (*ta_data OP *tb_data) ? 1 : 0;); \
9+
#define TENSOR_IMPLEMENT_LOGICAL(NAME,OP) \
10+
void THTensor_(NAME##Value)(THBoolTensor *r_, THTensor* t, scalar_t value) \
11+
{ \
12+
THBoolTensor_resizeNd(r_, t->dim(), THTensor_getSizePtr(t), NULL); \
13+
TH_TENSOR_APPLY2(bool, r_, scalar_t, t, \
14+
*r__data = (*t_data OP value) ? 1 : 0;); \
15+
} \
16+
void THTensor_(NAME##ValueT)(THTensor* r_, THTensor* t, scalar_t value) \
17+
{ \
18+
THTensor_(resizeNd)(r_, t->dim(), THTensor_getSizePtr(t), NULL); \
19+
TH_TENSOR_APPLY2(scalar_t, r_, scalar_t, t, \
20+
*r__data = (*t_data OP value) ? 1 : 0;); \
21+
} \
22+
void THTensor_(NAME##Tensor)(THBoolTensor *r_, THTensor *ta, THTensor *tb) \
23+
{ \
24+
THBoolTensor_resizeNd(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \
25+
TH_TENSOR_APPLY3(bool, r_, scalar_t, ta, scalar_t, tb, \
26+
*r__data = (*ta_data OP *tb_data) ? 1 : 0;); \
27+
} \
28+
void THTensor_(NAME##TensorT)(THTensor *r_, THTensor *ta, THTensor *tb) \
29+
{ \
30+
THTensor_(resizeNd)(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \
31+
TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, ta, scalar_t, tb, \
32+
*r__data = (*ta_data OP *tb_data) ? 1 : 0;); \
3333
}
3434

3535
TENSOR_IMPLEMENT_LOGICAL(lt,<)

aten/src/THC/generic/THCTensorMathCompare.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,49 +5,49 @@
55
void THCTensor_(ltValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value)
66
{
77
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
8-
THC_logicalValue<uint8_t, scalar_t>(state, self_, src,
8+
THC_logicalValue<bool, scalar_t>(state, self_, src,
99
TensorLTValueOp<scalar_t,
10-
unsigned char>(value));
10+
bool>(value));
1111
}
1212

1313
void THCTensor_(gtValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value)
1414
{
1515
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
16-
THC_logicalValue<uint8_t, scalar_t>(state, self_, src,
16+
THC_logicalValue<bool, scalar_t>(state, self_, src,
1717
TensorGTValueOp<scalar_t,
18-
unsigned char>(value));
18+
bool>(value));
1919
}
2020

2121
void THCTensor_(leValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value)
2222
{
2323
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
24-
THC_logicalValue<uint8_t, scalar_t>(state, self_, src,
24+
THC_logicalValue<bool, scalar_t>(state, self_, src,
2525
TensorLEValueOp<scalar_t,
26-
unsigned char>(value));
26+
bool>(value));
2727
}
2828

2929
void THCTensor_(geValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value)
3030
{
3131
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
32-
THC_logicalValue<uint8_t, scalar_t>(state, self_, src,
32+
THC_logicalValue<bool, scalar_t>(state, self_, src,
3333
TensorGEValueOp<scalar_t,
34-
unsigned char>(value));
34+
bool>(value));
3535
}
3636

3737
void THCTensor_(eqValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value)
3838
{
3939
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
40-
THC_logicalValue<uint8_t, scalar_t>(state, self_, src,
40+
THC_logicalValue<bool, scalar_t>(state, self_, src,
4141
TensorEQValueOp<scalar_t,
42-
unsigned char>(value));
42+
bool>(value));
4343
}
4444

4545
void THCTensor_(neValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value)
4646
{
4747
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
48-
THC_logicalValue<uint8_t, scalar_t>(state, self_, src,
48+
THC_logicalValue<bool, scalar_t>(state, self_, src,
4949
TensorNEValueOp<scalar_t,
50-
unsigned char>(value));
50+
bool>(value));
5151
}
5252

5353
void THCTensor_(ltValueT)(THCState *state, THCTensor *self_, THCTensor *src, scalar_t value)

aten/src/THC/generic/THCTensorMathCompareT.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,49 +5,49 @@
55
void THCTensor_(ltTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
66
{
77
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
8-
THC_logicalTensor<uint8_t, scalar_t>(state, self_, src1, src2,
8+
THC_logicalTensor<bool, scalar_t>(state, self_, src1, src2,
99
TensorLTOp<scalar_t,
10-
unsigned char>());
10+
bool>());
1111
}
1212

1313
void THCTensor_(gtTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
1414
{
1515
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
16-
THC_logicalTensor<uint8_t, scalar_t>(state, self_, src1, src2,
16+
THC_logicalTensor<bool, scalar_t>(state, self_, src1, src2,
1717
TensorGTOp<scalar_t,
18-
unsigned char>());
18+
bool>());
1919
}
2020

2121
void THCTensor_(leTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
2222
{
2323
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
24-
THC_logicalTensor<uint8_t, scalar_t>(state, self_, src1, src2,
24+
THC_logicalTensor<bool, scalar_t>(state, self_, src1, src2,
2525
TensorLEOp<scalar_t,
26-
unsigned char>());
26+
bool>());
2727
}
2828

2929
void THCTensor_(geTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
3030
{
3131
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
32-
THC_logicalTensor<uint8_t, scalar_t>(state, self_, src1, src2,
32+
THC_logicalTensor<bool, scalar_t>(state, self_, src1, src2,
3333
TensorGEOp<scalar_t,
34-
unsigned char>());
34+
bool>());
3535
}
3636

3737
void THCTensor_(eqTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
3838
{
3939
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
40-
THC_logicalTensor<uint8_t, scalar_t>(state, self_, src1, src2,
40+
THC_logicalTensor<bool, scalar_t>(state, self_, src1, src2,
4141
TensorEQOp<scalar_t,
42-
unsigned char>());
42+
bool>());
4343
}
4444

4545
void THCTensor_(neTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
4646
{
4747
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
48-
THC_logicalTensor<uint8_t, scalar_t>(state, self_, src1, src2,
48+
THC_logicalTensor<bool, scalar_t>(state, self_, src1, src2,
4949
TensorNEOp<scalar_t,
50-
unsigned char>());
50+
bool>());
5151
}
5252

5353
void THCTensor_(ltTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2)

c10/core/ScalarType.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ namespace c10 {
109109
_(double, Double, d) \
110110
_(c10::qint8, QInt8, i) \
111111
_(c10::quint8, QUInt8, i) \
112-
_(c10::qint32, QInt32, i)
112+
_(c10::qint32, QInt32, i) \
113+
_(bool, Bool, i)
113114

114115
#define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF_AND_QINT(_) \
115116
_(uint8_t, Byte, i) \

test/expect/TestScript.test_listconstruct_erasure.expect

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ModelProto {
1212
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
1313
Node {type: "Less", inputs: [0,1], outputs: [2], attributes: []},
1414
Node {type: "Cast", inputs: [2], outputs: [3], attributes: [{ name: 'to', type: int, value: 2}]},
15-
Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 2}]},
15+
Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 9}]},
1616
Node {type: "ATen", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}]}
1717
]
1818
}

0 commit comments

Comments
 (0)