Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions aten/src/ATen/native/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,4 +440,62 @@ Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight
return output;
}

// implements tensordot, a matrix-multiplication-like contraction, but the dimensions given
// in the two dimension lists
Tensor tensordot(const Tensor& input1, const Tensor& input2, IntList dims1, IntList dims2) {
AT_CHECK(dims1.size() == dims2.size(), "both dimension lists should have same length");
int64_t csize = 1; // total size of the contracted dimensions
Tensor t1 = input1;
Tensor t2 = input2;
for (size_t i = 0; i < dims1.size(); i++) {
int s1 = input1.size(dims1[i]);
int s2 = input2.size(dims2[i]);
if (s2 == 1) { // broadcasted dimensions can be summed right away
t1 = t1.sum(dims1[i], true);
} else if (s1 == 1) {
t2 = t2.sum(dims2[i], true);
} else {
AT_CHECK(s1 == s2, "contracted dimensions need to match, but first has size ", s1, " in dim ", dims1[i],
" and second has size ", s2, " in dim ", dims2[i]);
csize *= s1;
}
}

auto cdims1 = dim_list_to_bitset(dims1, input1.dim());
auto cdims2 = dim_list_to_bitset(dims2, input2.dim());
std::vector<int64_t> p1, p2, rsizes; // p1, p2: input permutations, rsizes: sizes of the result
p1.reserve(input1.dim());
p2.reserve(input2.dim());
rsizes.reserve(input1.dim() + input2.dim() - (int64_t) dims1.size());
int64_t size1 = 1; // number of non-contracted elements in input1
int64_t size2 = 1; // number of non-contracted elements in input2

// fill the permutations and compute sizes
for (int64_t i = 0; i < input1.dim(); i++) {
if (! cdims1[i]) {
p1.emplace_back(i);
size1 *= t1.size(i);
rsizes.emplace_back(t1.size(i));
}
}
for (size_t i = 0; i < dims1.size(); i++) {
p1.emplace_back(dims1[i]);
}
for (size_t i = 0; i < dims2.size(); i++) {
p2.emplace_back(dims2[i]);
}
for (int64_t i = 0; i < input2.dim(); i++) {
if (! cdims2[i]) {
p2.emplace_back(i);
size2 *= t2.size(i);
rsizes.emplace_back(t2.size(i));
}
}
// permut and reshape for matrix multiplication
t1 = t1.permute(p1).reshape({size1, csize});
t2 = t2.permute(p2).reshape({csize, size2});
// multiply and reshape to target size
return at::mm(t1, t2).reshape(rsizes);
}

}} // namespace at::native
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1618,6 +1618,9 @@
CPU: _tanh_out_cpu
CUDA: _tanh_out_cuda

- func: tensordot(Tensor self, Tensor other, IntList dims_self, IntList dims_other) -> Tensor
variants: function

- func: transpose(Tensor self, int64_t dim0, int64_t dim1) -> Tensor

- func: transpose_(Tensor self, int64_t dim0, int64_t dim1) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ Other Operations
.. autofunction:: histc
.. autofunction:: meshgrid
.. autofunction:: renorm
.. autofunction:: tensordot
.. autofunction:: trace
.. autofunction:: tril
.. autofunction:: triu
Expand Down
20 changes: 20 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3260,6 +3260,26 @@ def test_sort(self):
# Test that we still have proper sorting with duplicate keys
self.assertIsOrdered('descending', x, res2val, res2ind, 'random with duplicate keys')

@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
def test_tensordot(self):
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
for d in devices:
a = torch.arange(60., device=d).reshape(3, 4, 5)
b = torch.arange(24., device=d).reshape(4, 3, 2)
c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu()
cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
axes=([1, 0], [0, 1])))
self.assertEqual(c, cn)
a = torch.randn(2, 3, 4, 5, device=d)
b = torch.randn(4, 5, 6, 7, device=d)
c = torch.tensordot(a, b, dims=2).cpu()
cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
axes=2))
self.assertEqual(c, cn)
c = torch.tensordot(a, b).cpu()
cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy()))
self.assertEqual(c, cn)

def test_topk(self):
def topKViaSort(t, k, dim, dir):
sorted, indices = t.sort(dim, dir)
Expand Down
56 changes: 56 additions & 0 deletions torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
'isnan',
'split',
'stft',
'tensordot',
'unique',
]

Expand Down Expand Up @@ -429,6 +430,61 @@ def argmin(input, dim=None, keepdim=False):
return torch._argmin(input, dim, keepdim)


def tensordot(a, b, dims=2):
"""Returns a contraction of a and b over multiple dimensions.

:attr:`tensordot` implements a generalizes the matrix product.

Args:
a (Tensor): Left tensor to contract
b (Tensor): Right tensor to contract
dims (int or tuple of two lists of integers): number of dimensions to
contract or explicit lists of dimensions for :attr:`a` and
:attr:`b` respectively

When called with an integer argument :attr:`dims` = :math:`d`, and the number of
dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`, respectively,
it computes

.. math::
r_{i_0,...,i_{m-d}, i_d,...,i_n}
= \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} * b_{k_0,...,k_{d-1}, i_d,...,i_n}.

When called with :attr:`dims` of the list form, the given dimensions will be contracted
in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes
in these dimensions must match, but :attr:`tensordot` will deal with broadcasted
dimensions.

Examples::

>>> a = torch.arange(60.).reshape(3, 4, 5)
>>> b = torch.arange(24.).reshape(4, 3, 2)
>>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))
tensor([[4400., 4730.],
[4532., 4874.],
[4664., 5018.],
[4796., 5162.],
[4928., 5306.]])

>>> a = torch.randn(3, 4, 5, device='cuda')
>>> b = torch.randn(4, 5, 6, device='cuda')
>>> c = torch.tensordot(a, b, dims=2).cpu()
tensor([[ 8.3504, -2.5436, 6.2922, 2.7556, -1.0732, 3.2741],
[ 3.3161, 0.0704, 5.0187, -0.4079, -4.3126, 4.8744],
[ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]])

"""
if isinstance(dims, (list, tuple)) or \
(isinstance(dims, torch.Tensor) and dims.numel() > 1):
dims_a, dims_b = dims
else:
if isinstance(dims, torch.Tensor):
dims = dims.item()
dims_a = list(range(-dims, 0))
dims_b = list(range(dims))
return torch._C._VariableFunctions.tensordot(a, b, dims_a, dims_b)


def argsort(input, dim=None, descending=False):
"""Returns the indices that sort a tensor along a given dimension in ascending
order by value.
Expand Down