Skip to content

Commit bd7890b

Browse files
pearufacebook-github-bot
authored andcommitted
Support copy_ for Sparse Compressed tensors. (#77605)
Summary: Pull Request resolved: #77605 Approved by: https://github.com/cpuhrsch Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/8b5f11c61eecd58214c631056a634f2eedc6455a Reviewed By: seemethere Differential Revision: D36494385 Pulled By: seemethere fbshipit-source-id: 103bffbddfecce3aaa728f06c8f5c2d16f0a0667
1 parent c1c1e32 commit bd7890b

File tree

5 files changed

+109
-48
lines changed

5 files changed

+109
-48
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1419,7 +1419,7 @@
14191419
MkldnnCPU: copy_mkldnn_
14201420
SparseCPU, SparseCUDA: copy_sparse_wrapper_
14211421
CompositeExplicitAutograd: copy_
1422-
SparseCsrCPU, SparseCsrCUDA: copy_sparse_csr_
1422+
SparseCsrCPU, SparseCsrCUDA: copy_sparse_compressed_
14231423

14241424
- func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor
14251425
dispatch:

aten/src/ATen/native/sparse/SparseCsrTensor.cpp

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind
170170
Tensor compressed_indices_cpu = compressed_indices.to(kCPU);
171171
auto compressed_indices_data_ptr = compressed_indices_cpu.data_ptr<index_t>();
172172
auto batch_stride = compressed_indices_cpu.dim() >= 2 ? compressed_indices_cpu.stride(-2) : 0;
173-
173+
auto compressed_dims = size[compressedDimension(layout, size)];
174174
for (const auto batch_id : c10::irange(batchCount(compressed_indices_cpu))) {
175175
TORCH_CHECK(
176176
compressed_indices_data_ptr[batch_id*batch_stride] == 0,
@@ -180,7 +180,7 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind
180180
compressed_indices_data_ptr[batch_id*batch_stride + compressed_indices.size(-1) - 1] == plain_indices.size(-1),
181181
"(Batch element ", batch_id, ") ",
182182
"last value of ", compressed_indices_name, " should be equal to the length of ", plain_indices_name, ".");
183-
for (int i = 1; i <= size[size.size() - 2]; i++) {
183+
for (int i = 1; i <= compressed_dims; i++) {
184184
TORCH_CHECK(
185185
compressed_indices_data_ptr[batch_id*batch_stride + i - 1] <= compressed_indices_data_ptr[batch_id*batch_stride + i],
186186
"(Batch element ", batch_id, ") ",
@@ -513,18 +513,52 @@ const Tensor& resize_sparse_csr_(
513513
return self;
514514
}
515515

516-
Tensor& copy_sparse_csr_(Tensor& self, const Tensor& src, bool non_blocking) {
516+
Tensor& copy_sparse_compressed_(Tensor& self, const Tensor& src, bool non_blocking) {
517+
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_", [&]{});
517518
TORCH_CHECK(
518-
self.is_sparse_csr() && src.is_sparse_csr(),
519-
"copy_sparse_csr_: copy between different layouts is not supported. Found self type = ",
520-
self.toString(),
521-
" and src type = ",
522-
src.toString());
519+
self.layout() == src.layout(),
520+
"torch.copy_: copy of sparse compressed tensors having different layouts is not supported.",
521+
" self layout is ", self.layout(), " and src layout is ", src.layout());
523522
TORCH_CHECK(
524-
self._nnz() == src._nnz(),
525-
"copy_sparse_csr_: only tensors with the same number of specified elements are supported.");
526-
self.crow_indices().copy_(src.crow_indices(), non_blocking);
527-
self.col_indices().copy_(src.col_indices(), non_blocking);
523+
self._nnz() == src._nnz(), // actually, values copy allows different shapes as long as operands are broadcastable
524+
"torch.copy_: only sparse compressed tensors with the same number of specified elements are supported.");
525+
auto self_compressed_dim = compressedDimension(self.layout(), self.sizes());
526+
auto src_compressed_dim = compressedDimension(src.layout(), src.sizes());
527+
auto self_compressed_dims = self.size(self_compressed_dim);
528+
auto src_compressed_dims = src.size(compressedDimension(src.layout(), src.sizes()));
529+
if (self_compressed_dim == src_compressed_dim) {
530+
TORCH_CHECK(self_compressed_dims == src_compressed_dims,
531+
"torch.copy_: expected shapes of self and src to match along dimension ",
532+
self_compressed_dim, " for ",
533+
self.layout(), " layout but the corresponding dimensions of self and src are ",
534+
self_compressed_dims, " and ", src_compressed_dims, ", respecitvely.");
535+
} else {
536+
TORCH_CHECK(self_compressed_dims == src_compressed_dims,
537+
"torch.copy_: expected shapes of self and src to match along dimensions ",
538+
self_compressed_dim, " and ", src_compressed_dim, ", respectively, for ",
539+
self.layout(), " layout but the corresponding dimensions of self and src are ",
540+
self_compressed_dims, " and ", src_compressed_dims, ", respecitvely.");
541+
}
542+
AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_",
543+
[&]{},
544+
[&]{
545+
auto self_values = self.values();
546+
auto src_values = src.values();
547+
auto self_block_size = DimVector(self_values.sizes().slice(self_values.dim()-2, 2));
548+
auto src_block_size = DimVector(src_values.sizes().slice(src_values.dim()-2, 2));
549+
TORCH_CHECK(self_block_size == src_block_size,
550+
"torch.copy_: copy of sparse compressed tensors having different block sizes is not supported.",
551+
" self and src block sizes are ", self_block_size, " and ", src_block_size, ", respectivly.");
552+
});
553+
AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_",
554+
[&]{
555+
self.crow_indices().copy_(src.crow_indices(), non_blocking);
556+
self.col_indices().copy_(src.col_indices(), non_blocking);
557+
},
558+
[&]{
559+
self.ccol_indices().copy_(src.ccol_indices(), non_blocking);
560+
self.row_indices().copy_(src.row_indices(), non_blocking);
561+
});
528562
self.values().copy_(src.values(), non_blocking);
529563
return self;
530564
}

aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ Tensor& unary_op_out(F op_out, const Tensor& self, Tensor& result) {
148148
if (result.numel() == 0) {
149149
at::native::resize_as_sparse_csr_(result, self);
150150
}
151-
// copy_sparse_csr_ internally checks the sizes of result and self tensors
151+
// copy_sparse_compressed_ internally checks the sizes of result and self tensors
152152
// Hence no external size check required
153-
at::native::copy_sparse_csr_(result, self);
153+
at::native::copy_sparse_compressed_(result, self);
154154
}
155155

156156
auto self_values = self.values();

test/test_sparse_csr.py

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,64 @@ def test_print(self, layout, device):
363363
self.maxDiff = orig_maxDiff
364364
raise
365365

366+
@skipMeta
367+
@all_sparse_compressed_layouts()
368+
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
369+
def test_copy(self, layout, device, dtype):
370+
371+
def run_test(shape, nnz, index_type):
372+
block_size = (2, 3) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
373+
a = self.genSparseCompressedTensor(shape, nnz, dtype=dtype, layout=layout, device=device,
374+
index_dtype=index_dtype, block_size=block_size)
375+
b = self.genSparseCompressedTensor(shape, nnz, dtype=dtype, layout=layout, device=device,
376+
index_dtype=index_dtype, block_size=block_size)
377+
378+
a.copy_(b)
379+
380+
self.assertEqual(a, b)
381+
382+
ns = [5, 2, 0]
383+
batch_shapes = [(), (2,), (2, 3)]
384+
for (m, n, b), index_dtype in zip(itertools.product(ns, ns, batch_shapes), [torch.int32, torch.int64]):
385+
run_test((*b, m, n), 0, index_dtype)
386+
run_test((*b, m, n), m * n, index_dtype)
387+
388+
@skipMeta
389+
@all_sparse_compressed_layouts()
390+
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
391+
def test_copy_errors(self, layout, device, dtype):
392+
block_size = (2, 3) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
393+
for index_dtype in [torch.int32, torch.int64]:
394+
shape1 = (2, 3)
395+
a = self.genSparseCompressedTensor(shape1, 0, dtype=dtype, layout=layout, device=device,
396+
index_dtype=index_dtype, block_size=block_size)
397+
398+
with self.assertRaisesRegex(RuntimeError,
399+
"copy of sparse compressed tensors having different layouts is not supported."):
400+
a.copy_(torch.empty(a.shape, dtype=dtype, device=device))
401+
402+
b = self.genSparseCompressedTensor(shape1, 1, dtype=dtype, layout=layout, device=device,
403+
index_dtype=index_dtype, block_size=block_size)
404+
with self.assertRaisesRegex(RuntimeError,
405+
"only sparse compressed tensors with the same number of specified elements are supported."):
406+
a.copy_(b)
407+
408+
shape2 = tuple(reversed(shape1))
409+
c = self.genSparseCompressedTensor(shape2, 1, dtype=dtype, layout=layout, device=device,
410+
index_dtype=index_dtype, block_size=block_size)
411+
with self.assertRaisesRegex(
412+
RuntimeError,
413+
"expected shapes of self and src to match along dimension"):
414+
b.copy_(c)
415+
416+
if block_size:
417+
block_size1 = tuple(reversed(block_size))
418+
d = self.genSparseCompressedTensor(shape1, 1, dtype=dtype, layout=layout, device=device,
419+
index_dtype=index_dtype, block_size=block_size1)
420+
with self.assertRaisesRegex(RuntimeError,
421+
"copy of sparse compressed tensors having different block sizes is not supported"):
422+
b.copy_(d)
423+
366424

367425
class TestSparseCSR(TestCase):
368426

@@ -435,38 +493,6 @@ def test_sparse_csr_select(self, device, dtype):
435493
with self.assertRaisesRegex(TypeError, "Cannot assign to a sparse tensor"):
436494
sparse[0, 0, 0, 0] = 99.0
437495

438-
@skipMeta
439-
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
440-
def test_copy(self, device, dtype):
441-
442-
def run_test(shape, nnz, index_type):
443-
a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
444-
b = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
445-
446-
a.copy_(b)
447-
448-
self.assertEqual(a, b)
449-
450-
ns = [5, 2, 0]
451-
batch_shapes = [(), (2,), (2, 3)]
452-
for (m, n, b), index_dtype in zip(itertools.product(ns, ns, batch_shapes), [torch.int32, torch.int64]):
453-
run_test((*b, m, n), 0, index_dtype)
454-
run_test((*b, m, n), m * n, index_dtype)
455-
456-
@skipMeta
457-
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
458-
def test_copy_errors(self, device, dtype):
459-
for index_dtype in [torch.int32, torch.int64]:
460-
shape1 = (2, 3)
461-
a = self.genSparseCSRTensor(shape1, 0, dtype=dtype, device=device, index_dtype=index_dtype)
462-
463-
with self.assertRaisesRegex(RuntimeError, "copy between different layouts is not supported."):
464-
a.copy_(torch.empty(a.shape, dtype=dtype, device=device))
465-
466-
b = self.genSparseCSRTensor(shape1, 1, dtype=dtype, device=device, index_dtype=index_dtype)
467-
with self.assertRaisesRegex(RuntimeError, "only tensors with the same number of specified elements are supported."):
468-
a.copy_(b)
469-
470496
@skipMeta
471497
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
472498
def test_resize(self, device, dtype):

torch/testing/_internal/common_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2061,7 +2061,8 @@ def random_sparse_compressed(n_compressed_dims, n_plain_dims, nnz):
20612061
n_compressed_dims, n_plain_dims = size[-1], size[-2]
20622062
sparse_tensors = [random_sparse_compressed(n_compressed_dims, n_plain_dims, nnz) for _ in range(n_batch)]
20632063
sparse_tensors_it = map(list, zip(*sparse_tensors))
2064-
values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
2064+
2065+
values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, nnz, *block_size)
20652066
compressed_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
20662067
plain_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
20672068

0 commit comments

Comments
 (0)