Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: pow_
.. automethod:: prod
.. automethod:: pstrf
.. automethod:: put_
.. automethod:: qr
.. automethod:: random_
.. automethod:: reciprocal
Expand Down Expand Up @@ -285,6 +286,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: symeig
.. automethod:: t
.. automethod:: t_
.. automethod:: take
.. automethod:: tan
.. automethod:: tan_
.. automethod:: tanh
Expand Down
2 changes: 1 addition & 1 deletion docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Indexing, Slicing, Joining, Mutating Ops
.. autofunction:: squeeze
.. autofunction:: stack
.. autofunction:: t
.. autofunction:: take
.. autofunction:: transpose
.. autofunction:: unbind
.. autofunction:: unsqueeze
Expand Down Expand Up @@ -200,4 +201,3 @@ BLAS and LAPACK Operations
.. autofunction:: svd
.. autofunction:: symeig
.. autofunction:: trtrs

7 changes: 7 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def large_2d_lapack(t):
return t(1000, 1000).normal_()


def long_type(t):
return torch.cuda.LongTensor if 'cuda' in t.__module__ else torch.LongTensor


def new_t(*sizes):
def tmp(t):
return t(*sizes).copy_(torch.randn(*sizes))
Expand Down Expand Up @@ -249,6 +253,8 @@ def tmp(t):
('norm', small_3d, lambda t: [3, -2], '3_norm_neg_dim'),
('ones', small_3d, lambda t: [1, 2, 3, 4, 5],),
('permute', new_t(1, 2, 3, 4), lambda t: [2, 1, 3, 0],),
('put_', new_t(2, 5, 3), lambda t: [long_type(t)([[0], [-2]]), t([[3], [4]])],),
('put_', new_t(2, 2), lambda t: [long_type(t)([[1], [-3]]), t([[1], [2]]), True], 'accumulate'),
('prod', small_2d_oneish, lambda t: [],),
('prod', small_3d, lambda t: [1], 'dim'),
('prod', small_3d, lambda t: [-1], 'neg_dim'),
Expand All @@ -274,6 +280,7 @@ def tmp(t):
('squeeze', new_t(1, 2, 1, 4), lambda t: [2], 'dim'),
('squeeze', new_t(1, 2, 1, 4), lambda t: [-2], 'neg_dim'),
('t', new_t(1, 2), lambda t: [],),
('take', new_t(3, 4), lambda t: [long_type(t)([[0], [-2]])],),
('transpose', new_t(1, 2, 3, 4), lambda t: [1, 2],),
('transpose', new_t(1, 2, 3, 4), lambda t: [-1, -2], 'neg_dim'),
('to_list', small_3d, lambda t: [],),
Expand Down
34 changes: 34 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3429,6 +3429,40 @@ def test_index_add(self):
dest2[idx[i]] = dest2[idx[i]] + src[i]
self.assertEqual(dest, dest2)

def test_take(self):
def check(src, idx):
expected = src.contiguous().view(-1).index_select(
0, idx.contiguous().view(-1)).view_as(idx)
actual = src.take(idx)
self.assertEqual(actual.size(), idx.size())
self.assertEqual(expected, actual)

src = torch.randn(2, 3, 5)
idx = torch.LongTensor([[0, 2], [3, 4]])
check(src, idx)
check(src.transpose(1, 2), idx)

def test_put_(self):
def check(dst, idx, value):
expected = dst.clone().view(-1).index_copy_(
0, idx.contiguous().view(-1), value.contiguous().view(-1))
expected = expected.view_as(dst)
dst.put_(idx, value)
self.assertEqual(expected, dst)

dst = torch.randn(2, 3, 5)
idx = torch.LongTensor([[0, 2], [3, 4]])
values = torch.randn(2, 2)
check(dst, idx, values)
check(dst.transpose(1, 2), idx, values)

def test_put_accumulate(self):
dst = torch.ones(2, 2)
idx = torch.LongTensor([[0, 1], [0, 1]])
src = torch.Tensor([1, 2, 3, 4])
dst.put_(idx, src, accumulate=True)
self.assertEqual(dst.tolist(), [[5, 7], [1, 1]])

# Fill idx with valid indices.
@staticmethod
def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o):
Expand Down
7 changes: 7 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,10 @@
- name: pstrf(Tensor self, bool upper, Scalar tol)
self: not_implemented("pstrf")

