Skip to content

Commit 8ce1f84

Browse files
committed
Revert "Changed tensor comparison return type from uint8 to bool (#21113)"
This reverts commit 865c7ee.
1 parent 6cf9ed4 commit 8ce1f84

30 files changed

+249
-737
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 6 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -611,34 +611,12 @@
611611
return: argument 0
612612
options:
613613
- cname: ltValue
614-
arguments:
615-
- arg: THBoolTensor* result
616-
output: True
617-
- THTensor* self
618-
- real other
619-
- cname: ltTensor
620-
arguments:
621-
- arg: THBoolTensor* result
622-
output: True
623-
- arg: THTensor* self
624-
broadcast: other fallback
625-
- THTensor* other
626-
]]
627-
[[
628-
name: _th_lt_byte
629-
cpu_bool: True
630-
cuda_bool: True
631-
variants:
632-
- function
633-
return: argument 0
634-
options:
635-
- cname: ltValueByte
636614
arguments:
637615
- arg: THByteTensor* result
638616
output: True
639617
- THTensor* self
640618
- real other
641-
- cname: ltTensorByte
619+
- cname: ltTensor
642620
arguments:
643621
- arg: THByteTensor* result
644622
output: True
@@ -674,34 +652,12 @@
674652
return: argument 0
675653
options:
676654
- cname: gtValue
677-
arguments:
678-
- arg: THBoolTensor* result
679-
output: True
680-
- THTensor* self
681-
- real other
682-
- cname: gtTensor
683-
arguments:
684-
- arg: THBoolTensor* result
685-
output: True
686-
- arg: THTensor* self
687-
broadcast: other fallback
688-
- THTensor* other
689-
]]
690-
[[
691-
name: _th_gt_byte
692-
cpu_bool: True
693-
cuda_bool: True
694-
variants:
695-
- function
696-
return: argument 0
697-
options:
698-
- cname: gtValueByte
699655
arguments:
700656
- arg: THByteTensor* result
701657
output: True
702658
- THTensor* self
703659
- real other
704-
- cname: gtTensorByte
660+
- cname: gtTensor
705661
arguments:
706662
- arg: THByteTensor* result
707663
output: True
@@ -737,34 +693,12 @@
737693
return: argument 0
738694
options:
739695
- cname: leValue
740-
arguments:
741-
- arg: THBoolTensor* result
742-
output: True
743-
- THTensor* self
744-
- real other
745-
- cname: leTensor
746-
arguments:
747-
- arg: THBoolTensor* result
748-
output: True
749-
- arg: THTensor* self
750-
broadcast: other fallback
751-
- THTensor* other
752-
]]
753-
[[
754-
name: _th_le_byte
755-
cpu_bool: True
756-
cuda_bool: True
757-
variants:
758-
- function
759-
return: argument 0
760-
options:
761-
- cname: leValueByte
762696
arguments:
763697
- arg: THByteTensor* result
764698
output: True
765699
- THTensor* self
766700
- real other
767-
- cname: leTensorByte
701+
- cname: leTensor
768702
arguments:
769703
- arg: THByteTensor* result
770704
output: True
@@ -800,34 +734,12 @@
800734
return: argument 0
801735
options:
802736
- cname: geValue
803-
arguments:
804-
- arg: THBoolTensor* result
805-
output: True
806-
- THTensor* self
807-
- real other
808-
- cname: geTensor
809-
arguments:
810-
- arg: THBoolTensor* result
811-
output: True
812-
- arg: THTensor* self
813-
broadcast: other fallback
814-
- THTensor* other
815-
]]
816-
[[
817-
name: _th_ge_byte
818-
cpu_bool: True
819-
cuda_bool: True
820-
variants:
821-
- function
822-
return: argument 0
823-
options:
824-
- cname: geValueByte
825737
arguments:
826738
- arg: THByteTensor* result
827739
output: True
828740
- THTensor* self
829741
- real other
830-
- cname: geTensorByte
742+
- cname: geTensor
831743
arguments:
832744
- arg: THByteTensor* result
833745
output: True
@@ -863,34 +775,12 @@
863775
return: argument 0
864776
options:
865777
- cname: eqValue
866-
arguments:
867-
- arg: THBoolTensor* result
868-
output: True
869-
- THTensor* self
870-
- real other
871-
- cname: eqTensor
872-
arguments:
873-
- arg: THBoolTensor* result
874-
output: True
875-
- arg: THTensor* self
876-
broadcast: other fallback
877-
- THTensor* other
878-
]]
879-
[[
880-
name: _th_eq_byte
881-
cpu_bool: True
882-
cuda_bool: True
883-
variants:
884-
- function
885-
return: argument 0
886-
options:
887-
- cname: eqValueByte
888778
arguments:
889779
- arg: THByteTensor* result
890780
output: True
891781
- THTensor* self
892782
- real other
893-
- cname: eqTensorByte
783+
- cname: eqTensor
894784
arguments:
895785
- arg: THByteTensor* result
896786
output: True
@@ -926,34 +816,12 @@
926816
return: argument 0
927817
options:
928818
- cname: neValue
929-
arguments:
930-
- arg: THBoolTensor* result
931-
output: True
932-
- THTensor* self
933-
- real other
934-
- cname: neTensor
935-
arguments:
936-
- arg: THBoolTensor* result
937-
output: True
938-
- arg: THTensor* self
939-
broadcast: other fallback
940-
- THTensor* other
941-
]]
942-
[[
943-
name: _th_ne_byte
944-
cpu_bool: True
945-
cuda_bool: True
946-
variants:
947-
- function
948-
return: argument 0
949-
options:
950-
- cname: neValueByte
951819
arguments:
952820
- arg: THByteTensor* result
953821
output: True
954822
- THTensor* self
955823
- real other
956-
- cname: neTensorByte
824+
- cname: neTensor
957825
arguments:
958826
- arg: THByteTensor* result
959827
output: True

aten/src/ATen/native/Itertools.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Tensor _triu_mask(int64_t n, int64_t dims, bool diagonal, TensorOptions opt) {
1212
// or i <= j <= k <= ... (depending on diagonal)
1313
Tensor range = at::arange(n, opt.dtype(kLong));
1414
std::vector<Tensor> index_grids = at::meshgrid(std::vector<Tensor>(dims, range));
15-
Tensor mask = at::full(index_grids[0].sizes(), true, opt.dtype(kBool));
15+
Tensor mask = at::ones(index_grids[0].sizes(), opt.dtype(kByte));
1616
if(diagonal) {
1717
for(int64_t i = 0; i < dims - 1; i++) {
1818
mask *= index_grids[i] <= index_grids[i+1];

aten/src/ATen/native/LegacyDefinitions.cpp

Lines changed: 0 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -64,124 +64,4 @@ Tensor gather_cpu(const Tensor & self, int64_t dim, const Tensor & index, bool s
6464
return legacy::cpu::_th_gather(self, dim, index);
6565
}
6666

67-
Tensor & lt_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
68-
if (result.dtype() == at::ScalarType::Byte) {
69-
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
70-
"please use 'out' parameter with dtype torch.bool instead.");
71-
return legacy::cpu::_th_lt_byte_out(result, self, other);
72-
} else {
73-
return legacy::cpu::_th_lt_out(result, self, other);
74-
}
75-
}
76-
77-
Tensor & lt_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
78-
if (result.dtype() == at::ScalarType::Byte) {
79-
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
80-
"please use 'out' parameter with dtype torch.bool instead.");
81-
return legacy::cpu::_th_lt_byte_out(result, self, value);
82-
} else {
83-
return legacy::cpu::_th_lt_out(result, self, value);
84-
}
85-
}
86-
87-
Tensor & le_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
88-
if (result.dtype() == at::ScalarType::Byte) {
89-
AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
90-
"please use 'out' parameter with dtype torch.bool instead.");
91-
return legacy::cpu::_th_le_byte_out(result, self, other);
92-
} else {
93-
return legacy::cpu::_th_le_out(result, self, other);
94-
}
95-
}
96-
97-
Tensor & le_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
98-
if (result.dtype() == at::ScalarType::Byte) {
99-
AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
100-
"please use 'out' parameter with dtype torch.bool instead.");
101-
return legacy::cpu::_th_le_byte_out(result, self, value);
102-
} else {
103-
return legacy::cpu::_th_le_out(result, self, value);
104-
}
105-
}
106-
107-
Tensor & gt_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
108-
if (result.dtype() == at::ScalarType::Byte) {
109-
AT_WARN("torch.gt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
110-
"please use 'out' parameter with dtype torch.bool instead.");
111-
return legacy::cpu::_th_gt_byte_out(result, self, other);
112-
} else {
113-
return legacy::cpu::_th_gt_out(result, self, other);
114-
}
115-
}
116-
117-
Tensor & gt_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
118-
if (result.dtype() == at::ScalarType::Byte) {
119-
AT_WARN("torch.gt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
120-
"please use 'out' parameter with dtype torch.bool instead.");
121-
return legacy::cpu::_th_gt_byte_out(result, self, value);
122-
} else {
123-
return legacy::cpu::_th_gt_out(result, self, value);
124-
}
125-
}
126-
127-
Tensor & ge_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
128-
if (result.dtype() == at::ScalarType::Byte) {
129-
AT_WARN("torch.ge received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
130-
"please use 'out' parameter with dtype torch.bool instead.");
131-
return legacy::cpu::_th_ge_byte_out(result, self, other);
132-
} else {
133-
return legacy::cpu::_th_ge_out(result, self, other);
134-
}
135-
}
136-
137-
Tensor & ge_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
138-
if (result.dtype() == at::ScalarType::Byte) {
139-
AT_WARN("torch.ge received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
140-
"please use 'out' parameter with dtype torch.bool instead.");
141-
return legacy::cpu::_th_ge_byte_out(result, self, value);
142-
} else {
143-
return legacy::cpu::_th_ge_out(result, self, value);
144-
}
145-
}
146-
147-
Tensor & eq_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
148-
if (result.dtype() == at::ScalarType::Byte) {
149-
AT_WARN("torch.eq received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
150-
"please use 'out' parameter with dtype torch.bool instead.");
151-
return legacy::cpu::_th_eq_byte_out(result, self, other);
152-
} else {
153-
return legacy::cpu::_th_eq_out(result, self, other);
154-
}
155-
}
156-
157-
Tensor & eq_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
158-
if (result.dtype() == at::ScalarType::Byte) {
159-
AT_WARN("torch.eq received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
160-
"please use 'out' parameter with dtype torch.bool instead.");
161-
return legacy::cpu::_th_eq_byte_out(result, self, value);
162-
} else {
163-
return legacy::cpu::_th_eq_out(result, self, value);
164-
}
165-
}
166-
167-
Tensor & ne_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
168-
if (result.dtype() == at::ScalarType::Byte) {
169-
AT_WARN("torch.ne received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
170-
"please use 'out' parameter with dtype torch.bool instead.");
171-
return legacy::cpu::_th_ne_byte_out(result, self, other);
172-
} else {
173-
return legacy::cpu::_th_ne_out(result, self, other);
174-
}
175-
}
176-
177-
Tensor & ne_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
178-
if (result.dtype() == at::ScalarType::Byte) {
179-
AT_WARN("torch.ne received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
180-
"please use 'out' parameter with dtype torch.bool instead.");
181-
return legacy::cpu::_th_ne_byte_out(result, self, value);
182-
} else {
183-
return legacy::cpu::_th_ne_out(result, self, value);
184-
}
185-
}
186-
18767
}} // namespace at::native

aten/src/ATen/native/cpu/TensorCompareKernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ static void max_kernel_impl(
8383
Tensor& max_indices,
8484
const Tensor& self,
8585
c10::optional<int64_t> dim) {
86-
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "max", [&] {
86+
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "max", [&] {
8787
Reduction<scalar_t, int64_t>::apply(max, max_indices, self, dim, true);
8888
});
8989
}
@@ -93,7 +93,7 @@ static void min_kernel_impl(
9393
Tensor& min_indices,
9494
const Tensor& self,
9595
c10::optional<int64_t> dim) {
96-
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "min", [&] {
96+
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "min", [&] {
9797
Reduction<scalar_t, int64_t>::apply(min, min_indices, self, dim, false);
9898
});
9999
}

0 commit comments

Comments
 (0)