Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions aten/src/ATen/native/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,12 +727,12 @@ Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight
Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2) {
TORCH_CHECK(dims1.size() == dims2.size(), "both dimension lists should have same length");
TORCH_CHECK(input1.scalar_type() == input2.scalar_type(), "both inputs should have same dtype");
int64_t csize = 1; // total size of the contracted dimensions
SymInt csize = 1; // total size of the contracted dimensions
Tensor t1 = input1;
Tensor t2 = input2;
for (const auto i : c10::irange(dims1.size())) {
int s1 = input1.size(dims1[i]);
int s2 = input2.size(dims2[i]);
SymInt s1 = input1.sym_size(dims1[i]);
SymInt s2 = input2.sym_size(dims2[i]);
if (s2 == 1) { // broadcasted dimensions can be summed right away
t1 = t1.sum(dims1[i], true);
} else if (s1 == 1) {
Expand All @@ -746,19 +746,20 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,

auto cdims1 = at::dim_list_to_bitset(dims1, input1.dim());
auto cdims2 = at::dim_list_to_bitset(dims2, input2.dim());
std::vector<int64_t> p1, p2, rsizes; // p1, p2: input permutations, rsizes: sizes of the result
std::vector<int64_t> p1, p2; // p1, p2: input permutations
std::vector<SymInt> rsizes; // rsizes: sizes of the result
p1.reserve(input1.dim());
p2.reserve(input2.dim());
rsizes.reserve(input1.dim() + input2.dim() - (int64_t) dims1.size());
int64_t size1 = 1; // number of non-contracted elements in input1
int64_t size2 = 1; // number of non-contracted elements in input2
SymInt size1 = 1; // number of non-contracted elements in input1
SymInt size2 = 1; // number of non-contracted elements in input2

// fill the permutations and compute sizes
for (const auto i : c10::irange(input1.dim())) {
if (! cdims1[i]) {
p1.emplace_back(i);
size1 *= t1.size(i);
rsizes.emplace_back(t1.size(i));
size1 *= t1.sym_size(i);
rsizes.emplace_back(t1.sym_size(i));
}
}
for (const auto x : dims1) {
Expand All @@ -770,15 +771,15 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,
for (const auto i : c10::irange(input2.dim())) {
if (! cdims2[i]) {
p2.emplace_back(i);
size2 *= t2.size(i);
rsizes.emplace_back(t2.size(i));
size2 *= t2.sym_size(i);
rsizes.emplace_back(t2.sym_size(i));
}
}
// permut and reshape for matrix multiplication
t1 = t1.permute(p1).reshape({size1, csize});
t2 = t2.permute(p2).reshape({csize, size2});
t1 = t1.permute(p1).reshape_symint({size1, csize});
t2 = t2.permute(p2).reshape_symint({csize, size2});
// multiply and reshape to target size
return at::mm(t1, t2).reshape(rsizes);
return at::mm(t1, t2).reshape_symint(rsizes);
}

Tensor &tensordot_out(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2, Tensor& result) {
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1284,11 +1284,11 @@ Tensor inner(const Tensor& self, const Tensor& other) {

// Last dimension should match (tensordot does not enforce this)
TORCH_CHECK(
self.size(-1) == other.size(-1),
self.sym_size(-1) == other.sym_size(-1),
"inner() the last dimension must match on both input tensors but got shapes ",
self.sizes(),
self.sym_sizes(),
" and ",
other.sizes());
other.sym_sizes());

return at::tensordot(self, other, -1, -1);
}
Expand Down
2 changes: 0 additions & 2 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2502,7 +2502,6 @@ def forward(self, x):
xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('inner', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('linalg.cholesky_ex', ''), # could not find kernel for aten.linalg_solve_triangular.default
Expand Down Expand Up @@ -2612,7 +2611,6 @@ def forward(self, x):
xfail('svd', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('svd_lowrank', ''), # could not find kernel
xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('tensordot', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/de...
xfail('_upsample_bilinear2d_aa'), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList
Expand Down
2 changes: 0 additions & 2 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1455,7 +1455,6 @@ def f(a, b, c, d, e):
xfail('histogramdd', ''), # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition
xfail('hsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('index_reduce', ''), # Float
xfail('inner', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition
xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition
Expand Down Expand Up @@ -1559,7 +1558,6 @@ def f(a, b, c, d, e):
xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at...
xfail('svd_lowrank', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition
xfail('take_along_dim', ''), # dtype of indices should be Long but got Float
xfail('tensordot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition
xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
Expand Down