- name: put(Tensor self, Tensor index, Tensor source, bool accumulate)
self: zeros_like(self).put_(index, source, accumulate)
source: grad.take(index)

- name: qr(Tensor self)
self: not_implemented("qr")

Expand Down Expand Up @@ -490,6 +494,9 @@
self: grad.t()
__view__: True

- name: take(Tensor self, Tensor index)
self: zeros_like(self).put_(index, grad, true)

- name: tan(Tensor self)
self: grad / self.cos().pow(2)

Expand Down
33 changes: 30 additions & 3 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,33 @@ def callable(a, b) -> number
See :func:`torch.pstrf`
""")

add_docstr_all('put_',
"""
put_(indices, tensor, accumulate=False) -> Tensor

Copies the elements from :attr:`tensor` into the positions specified by
indices. For the puropose of indexing, the ``self`` tensor is treated as if it
were a 1D tensor.

If :attr:`accumulate` is ``True``, the elements in :attr:`tensor` are added to
:attr:`self`. If accumulate is ``False``, the behavior is undefined if indices
contains duplicate elements.

Args:
indices (LongTensor): the indices into self
tensor (Tensor): Tensor containing values to copy
accumulate (bool): True to accumulate into self

Example::

>>> src = torch.Tensor([[4, 3, 5],
... [6, 7, 8]])
>>> src.put_(torch.LongTensor([1, 3]), torch.Tensor([9, 10]))
4 9 5
10 7 8
[torch.FloatTensor of size 2x3]
""")

add_docstr_all('qr',
"""
qr() -> (Tensor, Tensor)
Expand Down Expand Up @@ -1543,11 +1570,11 @@ def callable(a, b) -> number
In-place version of :meth:`~Tensor.t`
""")

add_docstr_all('tan',
add_docstr_all('take',
"""
tan() -> Tensor
take(indices) -> Tensor

See :func:`torch.tan`
See :func:`torch.take`
""")

add_docstr_all('tan_',
Expand Down
22 changes: 22 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4405,6 +4405,28 @@

""")

add_docstr(torch._C.take, """\
take(input, indices) -> Tensor

Returns a new `Tensor` with the elements of :attr:`input` at the given indices.
The input tensor is treated as if it were viewed as a 1D tensor. The result
takes the same shape as the indices.

Args:
input (Tensor): the input `Tensor`
indices (LongTensor): the indices into `Tensor`

Example::

>>> src = torch.Tensor([[4, 3, 5],
... [6, 7, 8]])
>>> torch.take(src, torch.LongTensor([0, 2, 5]))
4
5
8
[torch.FloatTensor of size 3]
""")

