Skip to content

Commit e3c0b65

Browse files
committed
Changed tensor comparison return type from uint8 to bool (#21113)
Summary: Pull Request resolved: #21113 ghimport-source-id: 9c4ba63 Test Plan: Imported from OSS Differential Revision: D15552204 Pulled By: izdeby fbshipit-source-id: a608213668649d058e22b510d7755cb99e7d0037
1 parent 1ac19ea commit e3c0b65

30 files changed

+737
-249
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 138 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -612,11 +612,33 @@
612612
options:
613613
- cname: ltValue
614614
arguments:
615-
- arg: THByteTensor* result
615+
- arg: THBoolTensor* result
616616
output: True
617617
- THTensor* self
618618
- real other
619619
- cname: ltTensor
620+
arguments:
621+
- arg: THBoolTensor* result
622+
output: True
623+
- arg: THTensor* self
624+
broadcast: other fallback
625+
- THTensor* other
626+
]]
627+
[[
628+
name: _th_lt_byte
629+
cpu_bool: True
630+
cuda_bool: True
631+
variants:
632+
- function
633+
return: argument 0
634+
options:
635+
- cname: ltValueByte
636+
arguments:
637+
- arg: THByteTensor* result
638+
output: True
639+
- THTensor* self
640+
- real other
641+
- cname: ltTensorByte
620642
arguments:
621643
- arg: THByteTensor* result
622644
output: True
@@ -653,11 +675,33 @@
653675
options:
654676
- cname: gtValue
655677
arguments:
656-
- arg: THByteTensor* result
678+
- arg: THBoolTensor* result
657679
output: True
658680
- THTensor* self
659681
- real other
660682
- cname: gtTensor
683+
arguments:
684+
- arg: THBoolTensor* result
685+
output: True
686+
- arg: THTensor* self
687+
broadcast: other fallback
688+
- THTensor* other
689+
]]
690+
[[
691+
name: _th_gt_byte
692+
cpu_bool: True
693+
cuda_bool: True
694+
variants:
695+
- function
696+
return: argument 0
697+
options:
698+
- cname: gtValueByte
699+
arguments:
700+
- arg: THByteTensor* result
701+
output: True
702+
- THTensor* self
703+
- real other
704+
- cname: gtTensorByte
661705
arguments:
662706
- arg: THByteTensor* result
663707
output: True
@@ -694,11 +738,33 @@
694738
options:
695739
- cname: leValue
696740
arguments:
697-
- arg: THByteTensor* result
741+
- arg: THBoolTensor* result
698742
output: True
699743
- THTensor* self
700744
- real other
701745
- cname: leTensor
746+
arguments:
747+
- arg: THBoolTensor* result
748+
output: True
749+
- arg: THTensor* self
750+
broadcast: other fallback
751+
- THTensor* other
752+
]]
753+
[[
754+
name: _th_le_byte
755+
cpu_bool: True
756+
cuda_bool: True
757+
variants:
758+
- function
759+
return: argument 0
760+
options:
761+
- cname: leValueByte
762+
arguments:
763+
- arg: THByteTensor* result
764+
output: True
765+
- THTensor* self
766+
- real other
767+
- cname: leTensorByte
702768
arguments:
703769
- arg: THByteTensor* result
704770
output: True
@@ -735,11 +801,33 @@
735801
options:
736802
- cname: geValue
737803
arguments:
738-
- arg: THByteTensor* result
804+
- arg: THBoolTensor* result
739805
output: True
740806
- THTensor* self
741807
- real other
742808
- cname: geTensor
809+
arguments:
810+
- arg: THBoolTensor* result
811+
output: True
812+
- arg: THTensor* self
813+
broadcast: other fallback
814+
- THTensor* other
815+
]]
816+
[[
817+
name: _th_ge_byte
818+
cpu_bool: True
819+
cuda_bool: True
820+
variants:
821+
- function
822+
return: argument 0
823+
options:
824+
- cname: geValueByte
825+
arguments:
826+
- arg: THByteTensor* result
827+
output: True
828+
- THTensor* self
829+
- real other
830+
- cname: geTensorByte
743831
arguments:
744832
- arg: THByteTensor* result
745833
output: True
@@ -776,11 +864,33 @@
776864
options:
777865
- cname: eqValue
778866
arguments:
779-
- arg: THByteTensor* result
867+
- arg: THBoolTensor* result
780868
output: True
781869
- THTensor* self
782870
- real other
783871
- cname: eqTensor
872+
arguments:
873+
- arg: THBoolTensor* result
874+
output: True
875+
- arg: THTensor* self
876+
broadcast: other fallback
877+
- THTensor* other
878+
]]
879+
[[
880+
name: _th_eq_byte
881+
cpu_bool: True
882+
cuda_bool: True
883+
variants:
884+
- function
885+
return: argument 0
886+
options:
887+
- cname: eqValueByte
888+
arguments:
889+
- arg: THByteTensor* result
890+
output: True
891+
- THTensor* self
892+
- real other
893+
- cname: eqTensorByte
784894
arguments:
785895
- arg: THByteTensor* result
786896
output: True
@@ -817,11 +927,33 @@
817927
options:
818928
- cname: neValue
819929
arguments:
820-
- arg: THByteTensor* result
930+
- arg: THBoolTensor* result
821931
output: True
822932
- THTensor* self
823933
- real other
824934
- cname: neTensor
935+
arguments:
936+
- arg: THBoolTensor* result
937+
output: True
938+
- arg: THTensor* self
939+
broadcast: other fallback
940+
- THTensor* other
941+
]]
942+
[[
943+
name: _th_ne_byte
944+
cpu_bool: True
945+
cuda_bool: True
946+
variants:
947+
- function
948+
return: argument 0
949+
options:
950+
- cname: neValueByte
951+
arguments:
952+
- arg: THByteTensor* result
953+
output: True
954+
- THTensor* self
955+
- real other
956+
- cname: neTensorByte
825957
arguments:
826958
- arg: THByteTensor* result
827959
output: True

aten/src/ATen/native/Itertools.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Tensor _triu_mask(int64_t n, int64_t dims, bool diagonal, TensorOptions opt) {
1212
// or i <= j <= k <= ... (depending on diagonal)
1313
Tensor range = at::arange(n, opt.dtype(kLong));
1414
std::vector<Tensor> index_grids = at::meshgrid(std::vector<Tensor>(dims, range));
15-
Tensor mask = at::ones(index_grids[0].sizes(), opt.dtype(kByte));
15+
Tensor mask = at::full(index_grids[0].sizes(), true, opt.dtype(kBool));
1616
if(diagonal) {
1717
for(int64_t i = 0; i < dims - 1; i++) {
1818
mask *= index_grids[i] <= index_grids[i+1];

aten/src/ATen/native/LegacyDefinitions.cpp

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,124 @@ Tensor gather_cpu(const Tensor & self, int64_t dim, const Tensor & index, bool s
6464
return legacy::cpu::_th_gather(self, dim, index);
6565
}
6666

67+
Tensor & lt_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
68+
if (result.dtype() == at::ScalarType::Byte) {
69+
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
70+
"please use 'out' parameter with dtype torch.bool instead.");
71+
return legacy::cpu::_th_lt_byte_out(result, self, other);
72+
} else {
73+
return legacy::cpu::_th_lt_out(result, self, other);
74+
}
75+
}
76+
77+
Tensor & lt_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
78+
if (result.dtype() == at::ScalarType::Byte) {
79+
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
80+
"please use 'out' parameter with dtype torch.bool instead.");
81+
return legacy::cpu::_th_lt_byte_out(result, self, value);
82+
} else {
83+
return legacy::cpu::_th_lt_out(result, self, value);
84+
}
85+
}
86+
87+
Tensor & le_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
88+
if (result.dtype() == at::ScalarType::Byte) {
89+
AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
90+
"please use 'out' parameter with dtype torch.bool instead.");
91+
return legacy::cpu::_th_le_byte_out(result, self, other);
92+
} else {
93+
return legacy::cpu::_th_le_out(result, self, other);
94+
}
95+
}
96+
97+
Tensor & le_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
98+
if (result.dtype() == at::ScalarType::Byte) {
99+
AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
100+
"please use 'out' parameter with dtype torch.bool instead.");
101+
return legacy::cpu::_th_le_byte_out(result, self, value);
102+
} else {
103+
return legacy::cpu::_th_le_out(result, self, value);
104+
}
105+
}
106+
107+
Tensor & gt_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
108+
if (result.dtype() == at::ScalarType::Byte) {
109+
AT_WARN("torch.gt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
110+
"please use 'out' parameter with dtype torch.bool instead.");
111+
return legacy::cpu::_th_gt_byte_out(result, self, other);
112+
} else {
113+
return legacy::cpu::_th_gt_out(result, self, other);
114+
}
115+
}
116+
117+
Tensor & gt_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
118+
if (result.dtype() == at::ScalarType::Byte) {
119+
AT_WARN("torch.gt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
120+
"please use 'out' parameter with dtype torch.bool instead.");
121+
return legacy::cpu::_th_gt_byte_out(result, self, value);
122+
} else {
123+
return legacy::cpu::_th_gt_out(result, self, value);
124+
}
125+
}
126+
127+
Tensor & ge_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
128+
if (result.dtype() == at::ScalarType::Byte) {
129+
AT_WARN("torch.ge received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
130+
"please use 'out' parameter with dtype torch.bool instead.");
131+
return legacy::cpu::_th_ge_byte_out(result, self, other);
132+
} else {
133+
return legacy::cpu::_th_ge_out(result, self, other);
134+
}
135+
}
136+
137+
Tensor & ge_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
138+
if (result.dtype() == at::ScalarType::Byte) {
139+
AT_WARN("torch.ge received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
140+
"please use 'out' parameter with dtype torch.bool instead.");
141+
return legacy::cpu::_th_ge_byte_out(result, self, value);
142+
} else {
143+
return legacy::cpu::_th_ge_out(result, self, value);
144+
}
145+
}
146+
147+
Tensor & eq_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
148+
if (result.dtype() == at::ScalarType::Byte) {
149+
AT_WARN("torch.eq received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
150+
"please use 'out' parameter with dtype torch.bool instead.");
151+
return legacy::cpu::_th_eq_byte_out(result, self, other);
152+
} else {
153+
return legacy::cpu::_th_eq_out(result, self, other);
154+
}
155+
}
156+
157+
Tensor & eq_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
158+
if (result.dtype() == at::ScalarType::Byte) {
159+
AT_WARN("torch.eq received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
160+
"please use 'out' parameter with dtype torch.bool instead.");
161+
return legacy::cpu::_th_eq_byte_out(result, self, value);
162+
} else {
163+
return legacy::cpu::_th_eq_out(result, self, value);
164+
}
165+
}
166+
167+
Tensor & ne_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
168+
if (result.dtype() == at::ScalarType::Byte) {
169+
AT_WARN("torch.ne received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
170+
"please use 'out' parameter with dtype torch.bool instead.");
171+
return legacy::cpu::_th_ne_byte_out(result, self, other);
172+
} else {
173+
return legacy::cpu::_th_ne_out(result, self, other);
174+
}
175+
}
176+
177+
Tensor & ne_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
178+
if (result.dtype() == at::ScalarType::Byte) {
179+
AT_WARN("torch.ne received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
180+
"please use 'out' parameter with dtype torch.bool instead.");
181+
return legacy::cpu::_th_ne_byte_out(result, self, value);
182+
} else {
183+
return legacy::cpu::_th_ne_out(result, self, value);
184+
}
185+
}
186+
67187
}} // namespace at::native

aten/src/ATen/native/cpu/TensorCompareKernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ static void max_kernel_impl(
8383
Tensor& max_indices,
8484
const Tensor& self,
8585
c10::optional<int64_t> dim) {
86-
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "max", [&] {
86+
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "max", [&] {
8787
Reduction<scalar_t, int64_t>::apply(max, max_indices, self, dim, true);
8888
});
8989
}
@@ -93,7 +93,7 @@ static void min_kernel_impl(
9393
Tensor& min_indices,
9494
const Tensor& self,
9595
c10::optional<int64_t> dim) {
96-
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "min", [&] {
96+
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "min", [&] {
9797
Reduction<scalar_t, int64_t>::apply(min, min_indices, self, dim, false);
9898
});
9999
}

0 commit comments

Comments
 (0)