Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 138 additions & 6 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -612,11 +612,33 @@
options:
- cname: ltValue
arguments:
- arg: THByteTensor* result
- arg: THBoolTensor* result
output: True
- THTensor* self
- real other
- cname: ltTensor
arguments:
- arg: THBoolTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_lt_byte
cpu_bool: True
cuda_bool: True
variants:
- function
return: argument 0
options:
- cname: ltValueByte
arguments:
- arg: THByteTensor* result
output: True
- THTensor* self
- real other
- cname: ltTensorByte
arguments:
- arg: THByteTensor* result
output: True
Expand Down Expand Up @@ -653,11 +675,33 @@
options:
- cname: gtValue
arguments:
- arg: THByteTensor* result
- arg: THBoolTensor* result
output: True
- THTensor* self
- real other
- cname: gtTensor
arguments:
- arg: THBoolTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_gt_byte
cpu_bool: True
cuda_bool: True
variants:
- function
return: argument 0
options:
- cname: gtValueByte
arguments:
- arg: THByteTensor* result
output: True
- THTensor* self
- real other
- cname: gtTensorByte
arguments:
- arg: THByteTensor* result
output: True
Expand Down Expand Up @@ -694,11 +738,33 @@
options:
- cname: leValue
arguments:
- arg: THByteTensor* result
- arg: THBoolTensor* result
output: True
- THTensor* self
- real other
- cname: leTensor
arguments:
- arg: THBoolTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_le_byte
cpu_bool: True
cuda_bool: True
variants:
- function
return: argument 0
options:
- cname: leValueByte
arguments:
- arg: THByteTensor* result
output: True
- THTensor* self
- real other
- cname: leTensorByte
arguments:
- arg: THByteTensor* result
output: True
Expand Down Expand Up @@ -735,11 +801,33 @@
options:
- cname: geValue
arguments:
- arg: THByteTensor* result
- arg: THBoolTensor* result
output: True
- THTensor* self
- real other
- cname: geTensor
arguments:
- arg: THBoolTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_ge_byte
cpu_bool: True
cuda_bool: True
variants:
- function
return: argument 0
options:
- cname: geValueByte
arguments:
- arg: THByteTensor* result
output: True
- THTensor* self
- real other
- cname: geTensorByte
arguments:
- arg: THByteTensor* result
output: True
Expand Down Expand Up @@ -776,11 +864,33 @@
options:
- cname: eqValue
arguments:
- arg: THByteTensor* result
- arg: THBoolTensor* result
output: True
- THTensor* self
- real other
- cname: eqTensor
arguments:
- arg: THBoolTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_eq_byte
cpu_bool: True
cuda_bool: True
variants:
- function
return: argument 0
options:
- cname: eqValueByte
arguments:
- arg: THByteTensor* result
output: True
- THTensor* self
- real other
- cname: eqTensorByte
arguments:
- arg: THByteTensor* result
output: True
Expand Down Expand Up @@ -817,11 +927,33 @@
options:
- cname: neValue
arguments:
- arg: THByteTensor* result
- arg: THBoolTensor* result
output: True
- THTensor* self
- real other
- cname: neTensor
arguments:
- arg: THBoolTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_ne_byte
cpu_bool: True
cuda_bool: True
variants:
- function
return: argument 0
options:
- cname: neValueByte
arguments:
- arg: THByteTensor* result
output: True
- THTensor* self
- real other
- cname: neTensorByte
arguments:
- arg: THByteTensor* result
output: True
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Itertools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Tensor _triu_mask(int64_t n, int64_t dims, bool diagonal, TensorOptions opt) {
// or i <= j <= k <= ... (depending on diagonal)
Tensor range = at::arange(n, opt.dtype(kLong));
std::vector<Tensor> index_grids = at::meshgrid(std::vector<Tensor>(dims, range));
Tensor mask = at::ones(index_grids[0].sizes(), opt.dtype(kByte));
Tensor mask = at::full(index_grids[0].sizes(), true, opt.dtype(kBool));
if(diagonal) {
for(int64_t i = 0; i < dims - 1; i++) {
mask *= index_grids[i] <= index_grids[i+1];
Expand Down
120 changes: 120 additions & 0 deletions aten/src/ATen/native/LegacyDefinitions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,124 @@ Tensor gather_cpu(const Tensor & self, int64_t dim, const Tensor & index, bool s
return legacy::cpu::_th_gather(self, dim, index);
}

Tensor & lt_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_lt_byte_out(result, self, other);
} else {
return legacy::cpu::_th_lt_out(result, self, other);
}
}

Tensor & lt_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_lt_byte_out(result, self, value);
} else {
return legacy::cpu::_th_lt_out(result, self, value);
}
}

Tensor & le_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_le_byte_out(result, self, other);
} else {
return legacy::cpu::_th_le_out(result, self, other);
}
}

Tensor & le_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_le_byte_out(result, self, value);
} else {
return legacy::cpu::_th_le_out(result, self, value);
}
}

Tensor & gt_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.gt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_gt_byte_out(result, self, other);
} else {
return legacy::cpu::_th_gt_out(result, self, other);
}
}

Tensor & gt_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.gt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_gt_byte_out(result, self, value);
} else {
return legacy::cpu::_th_gt_out(result, self, value);
}
}

Tensor & ge_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.ge received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_ge_byte_out(result, self, other);
} else {
return legacy::cpu::_th_ge_out(result, self, other);
}
}

Tensor & ge_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.ge received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_ge_byte_out(result, self, value);
} else {
return legacy::cpu::_th_ge_out(result, self, value);
}
}

Tensor & eq_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.eq received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_eq_byte_out(result, self, other);
} else {
return legacy::cpu::_th_eq_out(result, self, other);
}
}

Tensor & eq_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.eq received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_eq_byte_out(result, self, value);
} else {
return legacy::cpu::_th_eq_out(result, self, value);
}
}

Tensor & ne_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.ne received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_ne_byte_out(result, self, other);
} else {
return legacy::cpu::_th_ne_out(result, self, other);
}
}

Tensor & ne_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.ne received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_ne_byte_out(result, self, value);
} else {
return legacy::cpu::_th_ne_out(result, self, value);
}
}

}} // namespace at::native
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cpu/TensorCompareKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ static void max_kernel_impl(
Tensor& max_indices,
const Tensor& self,
c10::optional<int64_t> dim) {
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "max", [&] {
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "max", [&] {
Reduction<scalar_t, int64_t>::apply(max, max_indices, self, dim, true);
});
}
Expand All @@ -93,7 +93,7 @@ static void min_kernel_impl(
Tensor& min_indices,
const Tensor& self,
c10::optional<int64_t> dim) {
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "min", [&] {
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "min", [&] {
Reduction<scalar_t, int64_t>::apply(min, min_indices, self, dim, false);
});
}
Expand Down
Loading