Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/WrapDimUtilsMulti.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ static inline std::bitset<dim_bitset_size> dim_list_to_bitset(IntList dims, int6
std::bitset<dim_bitset_size> seen;
for (size_t i = 0; i < dims.size(); i++) {
size_t dim = maybe_wrap_dim(dims[i], ndims);
AT_CHECK(!seen[dim], "dim ", dim, " appears multiple times in the list of reduced dims");
AT_CHECK(!seen[dim], "dim ", dim, " appears multiple times in the list of dims");
seen[dim] = true;
}
return seen;
Expand Down
80 changes: 72 additions & 8 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "ATen/ExpandUtils.h"
#include "ATen/NativeFunctions.h"
#include "ATen/WrapDimUtils.h"
#include "ATen/WrapDimUtilsMulti.h"
#include "cpu/ReduceOpsKernel.h"

#include <algorithm>
Expand Down Expand Up @@ -155,7 +156,7 @@ static Tensor &_dimreduce_setup(Tensor &result, const Tensor &self,
return result;
}

static inline Tensor &sum_out(Tensor &result, const Tensor &self, int64_t dim,
static inline Tensor &sum_out(Tensor &result, const Tensor &self, IntList dim,
bool keepdim, optional<ScalarType> dtype) {
// result type is favored over dtype; check that they match if provided (NumPy doesn't check)
AT_CHECK(!dtype.has_value() || (result.type().scalarType() == dtype.value()),
Expand All @@ -164,14 +165,14 @@ static inline Tensor &sum_out(Tensor &result, const Tensor &self, int64_t dim,
return at::_sum_out(result, self.toType(result.type().scalarType()), dim, keepdim);
}

Tensor& sum_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
return at::native::sum_out(result, self, dim, keepdim, at::optional<ScalarType>(dtype));
}
Tensor& sum_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim) {
Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, bool keepdim) {
return at::native::sum_out(result, self, dim, keepdim, nullopt);
}

Tensor& sum_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dtype) {
Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, ScalarType dtype) {
return at::native::sum_out(result, self, dim, false, dtype);
}

Expand Down Expand Up @@ -233,19 +234,19 @@ Tensor &_prod_out_cuda(Tensor &result, const Tensor &self, int64_t dim,
return at::_th_prod_out(result, self, dim, keepdim);
}

static inline Tensor sum(const Tensor &self, int64_t dim_, bool keepdim, optional<ScalarType> dtype) {
static inline Tensor sum(const Tensor &self, IntList dim_, bool keepdim, optional<ScalarType> dtype) {
return at::_sum(integer_upcast(self, dtype), dim_, keepdim);
}

Tensor sum(const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
Tensor sum(const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
return at::native::sum(self, dim, keepdim, at::optional<ScalarType>(dtype));
}

Tensor sum(const Tensor& self, int64_t dim, bool keepdim) {
Tensor sum(const Tensor& self, IntList dim, bool keepdim) {
return at::native::sum(self, dim, keepdim, nullopt);
}

Tensor sum(const Tensor& self, int64_t dim, ScalarType dtype) {
Tensor sum(const Tensor& self, IntList dim, ScalarType dtype) {
return at::native::sum(self, dim, false, dtype);
}

Expand Down Expand Up @@ -278,5 +279,68 @@ Tensor _prod(const Tensor &self, int64_t dim_, bool keepdim) {
}

// \DIM REDUCE ################################################################

// MULTI DIM REDUCE ###########################################################

template <Tensor (reduce_1)(const Tensor &, int64_t, bool)>
inline Tensor reduce_multi_associative(const Tensor &self, IntList dims_, bool keepdim) {
if (dims_.size() == 1) {
return reduce_1(self, dims_[0], keepdim);
}
if (dims_.size() == 0) {
return self;
}
int64_t ndims = self.dim();
auto reduce_dims = dim_list_to_bitset(dims_, ndims);
Tensor result = self;
for (int64_t dim = ndims-1; dim >= 0; dim--) {
if (reduce_dims[dim])
result = reduce_1(result, dim, keepdim);
}
return result;
}

template <Tensor (reduce_1)(const Tensor &, int64_t, bool),
Tensor& (reduce_1_out)(Tensor& result, const Tensor &, int64_t, bool)>
inline Tensor& reduce_multi_associative_out(Tensor &result, const Tensor &self, IntList dims_, bool keepdim) {
if (dims_.size() == 1) {
return reduce_1_out(result, self, dims_[0], keepdim);
}
int64_t ndims = self.dim();
auto reduce_dims = dim_list_to_bitset(dims_, ndims);
Tensor t = self;
int64_t last_reduction = dims_.size()-1;
int64_t num_reduction = 0;
for (int64_t dim = ndims-1; dim >= 0; dim--) {
if (reduce_dims[dim]) {
if (num_reduction < last_reduction) {
t = reduce_1(t, dim, keepdim);
} else {
reduce_1_out(result, t, dim, keepdim);
}
num_reduction++;
}
}
return result;
}


Tensor& _sum_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
if (self.is_cuda()) {
return _sum_out_cuda(result, self, dim, keepdim);
}
else {
return _sum_out_cpu(result, self, dim, keepdim);
}
}

Tensor _sum(const Tensor &self, IntList dims, bool keepdim) {
return reduce_multi_associative<_sum>(self, dims, keepdim);
}

Tensor& _sum_out(Tensor &result, const Tensor &self, IntList dims, bool keepdim)
{
return reduce_multi_associative_out<_sum, _sum_out>(result, self, dims, keepdim);
}

}} // namespace at::native
19 changes: 8 additions & 11 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -717,28 +717,25 @@
CPU: _sum_cpu
CUDA: _sum_cuda

- func: sum(Tensor self, int64_t dim, bool keepdim, *, ScalarType dtype) -> Tensor
- func: sum(Tensor self, IntList[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor

- func: sum(Tensor self, int64_t dim, bool keepdim=False) -> Tensor
- func: sum(Tensor self, IntList[1] dim, bool keepdim=False) -> Tensor

- func: sum(Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor
- func: sum(Tensor self, IntList[1] dim, *, ScalarType dtype) -> Tensor

- func: _sum(Tensor self, int64_t dim, bool keepdim=False) -> Tensor
- func: _sum(Tensor self, IntList[1] dim, bool keepdim=False) -> Tensor

- func: sum_out(Tensor result, Tensor self, int64_t dim, bool keepdim, *, ScalarType dtype) -> Tensor
- func: sum_out(Tensor result, Tensor self, IntList[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
variants: function

- func: sum_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor
- func: sum_out(Tensor result, Tensor self, IntList[1] dim, bool keepdim=False) -> Tensor
variants: function

- func: sum_out(Tensor result, Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor
- func: sum_out(Tensor result, Tensor self, IntList[1] dim, *, ScalarType dtype) -> Tensor
variants: function

- func: _sum_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor
- func: _sum_out(Tensor result, Tensor self, IntList[1] dim, bool keepdim=False) -> Tensor
variants: function
dispatch:
CPU: _sum_out_cpu
CUDA: _sum_out_cuda

- func: sqrt(Tensor self) -> Tensor

Expand Down
2 changes: 2 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2621,6 +2621,8 @@ class dont_convert(tuple):
('sum', (), NO_ARGS, 'scalar'),
('sum', (), (0,), 'scalar_dim', [0]),
('sum', (), (0, True,), 'scalar_keepdim_dim', [0]),
('sum', (S, S, S), ([1, 2],), 'multi_dim'),
('sum', (S, S, S), ([1, 2], True,), 'multi_dim_keepdim'),
('prod', (S, S, S), NO_ARGS),
('prod', (S, S, S), (1,), 'dim', [0]),
('prod', (S, S, S), (1, True,), 'keepdim_dim', [0]),
Expand Down
7 changes: 7 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,8 @@ def make_tensors(*shape):
check_sum_dim(make_tensors(50, 50, 50), 0)
check_sum_dim(make_tensors(50, 50, 50), 1)
check_sum_dim(make_tensors(50, 50, 50), 2)
check_sum_dim(make_tensors(50, 50, 50), (1, 2))
check_sum_dim(make_tensors(50, 50, 50), (1, -1))

def make_contiguous_slice(size, dtype):
contig = make_contiguous((1, size), dtype)
Expand All @@ -1522,6 +1524,11 @@ def test_sum_out(self):
res2 = torch.Tensor()
torch.sum(x, 1, out=res2)
self.assertEqual(res1, res2)
x = torch.rand(100, 100, 100)
res1 = x.sum(2).sum(1)
res2 = torch.Tensor()
torch.sum(x, (2, 1), out=res2)
self.assertEqual(res1, res2)

# TODO: these tests only check if it's possible to pass a return value
# it'd be good to expand them
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@
- name: _sum(Tensor self)
self: grad.expand(self.sizes())

- name: _sum(Tensor self, int64_t dim, bool keepdim)
- name: _sum(Tensor self, IntList dim, bool keepdim)
self: sum_backward(grad, self.sizes(), dim, keepdim)

- name: svd(Tensor self, bool some)
Expand Down
15 changes: 13 additions & 2 deletions tools/autograd/templates/Functions.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "Functions.h"
#include <ATen/WrapDimUtils.h>
#include <ATen/WrapDimUtilsMulti.h>

// define constants like M_PI and C keywords for MSVC
#ifdef _MSC_VER
Expand Down Expand Up @@ -132,9 +133,19 @@ Tensor permute_backwards(const Tensor & grad, IntList fwd_dims) {
return grad.permute(dims);
}

Tensor sum_backward(const Tensor & grad, IntList sizes, int64_t dim, bool keepdim) {
Tensor sum_backward(const Tensor & grad, IntList sizes, IntList dims, bool keepdim) {
if (!keepdim && sizes.size() > 0) {
return grad.unsqueeze(dim).expand(sizes);
if (dims.size()==1) {
return grad.unsqueeze(dims[0]).expand(sizes);
} else {
auto dims_to_unsqueeze = dim_list_to_bitset(dims, sizes.size());
Tensor res = grad;
for (size_t i = 0; i < sizes.size(); i++){
if (dims_to_unsqueeze[i])
res = res.unsqueeze(i);
}
return res.expand(sizes);
}
} else {
return grad.expand(sizes);
}
Expand Down
8 changes: 6 additions & 2 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4094,7 +4094,8 @@ def parse_kwargs(desc):
.. function:: sum(input, dim, keepdim=False, out=None) -> Tensor

Returns the sum of each row of the :attr:`input` tensor in the given
dimension :attr:`dim`.
dimension :attr:`dim`. If :attr::`dim` is a list of dimensions,
reduce over all of them.

If :attr:`keepdim` is ``True``, the output tensor is of the same size
as :attr:`input` except in the dimension :attr:`dim` where it is of size 1.
Expand All @@ -4103,7 +4104,7 @@ def parse_kwargs(desc):

Args:
input (Tensor): the input tensor
dim (int): the dimension to reduce
dim (int or tuple of ints): the dimension or dimensions to reduce
keepdim (bool): whether the output tensor has :attr:`dim` retained or not
out (Tensor, optional): the output tensor

Expand All @@ -4117,6 +4118,9 @@ def parse_kwargs(desc):
[ 0.3637, -0.9906, -0.4752, -1.5197]])
>>> torch.sum(a, 1)
tensor([-0.4598, -0.1381, 1.3708, -2.6217])
>>> b = torch.arange(4 * 5 * 6).view(4, 5, 6)
>>> torch.sum(b, (2, 1))
tensor([ 435., 1335., 2235., 3135.])
""")

add_docstr(torch.svd,
Expand Down