add_docstr(torch._C.tan,
"""
tan(input, out=None) -> Tensor
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ IMPLEMENT_STATELESS(zeros_like)
IMPLEMENT_STATELESS(ones)
IMPLEMENT_STATELESS(ones_like)
IMPLEMENT_STATELESS(index_select)
IMPLEMENT_STATELESS(take)
IMPLEMENT_STATELESS(ger)
IMPLEMENT_STATELESS(mv)
IMPLEMENT_STATELESS(mm)
Expand Down Expand Up @@ -678,6 +679,7 @@ static PyMethodDef TorchMethods[] = {
{"ones", (PyCFunction)THPModule_ones, METH_VARARGS | METH_KEYWORDS, NULL},
{"ones_like", (PyCFunction)THPModule_ones_like, METH_VARARGS | METH_KEYWORDS, NULL},
{"index_select", (PyCFunction)THPModule_index_select, METH_VARARGS | METH_KEYWORDS, NULL},
{"take", (PyCFunction)THPModule_take, METH_VARARGS | METH_KEYWORDS, NULL},
{"addmm", (PyCFunction)THPModule_addmm, METH_VARARGS | METH_KEYWORDS, NULL},
{"addmv", (PyCFunction)THPModule_addmv, METH_VARARGS | METH_KEYWORDS, NULL},
{"addr", (PyCFunction)THPModule_addr, METH_VARARGS | METH_KEYWORDS, NULL},
Expand Down
28 changes: 27 additions & 1 deletion torch/csrc/generic/methods/Tensor.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,33 @@ PyObject * THPTensor_(stride)(PyObject *self, PyObject *args, PyObject *kwargs)
- THIndexTensor* index
- THTensor* source
]]

[[
name: take
cname: take
variants:
- method
- function
return: argument 0
arguments:
- arg: THTensor* result
output: True
- THTensor* self
- THIndexTensor* index
]]
[[
name: put_
cname: put
backends:
- CPU
- CUDA
return: argument 0
arguments:
- THTensor* self
- THIndexTensor* index
- THTensor* source
- arg: bool accumulate
default: "false"
]]
[[
name: indexAdd_
python_name: index_add_
Expand Down
27 changes: 27 additions & 0 deletions torch/lib/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,33 @@
- THIndexTensor* index
- THTensor* source
]]
[[
name: take
cname: take
variants:
- method
- function
return: argument 0
arguments:
- arg: THTensor* result
output: True
- THTensor* self
- THIndexTensor* index
]]
[[
name: put_
cname: put
backends:
- CPU
- CUDA
return: argument 0
arguments:
- THTensor* self
- THIndexTensor* index
- THTensor* source
- arg: bool accumulate
default: "false"
]]
[[
name: indexAdd_
python_name: index_add_
Expand Down
75 changes: 75 additions & 0 deletions torch/lib/TH/generic/THTensorMath.c
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,81 @@ void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTens
THLongTensor_free(index);
}

static ptrdiff_t THTensor_(dataOffset)(THTensor* tensor, ptrdiff_t linearIndex) {
int64_t *size = tensor->size;
int64_t *stride = tensor->stride;
int nDim = tensor->nDimension;
ptrdiff_t dataOffset = 0;
for (int i = nDim - 1; i >= 0; i--) {
dataOffset += (linearIndex % size[i]) * stride[i];
linearIndex /= size[i];
}
return dataOffset;
}

static int64_t THTensor_(wrapLinearIndex)(int64_t linearIndex, int64_t numel) {
THArgCheck(linearIndex < numel && linearIndex >= -numel, 2, "out of range: %d out of %d", (int)linearIndex, (int)numel);
return linearIndex < 0 ? linearIndex + numel : linearIndex;
}

void THTensor_(take)(THTensor *r_, THTensor *src, THLongTensor *index)
{
THTensor_(resizeNd)(r_, index->nDimension, index->size, NULL);
THTensor* dst = THTensor_(newContiguous)(r_);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


index = THLongTensor_newContiguous(index);
long* index_data = THLongTensor_data(index);
ptrdiff_t srcElements = THTensor_(nElement)(src);
real* src_data = THTensor_(data)(src);
real* dst_data = THTensor_(data)(dst);

ptrdiff_t nIndices = THLongTensor_nElement(index);
if (THTensor_(isContiguous)(src)) {
ptrdiff_t i;
#pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i = 0; i < nIndices; i++) {
int64_t linearIndex = THTensor_(wrapLinearIndex)(index_data[i], srcElements);
dst_data[i] = src_data[linearIndex];
}
} else {
ptrdiff_t i;
#pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i = 0; i < nIndices; i++) {
int64_t linearIndex = THTensor_(wrapLinearIndex)(index_data[i], srcElements);
int64_t dataOffset = THTensor_(dataOffset)(src, linearIndex);
dst_data[i] = src_data[dataOffset];
}
}

THLongTensor_free(index);
THTensor_(freeCopyTo)(dst, r_);
}

void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int accumulate)
{
THArgCheck(THLongTensor_nElement(index) == THTensor_(nElement)(src), 3,
"src should have the same number of elements as index");

index = THLongTensor_newContiguous(index);
src = THTensor_(newContiguous)(src);
real* data = THTensor_(data)(tensor);
ptrdiff_t numel = THTensor_(nElement)(tensor);
int is_contiguous = THTensor_(isContiguous)(tensor);

TH_TENSOR_APPLY2(int64_t, index, real, src,
int64_t linearIndex = THTensor_(wrapLinearIndex)(*index_data, numel);
int64_t dataOffset = is_contiguous ? linearIndex : THTensor_(dataOffset)(tensor, linearIndex);
if (accumulate) {
data[dataOffset] += *src_data;
} else {
data[dataOffset] = *src_data;
}
);

THTensor_(free)(src);
THLongTensor_free(index);
}

void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src)
{
ptrdiff_t i, numel;
Expand Down
2 changes: 2 additions & 0 deletions torch/lib/TH/generic/THTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ TH_API void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THL
TH_API void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
TH_API void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
TH_API void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, real val);
TH_API void THTensor_(take)(THTensor *tensor, THTensor *src, THLongTensor *index);
TH_API void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int accumulate);

TH_API void THTensor_(gather)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index);
TH_API void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
Expand Down
Loading