Skip to content

Commit 090ff9e

Browse files
committed
Support copy_ for Sparse Compressed tensors.
ghstack-source-id: 5d17a70 Pull Request resolved: #77605
1 parent 38f03f6 commit 090ff9e

File tree

5 files changed

+85
-47
lines changed

5 files changed

+85
-47
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1419,7 +1419,7 @@
14191419
MkldnnCPU: copy_mkldnn_
14201420
SparseCPU, SparseCUDA: copy_sparse_wrapper_
14211421
CompositeExplicitAutograd: copy_
1422-
SparseCsrCPU, SparseCsrCUDA: copy_sparse_csr_
1422+
SparseCsrCPU, SparseCsrCUDA: copy_sparse_compressed_
14231423

14241424
- func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor
14251425
dispatch:

aten/src/ATen/native/sparse/SparseCsrTensor.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind
170170
Tensor compressed_indices_cpu = compressed_indices.to(kCPU);
171171
auto compressed_indices_data_ptr = compressed_indices_cpu.data_ptr<index_t>();
172172
auto batch_stride = compressed_indices_cpu.dim() >= 2 ? compressed_indices_cpu.stride(-2) : 0;
173-
173+
auto compressed_dims = size[compressedDimension(layout, size)];
174174
for (const auto batch_id : c10::irange(batchCount(compressed_indices_cpu))) {
175175
TORCH_CHECK(
176176
compressed_indices_data_ptr[batch_id*batch_stride] == 0,
@@ -180,7 +180,7 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind
180180
compressed_indices_data_ptr[batch_id*batch_stride + compressed_indices.size(-1) - 1] == plain_indices.size(-1),
181181
"(Batch element ", batch_id, ") ",
182182
"last value of ", compressed_indices_name, " should be equal to the length of ", plain_indices_name, ".");
183-
for (int i = 1; i <= size[size.size() - 2]; i++) {
183+
for (int i = 1; i <= compressed_dims; i++) {
184184
TORCH_CHECK(
185185
compressed_indices_data_ptr[batch_id*batch_stride + i - 1] <= compressed_indices_data_ptr[batch_id*batch_stride + i],
186186
"(Batch element ", batch_id, ") ",
@@ -513,18 +513,29 @@ const Tensor& resize_sparse_csr_(
513513
return self;
514514
}
515515

516-
Tensor& copy_sparse_csr_(Tensor& self, const Tensor& src, bool non_blocking) {
516+
Tensor& copy_sparse_compressed_(Tensor& self, const Tensor& src, bool non_blocking) {
517+
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_", [&]{});
517518
TORCH_CHECK(
518-
self.is_sparse_csr() && src.is_sparse_csr(),
519-
"copy_sparse_csr_: copy between different layouts is not supported. Found self type = ",
520-
self.toString(),
521-
" and src type = ",
522-
src.toString());
519+
self.layout() == src.layout(),
520+
"torch.copy_: copy of sparse compressed tensors having different layouts is not supported.",
521+
" self layout is ", self.layout(), " and src layout is ", src.layout());
523522
TORCH_CHECK(
524523
self._nnz() == src._nnz(),
525-
"copy_sparse_csr_: only tensors with the same number of specified elements are supported.");
526-
self.crow_indices().copy_(src.crow_indices(), non_blocking);
527-
self.col_indices().copy_(src.col_indices(), non_blocking);
524+
"torch.copy_: only sparse compressed tensors with the same number of specified elements are supported.");
525+
TORCH_CHECK(self.size(compressedDimension(self.layout(), self.sizes())) == src.size(compressedDimension(src.layout(), src.sizes())),
526+
"torch.copy_: only sparse compressed tensors with the same number of compressed dimensions are supported.");
527+
TORCH_CHECK(self.values().sizes() == src.values().sizes(),
528+
"torch.copy_: only sparse compressed tensors with the same values shape are supported.");
529+
530+
AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_",
531+
[&]{
532+
self.crow_indices().copy_(src.crow_indices(), non_blocking);
533+
self.col_indices().copy_(src.col_indices(), non_blocking);
534+
},
535+
[&]{
536+
self.ccol_indices().copy_(src.ccol_indices(), non_blocking);
537+
self.row_indices().copy_(src.row_indices(), non_blocking);
538+
});
528539
self.values().copy_(src.values(), non_blocking);
529540
return self;
530541
}

aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ Tensor& unary_op_out(F op_out, const Tensor& self, Tensor& result) {
148148
if (result.numel() == 0) {
149149
at::native::resize_as_sparse_csr_(result, self);
150150
}
151-
// copy_sparse_csr_ internally checks the sizes of result and self tensors
151+
// copy_sparse_compressed_ internally checks the sizes of result and self tensors
152152
// Hence no external size check required
153-
at::native::copy_sparse_csr_(result, self);
153+
at::native::copy_sparse_compressed_(result, self);
154154
}
155155

156156
auto self_values = self.values();

test/test_sparse_csr.py

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,64 @@ def test_print(self, layout, device):
363363
self.maxDiff = orig_maxDiff
364364
raise
365365

366+
@skipMeta
367+
@all_sparse_compressed_layouts()
368+
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
369+
def test_copy(self, layout, device, dtype):
370+
371+
def run_test(shape, nnz, index_type):
372+
block_size = (2, 3) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
373+
a = self.genSparseCompressedTensor(shape, nnz, dtype=dtype, layout=layout, device=device,
374+
index_dtype=index_dtype, block_size=block_size)
375+
b = self.genSparseCompressedTensor(shape, nnz, dtype=dtype, layout=layout, device=device,
376+
index_dtype=index_dtype, block_size=block_size)
377+
378+
a.copy_(b)
379+
380+
self.assertEqual(a, b)
381+
382+
ns = [5, 2, 0]
383+
batch_shapes = [(), (2,), (2, 3)]
384+
for (m, n, b), index_dtype in zip(itertools.product(ns, ns, batch_shapes), [torch.int32, torch.int64]):
385+
run_test((*b, m, n), 0, index_dtype)
386+
run_test((*b, m, n), m * n, index_dtype)
387+
388+
@skipMeta
389+
@all_sparse_compressed_layouts()
390+
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
391+
def test_copy_errors(self, layout, device, dtype):
392+
block_size = (2, 3) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
393+
for index_dtype in [torch.int32, torch.int64]:
394+
shape1 = (2, 3)
395+
a = self.genSparseCompressedTensor(shape1, 0, dtype=dtype, layout=layout, device=device,
396+
index_dtype=index_dtype, block_size=block_size)
397+
398+
with self.assertRaisesRegex(RuntimeError,
399+
"copy of sparse compressed tensors having different layouts is not supported."):
400+
a.copy_(torch.empty(a.shape, dtype=dtype, device=device))
401+
402+
b = self.genSparseCompressedTensor(shape1, 1, dtype=dtype, layout=layout, device=device,
403+
index_dtype=index_dtype, block_size=block_size)
404+
with self.assertRaisesRegex(RuntimeError,
405+
"only sparse compressed tensors with the same number of specified elements are supported."):
406+
a.copy_(b)
407+
408+
shape2 = tuple(reversed(shape1))
409+
c = self.genSparseCompressedTensor(shape2, 1, dtype=dtype, layout=layout, device=device,
410+
index_dtype=index_dtype, block_size=block_size)
411+
with self.assertRaisesRegex(
412+
RuntimeError,
413+
"only sparse compressed tensors with the same number of compressed dimensions are supported."):
414+
b.copy_(c)
415+
416+
if block_size:
417+
block_size1 = tuple(reversed(block_size))
418+
d = self.genSparseCompressedTensor(shape1, 1, dtype=dtype, layout=layout, device=device,
419+
index_dtype=index_dtype, block_size=block_size1)
420+
with self.assertRaisesRegex(RuntimeError,
421+
"only sparse compressed tensors with the same values shape are supported."):
422+
b.copy_(d)
423+
366424

367425
class TestSparseCSR(TestCase):
368426

@@ -435,38 +493,6 @@ def test_sparse_csr_select(self, device, dtype):
435493
with self.assertRaisesRegex(TypeError, "Cannot assign to a sparse tensor"):
436494
sparse[0, 0, 0, 0] = 99.0
437495

438-
@skipMeta
439-
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
440-
def test_copy(self, device, dtype):
441-
442-
def run_test(shape, nnz, index_type):
443-
a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
444-
b = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
445-
446-
a.copy_(b)
447-
448-
self.assertEqual(a, b)
449-
450-
ns = [5, 2, 0]
451-
batch_shapes = [(), (2,), (2, 3)]
452-
for (m, n, b), index_dtype in zip(itertools.product(ns, ns, batch_shapes), [torch.int32, torch.int64]):
453-
run_test((*b, m, n), 0, index_dtype)
454-
run_test((*b, m, n), m * n, index_dtype)
455-
456-
@skipMeta
457-
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
458-
def test_copy_errors(self, device, dtype):
459-
for index_dtype in [torch.int32, torch.int64]:
460-
shape1 = (2, 3)
461-
a = self.genSparseCSRTensor(shape1, 0, dtype=dtype, device=device, index_dtype=index_dtype)
462-
463-
with self.assertRaisesRegex(RuntimeError, "copy between different layouts is not supported."):
464-
a.copy_(torch.empty(a.shape, dtype=dtype, device=device))
465-
466-
b = self.genSparseCSRTensor(shape1, 1, dtype=dtype, device=device, index_dtype=index_dtype)
467-
with self.assertRaisesRegex(RuntimeError, "only tensors with the same number of specified elements are supported."):
468-
a.copy_(b)
469-
470496
@skipMeta
471497
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
472498
def test_resize(self, device, dtype):

torch/testing/_internal/common_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2059,7 +2059,8 @@ def random_sparse_compressed(n_compressed_dims, n_plain_dims, nnz):
20592059
n_compressed_dims, n_plain_dims = size[-1], size[-2]
20602060
sparse_tensors = [random_sparse_compressed(n_compressed_dims, n_plain_dims, nnz) for _ in range(n_batch)]
20612061
sparse_tensors_it = map(list, zip(*sparse_tensors))
2062-
values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
2062+
2063+
values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, nnz, *block_size)
20632064
compressed_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
20642065
plain_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
20652066

0 commit comments

Comments
 (0)