Skip to content

Commit e1ea8e9

Browse files
committed
Update on "Support copy_ for Sparse Compressed tensors."
It also fixes bugs in validating sparse block compressed indices and generating sparse block tensors. [ghstack-poisoned]
2 parents f8fcefd + f5b505b commit e1ea8e9

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

aten/src/ATen/native/sparse/SparseCsrTensor.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -522,17 +522,30 @@ Tensor& copy_sparse_compressed_(Tensor& self, const Tensor& src, bool non_blocki
522522
TORCH_CHECK(
523523
self._nnz() == src._nnz(), // actually, values copy allows different shapes as long as operands are broadcastable
524524
"torch.copy_: only sparse compressed tensors with the same number of specified elements are supported.");
525-
auto self_compressed_dims = self.size(compressedDimension(self.layout(), self.sizes()));
525+
auto self_compressed_dim = compressedDimension(self.layout(), self.sizes());
526+
auto src_compressed_dim = compressedDimension(src.layout(), src.sizes());
527+
auto self_compressed_dims = self.size(self_compressed_dim);
526528
auto src_compressed_dims = src.size(compressedDimension(src.layout(), src.sizes()));
527-
TORCH_CHECK(self_compressed_dims == src_compressed_dims,
528-
"torch.copy_: only sparse compressed tensors with the same number of compressed dimensions are supported.",
529-
" self and src compressed dimensions are ",
530-
self_compressed_dims, " and ", src_compressed_dims, ", respectively.");
529+
if (self_compressed_dim == src_compressed_dim) {
530+
TORCH_CHECK(self_compressed_dims == src_compressed_dims,
531+
"torch.copy_: expected shapes of self and src to match along dimension ",
532+
self_compressed_dim, " for ",
533+
self.layout(), " layout but the corresponding dimensions of self and src are ",
534+
self_compressed_dims, " and ", src_compressed_dims, ", respecitvely.");
535+
} else {
536+
TORCH_CHECK(self_compressed_dims == src_compressed_dims,
537+
"torch.copy_: expected shapes of self and src to match along dimensions ",
538+
self_compressed_dim, " and ", src_compressed_dim, ", respectively, for ",
539+
self.layout(), " layout but the corresponding dimensions of self and src are ",
540+
self_compressed_dims, " and ", src_compressed_dims, ", respecitvely.");
541+
}
531542
AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_",
532543
[&]{},
533544
[&]{
534-
auto self_block_size = DimVector(self.values().sizes().slice(self.values().dim()-2, 2));
535-
auto src_block_size = DimVector(src.values().sizes().slice(src.values().dim()-2, 2));
545+
auto self_values = self.values();
546+
auto src_values = src.values();
547+
auto self_block_size = DimVector(self_values.sizes().slice(self_values.dim()-2, 2));
548+
auto src_block_size = DimVector(src_values.sizes().slice(src_values.dim()-2, 2));
536549
TORCH_CHECK(self_block_size == src_block_size,
537550
"torch.copy_: copy of sparse compressed tensors having different block sizes is not supported.",
538551
" self and src block sizes are ", self_block_size, " and ", src_block_size, ", respectivly.");

test/test_sparse_csr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def test_copy_errors(self, layout, device, dtype):
410410
index_dtype=index_dtype, block_size=block_size)
411411
with self.assertRaisesRegex(
412412
RuntimeError,
413-
"only sparse compressed tensors with the same number of compressed dimensions are supported."):
413+
"expected shapes of self and src to match along dimension"):
414414
b.copy_(c)
415415

416416
if block_size:

0 commit comments

Comments
 (0)