Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,18 @@
return: real
arguments:
- THTensor* self
]]
[[
name: _th_all
types:
- Byte
variants:
- method
- function
backends:
- CPU
- CUDA
options:
- cname: logicalAnd
return: argument 0
scalar_check: self_->isScalar() || (keepdim == false && self_->dim() == 1)
Expand Down Expand Up @@ -1045,6 +1057,18 @@
return: real
arguments:
- THTensor* self
]]
[[
name: _th_any
types:
- Byte
variants:
- method
- function
backends:
- CPU
- CUDA
options:
- cname: logicalAny
return: argument 0
scalar_check: self_->isScalar() || (keepdim == false && self_->dim() == 1)
Expand Down Expand Up @@ -1641,6 +1665,18 @@
if_true: 0
if_false: 1
default: 0
]]
[[
name: _th_var
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
- function
options:
- cname: var
return: argument 0
scalar_check: self_->isScalar() || (keepdim == false && self_->dim() == 1)
Expand Down Expand Up @@ -1676,6 +1712,18 @@
if_true: 0
if_false: 1
default: 0
]]
[[
name: _th_std
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
- function
options:
- cname: std
return: argument 0
scalar_check: self_->isScalar() || (keepdim == false && self_->dim() == 1)
Expand Down Expand Up @@ -1711,7 +1759,7 @@
default: AS_REAL(2)
]]
[[
name: norm
name: _th_norm
types:
- floating_point
backends:
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/ExpandUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ std::vector<int64_t> infer_size(IntList a, IntList b) {
") must match the size of tensor b (", sizeB,
") at non-singleton dimension ", i);

expandedSizes[i] = std::max(sizeA, sizeB);
// 1s map to the other size (even 0).
expandedSizes[i] = sizeA == 1 ? sizeB : sizeA;
}

return expandedSizes;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
namespace at { namespace native {

Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps, bool keepdim) {
return norm(x1 - x2 + eps, p, 1, keepdim);
return at::norm(x1 - x2 + eps, p, 1, keepdim);
}
}} // namespace at::native
1 change: 1 addition & 0 deletions aten/src/ATen/native/Embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ Tensor embedding_sparse_backward(

// check if all our grad come from padding_idx
if (grad.numel() == 0) {
// FIXME: USE_TH_SIZE_ZERO_DIM

This comment was marked as off-topic.

This comment was marked as off-topic.

return sparse_type._sparse_coo_tensor_unsafe(indices_.type().tensor(),
dense_type.tensor(), weight_size);
}
Expand Down
28 changes: 19 additions & 9 deletions aten/src/ATen/native/Indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,13 @@ static std::vector<Tensor> expandByteTensors(const Tensor & self, TensorList ind
}
// Replace with nonzeros
auto nonzero = index.nonzero();
auto is_empty = nonzero.numel() == 0;
#ifndef USE_TH_SIZE_ZERO_DIM
auto special_empty = nonzero.numel() == 0;
#else
auto special_empty = false;
#endif
for (int64_t j = 0; j < index.dim(); j++) {
if (is_empty) {
if (special_empty) {
// We can't call select on an empty tensor so we just create an empty
// tensor.
result.emplace_back(nonzero.type().tensor());
Expand Down Expand Up @@ -143,13 +147,15 @@ static Tensor unsqueezeN(const Tensor & src, int64_t before, int64_t after) {
}

static Tensor wrapIndexOnce(const Tensor & index, int64_t dim, int64_t dim_size) {
auto max_idx = index.max().toCLong();
auto min_idx = index.min().toCLong();
if (max_idx >= dim_size) {
AT_ERROR("index ", max_idx, " is out of bounds for dimension ", dim, " with size ", dim_size);
}
if (min_idx < -dim_size) {
AT_ERROR("index ", min_idx, " is out of bounds for dimension ", dim, " with size ", dim_size);
if (index.numel() != 0) {
auto max_idx = index.max().toCLong();
auto min_idx = index.min().toCLong();
if (max_idx >= dim_size) {
AT_ERROR("index ", max_idx, " is out of bounds for dimension ", dim, " with size ", dim_size);
}
if (min_idx < -dim_size) {
AT_ERROR("index ", min_idx, " is out of bounds for dimension ", dim, " with size ", dim_size);
}
}
return index.remainder(dim_size);
}
Expand Down Expand Up @@ -208,6 +214,7 @@ static Tensor computeLinearIndex(const Tensor & src, TensorList indices) {
return linearIndex;
}

#ifndef USE_TH_SIZE_ZERO_DIM
static bool hasEmptyTensor(TensorList tensors) {
for (auto& tensor : tensors) {
if (tensor.defined() && tensor.numel() == 0) {
Expand All @@ -216,14 +223,17 @@ static bool hasEmptyTensor(TensorList tensors) {
}
return false;
}
#endif

static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig) {
checkIndexTensorTypes(orig);
// first expand ByteTensor (boolean masks) into 1 or more LongTensors
auto indices = expandByteTensors(self, orig);
#ifndef USE_TH_SIZE_ZERO_DIM
if (hasEmptyTensor(indices)) {
return std::make_tuple(self, self.type().toScalarType(kLong).tensor());
}
#endif
// next broadcast all index tensors together
indices = expand_outplace(indices);
// add missing null Tensors so that it matches self.dim()
Expand Down
154 changes: 116 additions & 38 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
#include "ATen/NativeFunctions.h"
#include "ATen/WrapDimUtils.h"
#include "ATen/WrapDimUtilsMulti.h"
#include "ReduceOpsUtils.h"
#include "cpu/ReduceOpsKernel.h"

#include <algorithm>
#include <functional>
#include <limits>
#include <numeric>
#include <vector>
#include <map>
Expand Down Expand Up @@ -94,10 +96,12 @@ static inline Tensor mean(const Tensor &self, optional<ScalarType> dtype) {
"Can only calculate the mean of floating types. Got ",
at::toString(scalarType),
" instead.");
Tensor result = at::native::sum(self);
if (self.numel() > 0)
result.div_(self.numel());
return result;
if (self.numel() > 0) {
Tensor result = at::native::sum(self);
return result.div_(self.numel());
} else {
return self.type().scalarTensor(std::numeric_limits<double>::quiet_NaN());
}
}

Tensor mean(const Tensor &self, ScalarType dtype) {
Expand Down Expand Up @@ -154,32 +158,6 @@ Tensor _prod_cpu(const Tensor &self) {

// DIM REDUCE #################################################################

static bool _dimreduce_return_trivial(Tensor &result, const Tensor &self,
int64_t ident) {
if (self.numel() == 1 && self.ndimension() == 0) {
result.resize_({});
result.fill_(self);
return true;
}
// Return identity
if (self.numel() == 0 && self.ndimension() == 1) {
result.resize_({0});
result.fill_(ident);
return true;
}
return false;
}

static Tensor &_dimreduce_setup(Tensor &result, const Tensor &self,
int64_t dim) {
IntList self_sizes = self.sizes();
std::vector<int64_t> result_sizes;
result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
result_sizes[dim] = 1;
result.resize_(result_sizes);
return result;
}

static inline Tensor &mean_out(Tensor &result, const Tensor &self, int64_t dim,
bool keepdim, optional<ScalarType> dtype) {
ScalarType scalarType = result.type().scalarType();
Expand All @@ -192,7 +170,12 @@ static inline Tensor &mean_out(Tensor &result, const Tensor &self, int64_t dim,
result, self.toType(result.type().scalarType()), dim, keepdim);
if (result.numel() > 0 && self.ndimension() > 0) {
int64_t numel = self.size(dim);
result.div_(numel);
if (numel > 0) {
result.div_(numel);
} else {
// NumPy equivalent
result.fill_(std::numeric_limits<double>::quiet_NaN());
}
}
return result;
}
Expand Down Expand Up @@ -235,7 +218,7 @@ Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, ScalarType dtyp
Tensor &_sum_out_cpu(Tensor &result, const Tensor &self, int64_t dim_,
bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
if (_dimreduce_return_trivial(result, self, 0))
if (_dimreduce_return_trivial(result, self, 0, dim, keepdim))
return result;
if (self.is_contiguous() && result.is_contiguous()) {
_dimreduce_setup(result, self, dim);
Expand Down Expand Up @@ -273,7 +256,7 @@ Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dty
Tensor &_prod_out_cpu(Tensor &result, const Tensor &self, int64_t dim_,
bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
if (_dimreduce_return_trivial(result, self, 1))
if (_dimreduce_return_trivial(result, self, 1, dim, keepdim))
return result;
if (self.is_contiguous() && result.is_contiguous()) {
_dimreduce_setup(result, self, dim);
Expand All @@ -294,7 +277,12 @@ static inline Tensor mean(const Tensor &self, int64_t dim, bool keepdim, optiona
Tensor result = at::native::sum(self, dim, keepdim);
if (result.numel() > 0 && self.ndimension() > 0) {
int64_t numel = self.size(dim);
result.div_(numel);
if (numel > 0) {
result.div_(numel);
} else {
// NumPy equivalent
result.fill_(std::numeric_limits<double>::quiet_NaN());
}
}
return result;
}
Expand Down Expand Up @@ -357,10 +345,15 @@ Tensor _prod(const Tensor &self, int64_t dim_, bool keepdim) {

Tensor& logsumexp_out(Tensor& result, const Tensor &self, int64_t dim_, bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
auto maxes = at::max_values(self, dim, true);
result = at::where((maxes == INFINITY).__or__(maxes == -INFINITY),
maxes,
maxes + at::log(at::sum(at::exp(self - maxes), dim, true)));
// can't take max of empty tensor.
if (self.numel() != 0) {
auto maxes = at::max_values(self, dim, true);
result = at::where((maxes == INFINITY).__or__(maxes == -INFINITY),
maxes,
maxes + at::log(at::sum(at::exp(self - maxes), dim, true)));
} else {
result = at::log(at::sum(at::exp(self), dim, true));
}
if (! keepdim)
result.squeeze_(dim);
return result;
Expand Down Expand Up @@ -588,4 +581,89 @@ Tensor& _sum_out(Tensor &result, const Tensor &self, IntList dims, bool keepdim)
return reduce_multi_associative_out<_sum, _sum_out>(result, self, dims, keepdim);
}

Tensor norm(const Tensor& self, Scalar p, int64_t dim, bool keepdim) {
Tensor result = self.type().tensor();
return at::native::norm_out(result, self, p, dim, keepdim);
}

Tensor &norm_out(Tensor &result, const Tensor &self, Scalar p, int64_t dim, bool keepdim) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"norm only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend()));
AT_CHECK(at::isFloatingType(self.type().scalarType()), "norm only supports floating-point dtypes");
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial(result, self, 0, dim, keepdim)) {
return result;
} else {
return at::_th_norm_out(result, self, p, dim, keepdim);
}
}

Tensor all(const Tensor& self, int64_t dim, bool keepdim) {
Tensor result = self.type().tensor();
return at::native::all_out(result, self, dim, keepdim);
}

Tensor &all_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"all only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend()));
AT_CHECK(self.type().scalarType() == at::ScalarType::Byte, "all only supports torch.uint8 dtype");
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial(result, self, 1, dim, keepdim)) {
return result;
} else {
return at::_th_all_out(result, self, dim, keepdim);
}
}

Tensor any(const Tensor& self, int64_t dim, bool keepdim) {
Tensor result = self.type().tensor();
return at::native::any_out(result, self, dim, keepdim);
}

Tensor &any_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"any only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend()));
AT_CHECK(self.type().scalarType() == at::ScalarType::Byte, "any only supports torch.uint8 dtype");
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial(result, self, 0, dim, keepdim)) {
return result;
} else {
return at::_th_any_out(result, self, dim, keepdim);
}
}

Tensor var(const Tensor& self, int64_t dim, bool unbiased, bool keepdim) {
Tensor result = self.type().tensor();
return at::native::var_out(result, self, dim, unbiased, keepdim);
}

Tensor &var_out(Tensor &result, const Tensor &self, int64_t dim, bool unbiased, bool keepdim) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"var only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend()));
AT_CHECK(at::isFloatingType(self.type().scalarType()), "var only supports floating-point dtypes");
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial(result, self, std::numeric_limits<double>::quiet_NaN(), dim, keepdim)) {
return result;
} else {
return at::_th_var_out(result, self, dim, unbiased, keepdim);
}
}

Tensor std(const Tensor& self, int64_t dim, bool unbiased, bool keepdim) {
Tensor result = self.type().tensor();
return at::native::std_out(result, self, dim, unbiased, keepdim);
}

Tensor &std_out(Tensor &result, const Tensor &self, int64_t dim, bool unbiased, bool keepdim) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"std only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend()));
AT_CHECK(at::isFloatingType(self.type().scalarType()), "std only supports floating-point dtypes");
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial(result, self, std::numeric_limits<double>::quiet_NaN(), dim, keepdim)) {
return result;
} else {
return at::_th_std_out(result, self, dim, unbiased, keepdim);
}
}

}} // namespace at::native
Loading