Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 171 additions & 79 deletions aten/src/ATen/native/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#include "ATen/NativeFunctions.h"
#include "ATen/WrapDimUtilsMulti.h"


namespace at { namespace native {


Expand Down Expand Up @@ -109,11 +108,28 @@ Tensor einsum(std::string eqn, TensorList tensors) {
constexpr size_t number_of_letters = 26;
std::string in_eqn;
size_t pos;
// we need are number of mappings (letter) index for analysing the equation. The index runs from 0='a' through 25='z'.
std::array<std::int64_t, number_of_letters> number_of_occurrences; // number of occurrence in the equation of this index
number_of_occurrences.fill(0);
std::array<std::int64_t, number_of_letters> last_occurrence; // the last operator (left to right) using this index
last_occurrence.fill(-1);
// The equation is given in terms of single lowercase letters ('a'..'z') and potentially an ellipsis.
// Internally, we represent it using indices from 0 to num_total_dimensions, with each letter
// mapped to an index and the ellipsis ('...') being mapped to a number of consequtive indices.
// The mapping of letters to internal indices is given in letter_mapping. A value of -1 means that
// the letter has not been assigned an index yet (because it has not been seen).
// The ellipsis is defined by first_ell_idx (the first index) and num_ell_idxes (the number of indices).
// A value of -1 for num_ell_idxes specifies that we have not seen an ellipsis yet.
// Note: The internal indices are NOT the dimensions used internally. There is a mapping to them below.

std::array<std::int64_t, number_of_letters> letter_mapping; // map letter to internal (numerical) label
letter_mapping.fill(-1);
int64_t num_ell_idxes = -1;
int64_t first_ell_idx = 0;

// The internal representation of the left hand side fo the equation (with ellipsis expanded) is stored in input_op_idxes.
// For each operand, we have a vector mapping each dimension to an internal index.
// We also keep track of the number of occurrences for each letter (to infer a right hand side if not given) and
// of the last occurence of each index.
std::vector<std::vector<int64_t>> input_op_idxes; // the parsed operand indices
std::array<std::int64_t, number_of_letters> num_letter_occurrences; // number of occurrence in the equation of this letter
num_letter_occurrences.fill(0);
std::vector<std::int64_t> last_idx_occurrence; // the last operator (left to right) using this index

if ((pos = eqn.find("->")) != std::string::npos) { // check whether we have a right hand side. in_eq is the left hand side
in_eqn = eqn.substr(0, pos);
Expand All @@ -125,120 +141,196 @@ Tensor einsum(std::string eqn, TensorList tensors) {
int64_t operand = 0;
std::stringstream eqn_stream(in_eqn);
std::string term;
int64_t num_total_idxes = 0;
while (! eqn_stream.eof()) {
std::getline(eqn_stream, term, ','); // term = string with indices of current term
int64_t dims_in_operand = 0;
for (auto &c : term) { // c = character with a single index
AT_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices");
int64_t index_num = c-'a'; // index_num = index to be used in the vectors above
number_of_occurrences[index_num]++;
// when there are two occurrences we need to take a diagonal with respect to the dimensions
// occuring multiple times before continuing the processing.
// e.g. einsum('ii->i', [A]) should return the diagonal
// This waits for the general diagonal handling discussed in #6479
// for now, we error out here
AT_CHECK(last_occurrence[index_num] < operand, "diagonals (multiple occurrences of the same index for one tensor) not implemented yet")
last_occurrence[index_num] = operand;
dims_in_operand++;
AT_CHECK((int64_t) tensors.size()>operand, "more operands in equation than tensors"); // we cannot have a longer equation than operands. We need to check here before we use the dimension

int64_t ell_char_count = 0; // handling of ellipsis '...' is a bit tedious, we count the '.'
// if there is an ellipsis, the number of dimensions it represents must be total dim - letter dimensions
int64_t candidate_num_ell_idxes = tensors[operand].dim() - term.size() + 3;
int64_t dims_in_term = 0; // dimensions we have seen
std::vector<int64_t> current_op_idxes; // mapping of operand dimensions to indices for current term
for (auto &c : term) { // c = character with a single letter or '.'
if (c == '.') {
ell_char_count++;
AT_CHECK(ell_char_count <= 3, "can only have '.' in one ellispis '...' in term ", operand, " of the equation");
if (ell_char_count == 3) { // this completes the ellipsis
if (num_ell_idxes == -1) { // if we have not seen an ellipsis before, keep track of indices and size
first_ell_idx = num_total_idxes;
num_ell_idxes = candidate_num_ell_idxes;
num_total_idxes += num_ell_idxes;
}
else { // we have seen an ellipsis before, so we check compatibility
AT_CHECK(candidate_num_ell_idxes == num_ell_idxes,
"ellipsis must represent ", num_ell_idxes, " dimensions in all terms");
}
for (int64_t i = 0; i < num_ell_idxes; ++i) { // map ellipsis dimensions in operand to indices
current_op_idxes.push_back(first_ell_idx + i);
last_idx_occurrence.push_back(operand);
}
dims_in_term += num_ell_idxes; // keep track of dimensions
}
} else { // a letter (hopefully)
AT_CHECK((ell_char_count == 0) || (ell_char_count == 3), "'.' must only occur in ellipsis, operand ", operand);
AT_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices");
int64_t letter_num = c-'a'; // letter_num = position in letter_mapping
if (letter_mapping[letter_num] == -1) { // new letter, add internal index and mapping
letter_mapping[letter_num] = num_total_idxes;
num_total_idxes++;
last_idx_occurrence.push_back(operand);
} else { // letter we have already seen
last_idx_occurrence[letter_mapping[letter_num]] = operand;
}
num_letter_occurrences[letter_num]++;
current_op_idxes.push_back(letter_mapping[letter_num]);
dims_in_term++;
}
}
AT_CHECK((int64_t) tensors.size()>operand, "more operands in equation than tensors"); // we cannot have a longer equation than operands. We need to check here before we check the dimensions
AT_CHECK(dims_in_operand == tensors[operand].dim(),
"dimension mismatch for operand ", operand, ": equation ", dims_in_operand, ", tensor ", tensors[operand].dim());
AT_CHECK(dims_in_term == tensors[operand].dim(), "dimension mismatch for operand ", operand, ": equation ", dims_in_term, " tensor ", tensors[operand].dim());
input_op_idxes.push_back(std::move(current_op_idxes));
operand++;
}
AT_CHECK((int64_t) tensors.size()==operand, "more tensors than operands in equation"); // we need ==, but > is captured above, so the error message can be specific that it is <.
// in the check below, we need ==, but > is captured above, so the error message can be specific that it is <.
AT_CHECK((int64_t) tensors.size()==operand, "more tensors than operands in equation");

// the following parses or infers output (right hand side)
// it also assigns the sorted_positions ((letter) index -> dimension in Tensors) and position_labels (dimensions in Tensors -> index)
// for the output indices
std::array<std::int64_t, number_of_letters> sorted_position; // the position of the index in the tensor dimensions
sorted_position.fill(-1);
// it also assigns the idxes_to_preprocessed_dims (index -> dimension in preprocessed / output tensors)
// for the output indices. -1 means that the index has not been assigned a dimension yet
std::vector<int64_t> idxes_to_preprocessed_dims(num_total_idxes, -1); // the position of the index in the tensor dimensions
int64_t num_output_dims = 0;
std::vector<int64_t> position_labels;
if (pos != std::string::npos) { // parse the user provided right hand side
int64_t ell_char_count = 0;
for (auto &c : eqn.substr(pos+2)) {
AT_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices");
int64_t index_num = c-'a';
AT_CHECK(sorted_position[index_num] == -1, "index ", c, " occurs twice in output");
sorted_position[index_num] = num_output_dims;
position_labels.push_back(index_num);
num_output_dims++;
if (c == '.') { // '.' as part of ellipsis
ell_char_count++;
AT_CHECK(ell_char_count <= 3, "can only have '.' in one ellispis '...' in right hand side of the equation");
if (ell_char_count == 3) { // ellipsis complete
AT_CHECK(num_ell_idxes >= 0, "ellipsis '...' may only appear in right hand side if it does in left hand side");
for (int64_t i = 0; i < num_ell_idxes; ++i) {
idxes_to_preprocessed_dims[first_ell_idx + i] = num_output_dims;
num_output_dims++;
}
}
} else { // letter (hopefully)
AT_CHECK((ell_char_count == 0) || (ell_char_count == 3), "'.' must only occur in ellipsis in the right hand side");
AT_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices");
int64_t letter_num = c-'a';
AT_CHECK(idxes_to_preprocessed_dims[letter_mapping[letter_num]] == -1, "index ", c, "occurs twice in output");
idxes_to_preprocessed_dims[letter_mapping[letter_num]] = num_output_dims;
num_output_dims++;
}
}
} else { // create an inferred right hand side
// the ellipsis (if in the lhs) comes first
if (num_ell_idxes >= 0) {
for (int64_t i = 0; i < num_ell_idxes; ++i) {
idxes_to_preprocessed_dims[first_ell_idx + i] = num_output_dims;
num_output_dims++;
}
}
} else { // create a right hand side: the indices that occur exactly once in alphabetic order
// then the indices that occur exactly once in alphabetic order
for (size_t idx = 0; idx < number_of_letters; idx++) {
if (number_of_occurrences[idx] == 1) {
sorted_position[idx] = num_output_dims;
position_labels.push_back(idx);
num_output_dims++;
if (num_letter_occurrences[idx] == 1) {
idxes_to_preprocessed_dims[letter_mapping[idx]] = num_output_dims;
num_output_dims++;
}
}
}
// now we assign the sorted_positions ((letter) index -> dimension in Tensors) and position_labels (dimensions in Tensors -> index)
// now we assign the idxes_to_preprocessed_dims (index -> dimension in preprocessed / output tensors)
// for the non-output indices - those that are eventually summed over
int64_t position = num_output_dims; // we now determine the porder of the remaining indices (in so far they are in the equation)
for (size_t idx = 0; idx < number_of_letters; idx++) {
if ((number_of_occurrences[idx] > 0) && (sorted_position[idx]==-1)) {
sorted_position[idx] = position;
position_labels.push_back(idx);
int64_t position = num_output_dims;
for (int64_t i = 0; i < num_total_idxes; i++) {
if (idxes_to_preprocessed_dims[i]==-1) {
idxes_to_preprocessed_dims[i] = position;
position++;
}
}
// we now "homogenize the dimensions", i.e. create all dimensions in each tensor and sort the dimensions according to the mapping in
// sorted_postition / position_labels

// we now "homogenize the dimensions", i.e.
// - take diagonals for duplicated indices
// - permute the dimensions to match the order given by idxes_to_preprocessed_dims
// - unsqueeze to create all dimensions for each index in each tensor where they are missing
// we also check that sizes match
// after this, all operands will have compatible shapes (i.e. all dimensions are aligned are broadcastable)
std::vector<Tensor> permuted_ops;
eqn_stream.clear();
eqn_stream.seekg(0, std::ios_base::beg);
std::vector<Tensor> preprocessed_operands;
std::vector<std::int64_t> size_of_dims(num_total_idxes, -1); // keep track of sizes for each index, -1 means we have not seen a size yet
for (int64_t op = 0; op < (int64_t) tensors.size(); op++) {
std::array<int64_t, number_of_letters> axes; // the dimension which the letter refers to in the permuted tensor
axes.fill(-1);
std::vector<int64_t> permutation; // permutation for this tensor
std::getline(eqn_stream, term, ',');
int64_t dim = 0;
for (auto &c : term) {
int64_t index_num = c-'a';
axes[index_num] = dim;
dim++;
auto preprocessed_op = tensors[op];
std::vector<int64_t> idx_to_dim(num_total_idxes, -1); // the dimension which the index refers to in the original tensor, -1 means it does not appear
std::vector<int64_t>& current_op_input_idxes = input_op_idxes[op];
int64_t dim = 0; // there are two dimension indices: dim is after taking diagonals, i is in input
for (size_t i = 0; i < current_op_input_idxes.size(); i++) {
auto idx = current_op_input_idxes[i];
auto dim_out = idxes_to_preprocessed_dims[idx];
if (idx_to_dim[dim_out] == -1) { // first appearance
idx_to_dim[dim_out] = dim;
if (size_of_dims[idx] == -1) { // keep track of sizes
size_of_dims[idx] = preprocessed_op.size(dim);
}
else {
AT_CHECK(size_of_dims[idx] == preprocessed_op.size(dim), "size of dimension does not match previous size, operand ", op, ", dim ", i);
}
dim++;
} else { // duplicate dimension in tensor --> take diagonal of idx_to_dim[dim_out] and dim and put the diagonal dimension to idx_to_dim[dim_out]
AT_CHECK(size_of_dims[idx] == preprocessed_op.size(dim), "size of dimension does not match previous size, operand ", op, ", dim ", i);
preprocessed_op = preprocessed_op.diagonal(0, idx_to_dim[dim_out], dim);
// diagonal moves the diagonal dimension to the back
// now we permute the last dim back to idx_to_dim[dim_out]
std::vector<int64_t> perm(preprocessed_op.dim(), 0);
for (int64_t d = 0; d < preprocessed_op.dim(); d++) {
if (d == idx_to_dim[dim_out]) {
perm[d] = preprocessed_op.dim() - 1;
} else {
perm[d] = d - (d > idx_to_dim[dim_out]);
}
}
preprocessed_op = preprocessed_op.permute(perm);
}
}
for (auto &c : position_labels) {
if (axes[c] > -1) {
permutation.push_back(axes[c]);
// now we permute the dimensions in the right order
std::vector<int64_t> permutation; // permutation for this tensor
for (auto &d : idx_to_dim) {
if (d > -1) {
permutation.push_back(d);
}
}
permuted_ops.push_back(tensors[op].permute(permutation));
for (int64_t dim = 0; dim < (int64_t) position_labels.size(); dim++) {
auto c = position_labels[dim];
if (axes[c] == -1) {
permuted_ops.back().unsqueeze_(dim);
preprocessed_op = preprocessed_op.permute(permutation);
// finally, we insert dimensions for idxes not in the operand
for (size_t dim = 0; dim < idx_to_dim.size(); dim++) {
if (idx_to_dim[dim] == -1) {
preprocessed_op.unsqueeze_(dim);
}
}
preprocessed_operands.push_back(preprocessed_op);
}

// now we reduce the indices from left to right
// numpy allows to optimize the path using various
// algorithms (see eigen_path in numpy docs)
// we start with the leftmost operator and reduce indices that
// appear only there
Tensor result = permuted_ops[0];
for (size_t idx = 0; idx < number_of_letters; idx++) {
if ((last_occurrence[idx] == 0)
&& (sorted_position[idx]>=num_output_dims)) {
result = result.sum(sorted_position[idx], true);
Tensor result = preprocessed_operands[0];
for (int64_t idx = 0; idx < num_total_idxes; idx++) {
if ((last_idx_occurrence[idx] == 0)
&& (idxes_to_preprocessed_dims[idx]>=num_output_dims)) {
result = result.sum(idxes_to_preprocessed_dims[idx], true);
}
}

// now we process each tensor using sumproduct_pair
for (int64_t i = 1; i < (int64_t) permuted_ops.size(); i++) {
for (int64_t i = 1; i < (int64_t) preprocessed_operands.size(); i++) {
std::vector<int64_t> sum_dims;
for (size_t idx = 0; idx < number_of_letters; idx++) {
if ((last_occurrence[idx] == i)
&& (sorted_position[idx]>=num_output_dims)) {
sum_dims.push_back(sorted_position[idx]);
for (int64_t idx = 0; idx < num_total_idxes; idx++) {
if ((last_idx_occurrence[idx] == i)
&& (idxes_to_preprocessed_dims[idx]>=num_output_dims)) {
sum_dims.push_back(idxes_to_preprocessed_dims[idx]);
}
}
result = at::native::sumproduct_pair(result, permuted_ops[i], sum_dims, true);
result = at::native::sumproduct_pair(result, preprocessed_operands[i], sum_dims, true);
}
// finally, we squeeze out all non-result dimensions
for (int64_t dim = position_labels.size()-1; dim >= num_output_dims; dim--)
for (int64_t dim = num_total_idxes-1; dim >= num_output_dims; dim--)
result.squeeze_(dim);
return result;
}
Expand Down
14 changes: 12 additions & 2 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,8 @@ def test_einsum(self):
E = torch.randn(7, 9)
F = torch.randn(2, 3, 5, 7)
G = torch.randn(7, 11, 13)
H = torch.randn(4, 4)
I = torch.randn(3, 4, 4)
l = torch.randn(5, 10)
r = torch.randn(5, 20)
w = torch.randn(30, 10, 20)
Expand All @@ -1434,14 +1436,22 @@ def test_einsum(self):
("ijk,jk->ij", C, A), # tensor matrix contraction with double indices
("ijk,ik->j", C, B), # non contiguous
("ijk,ik->jk", C, B), # non contiguous with double indices
# -- Diagonal
("ii", H), # trace
("ii->i", H), # diagonal
# -- Ellipsis
("i...->...", H),
("ki,...k->i...", A.t(), B),
("k...,jk", A.t(), B),
("...ii->...i", I), # batch diagonal
# -- Other
("bn,anm,bm->ba", l, w, r), # as torch.bilinear
]
for test in test_list:
actual = torch.einsum(test[0], test[1:])
expected = np.einsum(test[0], *[t.numpy() for t in test[1:]])
self.assertEqual(expected.shape, actual.shape)
self.assertTrue(np.allclose(expected, actual.numpy()))
self.assertEqual(expected.shape, actual.shape, test[0])
self.assertTrue(np.allclose(expected, actual.numpy()), test[0])

def do_einsum(*args):
return torch.einsum(test[0], args)
Expand Down
Loading