Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 97 additions & 46 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,61 +61,112 @@ static bool sizes_match_except(IntList s1, IntList s2, int64_t dim_except /* sho
return true;
}

// Check to see if the shape of tensors is compatible
// for being concatenated along a given dimension.
static void check_cat_sparse_dims(Tensor const &t,

This comment was marked as off-topic.

int64_t pos /* used only for debug messages */,
IntList sizes,
int64_t wrapped,
int64_t sparse_dim,
int64_t dense_dim) {
AT_CHECK(t.is_sparse(),
"Concatenating sparse tensors, but a dense tensor was found at position ", pos, ".");
AT_CHECK(sizes_match_except(sizes, t.sizes(), wrapped),
"All tensors must have the same shape: ", sizes, " (except in the concatenating dimension),"
" but found shape: ", t.sizes(), " at position ", pos, ".");
AT_CHECK(t.sparse_dim() == sparse_dim && t.dense_dim() == dense_dim,
"All tensors must have the same sparse_dim and dense_dim: ", sparse_dim, ", ", dense_dim,
", but tensor at position ", pos, " has ", t.sparse_dim(), ", ", t.dense_dim(), ".");
}

static Tensor cat_sparse(TensorList tensors, int64_t dim) {
std::vector<Tensor> indices;
std::vector<Tensor> values;
int64_t wrapped = maybe_wrap_dim(dim, tensors[0].dim());
int64_t sparse_dim = tensors[0].sparse_dim();
int64_t dense_dim = tensors[0].dense_dim();
// TODO - Make catting along dense dimensions work.
// it's possible to do so,
// but it involves creating a brand new values object
// for each nonzero index in each input tensor
// E.g.: catting [[1,2],[0,0]] and [[0,0],[3,4]]
// yields [[1,2,0,0],[0,0,3,4]]
AT_CHECK(wrapped < sparse_dim,
"Concatenating or stacking tensors of sparse dim ", sparse_dim, "along non-sparse dimension ", dim, " not supported.");
IntList sizes = tensors[0].sizes();
for (size_t i = 0; i < tensors.size(); ++i) {
auto const &t = tensors[i];
AT_CHECK(t.is_sparse(),
"Concatenating dense tensor at position ", i, " with sparse tensor(s) not supported.");
AT_CHECK(sizes_match_except(sizes, t.sizes(), wrapped),
"Concatenating tensor at position ", i, " of sizes ", t.sizes(), " with tensor of sizes ", sizes,
" along dimension ", dim, " not supported.");
AT_CHECK(t.sparse_dim() == sparse_dim && t.dense_dim() == dense_dim,
"Tensor at position ", i, " has dimension: sparse ", t.sparse_dim(), ", dense ", t.dense_dim(),
". Concatenating with tensor of dimensions ", sparse_dim, ", ", dense_dim, " not supported.");
indices.push_back(t._indices());
values.push_back(t._values());
}
Tensor idxs = native::cat(indices, 1);
Tensor vals = native::cat(values, 0);

// We now need to move the indices of each
// input tensor up along `dim` by an appropriate amount.
// E.g., if t1 has indices [[2,3,4],[5,6,7]],
// and sizes [10, 7]
// then torch.cat((t1,t1,t1),1) should have indices
// [[2,3,4,2,3,4,2,3,4],[5,6,7,12,13,14,19,20,21]],
// so we need to increase idxs[1][3:6] by 7
// and idxs[1][6:9] by 14.
int64_t col = 0;
int64_t cumulative_offset = 0;
for (size_t i = 0; i < tensors.size(); ++i) {
auto const &t = tensors[i];
int64_t this_piece_size = t._nnz();
// cumulative_offset is zero for the first piece, so
// don't waste time doing this operation unless i > 0.
if (i > 0) {
idxs[wrapped].narrow(0, col, this_piece_size) += cumulative_offset;
if (wrapped < sparse_dim) {
for (size_t i = 0; i < tensors.size(); ++i) {
auto const &t = tensors[i];
check_cat_sparse_dims(t, i, sizes, wrapped, sparse_dim, dense_dim);
indices.push_back(t._indices());
values.push_back(t._values());
}
Tensor idxs = at::cat(indices, 1);
Tensor vals = at::cat(values, 0);

// We now need to move the indices of each
// input tensor up along `dim` by an appropriate amount.
// E.g., if t1 has indices [[2,3,4],[5,6,7]],
// and sizes [10, 7]
// then torch.cat((t1,t1,t1),1) should have indices
// [[2,3,4,2,3,4,2,3,4],[5,6,7,12,13,14,19,20,21]],
// so we need to increase idxs[1][3:6] by 7
// and idxs[1][6:9] by 14.
int64_t col = 0;
int64_t cumulative_offset = 0;
for (size_t i = 0; i < tensors.size(); ++i) {
auto const &t = tensors[i];
int64_t this_piece_size = t._nnz();
// cumulative_offset is zero for the first piece, so
// don't waste time doing this operation unless i > 0.
if (i > 0) {
idxs[wrapped].narrow(0, col, this_piece_size) += cumulative_offset;
}
cumulative_offset += t.size(wrapped);
col += this_piece_size;
}
auto sizes_copy = sizes.vec();
sizes_copy[wrapped] = cumulative_offset;
return native::sparse_coo_tensor(idxs, vals, sizes_copy, tensors[0].options());
}
else {
// Catting along a dense dimension requires us to create new values.
// For illustration, consider the sparse 3d tensors t1 and t2,
// given by t1 = [[[1,2],[3,4]], ... (zeros) ..., [[5,6],[7,8]]]
// and t2 = [... (zeros) ..., [[9, 10], [11,12]], ... (zeros) ...],
// Their concatenation along dimension 2 is:
// [[[1,2,0,0],[3,4,0,0]], ... (zeros) ..., [[0,0,9,10],[0,0,11,12]], ... (zeros) ..., [[5,6,0,0],[7,8,0,0]]]
//
// Their values tensors are, respectively,
// [[[1,2],[3,4]],[[5,6],[7,8]]] and [[[9,10],[11,12]]].
//
// and so the values tensor of their concatenation along dim 2 will be:
// [[[1,2,0,0],[3,4,0,0]],[[5,6,0,0],[7,8,0,0]],[[0,0,9,10],[0,0,11,12]]]
//
// which we can get by taking the values tensor of each tensor, catting it with zeros of the appropriate size on the left and right,
// and then catting all those results together.

// The dimension in each tensor's values object that corresponds to the overall dimension along which we're catting.
int64_t values_dim = wrapped - sparse_dim + 1;
// The final size along the catted dimension.
int64_t total_size = std::accumulate(tensors.begin(), tensors.end(), 0, [values_dim](int64_t l, Tensor const &r) {
return l + r._values().size(values_dim);
});
auto zeros_sizes = tensors[0]._values().sizes().vec();
int64_t cumulative_size = 0;
std::vector<Tensor> vals_pieces;
std::vector<Tensor> idxs_pieces;
for (size_t i = 0; i < tensors.size(); ++i) {
auto const &t = tensors[i];
check_cat_sparse_dims(t, i, sizes, wrapped, sparse_dim, dense_dim);
// dimension 0 of values corresponds to the number of values,
// rather than to any logical dimension of the sparse tensor.
zeros_sizes[0] = t._values().size(0);
zeros_sizes[values_dim] = cumulative_size;
cumulative_size += t._values().size(values_dim);
auto z1 = native::zeros(zeros_sizes, t._values().options());
zeros_sizes[values_dim] = total_size - cumulative_size;
auto z2 = native::zeros(zeros_sizes, t._values().options());
vals_pieces.push_back(native::cat({z1, t._values(), z2}, values_dim));
idxs_pieces.push_back(t._indices());
}
cumulative_offset += t.size(wrapped);
col += this_piece_size;
auto sizes_copy = sizes.vec();
sizes_copy[wrapped] = total_size;
// This can create an uncoalesced tensor
return native::sparse_coo_tensor(native::cat(idxs_pieces, 1), native::cat(vals_pieces), sizes_copy, tensors[0].options());

This comment was marked as off-topic.

}
auto sizes_copy = sizes.vec();
sizes_copy[wrapped] = cumulative_offset;
return native::sparse_coo_tensor(idxs, vals, sizes_copy, tensors[0].options());
}

Tensor cat(TensorList tensors, int64_t dim) {
Expand Down
12 changes: 7 additions & 5 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,15 +694,17 @@ def test_shapes(shapes, dim, fail_message=None):

# mismatched sizes
test_shapes([(3, 10, [2, 3, 4]), (3, 10, [2, 1, 4])], 0,
"Concatenating tensor.*of sizes \\[2, 1, 4].*of sizes \\[2, 3, 4]")
"All tensors must have the same shape: \\[2, 3, 4].*\\[2, 1, 4]")
# hybrid sparse/dense
test_shapes(
[(2, 10, [2, 3, 4]), (2, 10, [2, 1, 4]), (2, 10, [2, 4, 4])], 1)
test_shapes([(2, 10, [2, 3, 4]), (2, 10, [2, 1, 4])], 2,
"Concatenating.*along non-sparse dimension 2")
# cat along dense dim
test_shapes([(2, 10, [2, 3, 4]), (2, 10, [2, 3, 7])], 2)
test_shapes([(1, 10, [2, 3, 4]), (1, 10, [2, 3, 4])], 1)
test_shapes([(1, 10, [2, 3, 4]), (1, 10, [2, 3, 4])], 2)
# mismatched dimensions
test_shapes([(2, 10, [2, 3, 4]), (3, 10, [2, 3, 4])], 0,
"has dimension: sparse 3, dense 0.*Concatenating with tensor of dimensions 2, 1")
"All tensors must have the same.*2, 1, but tensor at position 1 has 3, 0.")
# wrapped dimension
test_shapes(
[(3, 10, [2, 3, 4]), (3, 10, [2, 1, 4]), (3, 10, [2, 4, 4])], -2)
Expand All @@ -711,7 +713,7 @@ def test_shapes(shapes, dim, fail_message=None):
sp = self._gen_sparse(3, 10, [2, 3, 4])[0]
dn = sp.to_dense()
with self.assertRaisesRegex(RuntimeError,
"Concatenating dense tensor.*with sparse"):
"Concatenating sparse tensors, but a dense tensor was found at position 1."):
torch.cat((sp, dn))

@skipIfRocm
Expand Down