Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions aten/src/ATen/native/sparse/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ SparseTensor new_with_dims_sparse(int64_t sparse_dim, int64_t dense_dim, ArrayRe
return self;
}

// Does NOT make copies of indices and values
SparseTensor new_with_dims_and_tensor_sparse(
int64_t sparse_dim,
int64_t dense_dim,
Expand All @@ -101,7 +100,16 @@ SparseTensor new_with_dims_and_tensor_sparse(
const TensorOptions& options) {
SparseTensor self = new_sparse(options);
get_sparse_impl(self)->resize_(sparse_dim, dense_dim, size);
alias_into_sparse(self, indices, values);
// NOTE: There is no guarantee that `indices` and `values` don't contain AutogradMeta. However,
// we want to maintain the invariant that `indices_` and `values_` of a sparse tensor don't
// contain AutogradMeta, and to achieve that we shallow-copy `indices` and `values` here.
auto indices_shallow_copy = LongTensor(indices.unsafeGetTensorImpl()->shallow_copy_and_detach(
/*version_counter=*/indices.unsafeGetTensorImpl()->version_counter(),
/*allow_tensor_metadata_change=*/true));
auto values_shallow_copy = Tensor(values.unsafeGetTensorImpl()->shallow_copy_and_detach(
/*version_counter=*/values.unsafeGetTensorImpl()->version_counter(),
/*allow_tensor_metadata_change=*/true));
alias_into_sparse(self, indices_shallow_copy, values_shallow_copy);
return self;
}

Expand Down
41 changes: 41 additions & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,6 +1957,47 @@ def do_test(t):
do_test(self.sparse_empty(3, 0).data)
do_test(self.sparse_empty(3, 0).detach())

def test_change_tensor_metadata(self):
i = self.index_tensor([[0], [1]])
v = self.value_tensor([[3, 4, 5]])
t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]))
i.resize_(2, 3)
v.resize_(4, 5)
self.assertEqual(list(t.coalesce().indices().size()), [2, 1])
self.assertEqual(list(t.coalesce().values().size()), [1, 3])

i = self.index_tensor([[0], [1]])
v = self.value_tensor([[3, 4, 5]])
t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]))
i.resize_as_(self.index_tensor([0, 1]))
v.resize_as_(self.value_tensor([3, 4, 5]))
self.assertEqual(list(t.coalesce().indices().size()), [2, 1])
self.assertEqual(list(t.coalesce().values().size()), [1, 3])

i = self.index_tensor([[0], [1]])
v = self.value_tensor([[3, 4, 5]])
t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]))
i.as_strided_((2, 1), (1, 1))
v.as_strided_((1, 3), (1, 1))
self.assertEqual(list(t.coalesce().indices().size()), [2, 1])
self.assertEqual(list(t.coalesce().values().size()), [1, 3])

i = self.index_tensor([[0], [1]])
v = self.value_tensor([[3, 4, 5]])
t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]))
i.set_(self.index_tensor([0, 1]))
v.set_(self.value_tensor([3, 4, 5]))
self.assertEqual(list(t.coalesce().indices().size()), [2, 1])
self.assertEqual(list(t.coalesce().values().size()), [1, 3])

i = self.index_tensor([[0], [1]])
v = self.value_tensor([[3, 4, 5]])
t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]))
i.transpose_(0, 1)
v.transpose_(0, 1)
self.assertEqual(list(t.coalesce().indices().size()), [2, 1])
self.assertEqual(list(t.coalesce().values().size()), [1, 3])


class TestUncoalescedSparse(TestSparse):
def setUp(self):
Expand Down