Skip to content

Commit dc882ed

Browse files
pearupytorchmergebot
authored andcommitted
Add Sparse Compressed tensor support to torch.clone
Pull Request resolved: #77512 Approved by: https://github.com/cpuhrsch
1 parent 98a20eb commit dc882ed

File tree

4 files changed

+51
-53
lines changed

4 files changed

+51
-53
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5290,7 +5290,7 @@
52905290
dispatch:
52915291
CompositeExplicitAutograd: clone
52925292
SparseCPU, SparseCUDA: clone_sparse
5293-
SparseCsrCPU, SparseCsrCUDA: clone_sparse_csr
5293+
SparseCsrCPU, SparseCsrCUDA: clone_sparse_compressed
52945294
MkldnnCPU: mkldnn_clone
52955295
QuantizedCPU, QuantizedCUDA: quantized_clone
52965296

aten/src/ATen/native/sparse/SparseCsrTensor.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -584,23 +584,31 @@ const SparseCsrTensor& resize_as_sparse_csr_(
584584
return self;
585585
}
586586

587-
SparseCsrTensor clone_sparse_csr(
588-
const SparseCsrTensor& self,
589-
c10::optional<c10::MemoryFormat> optional_memory_format) {
587+
SparseCsrTensor clone_sparse_compressed(
588+
const SparseCsrTensor& self,
589+
c10::optional<c10::MemoryFormat> optional_memory_format) {
590590
TORCH_CHECK(
591591
!optional_memory_format.has_value(),
592592
"unsupported memory format option ",
593593
optional_memory_format.value());
594594
TensorOptions options = self.options();
595-
return at::native::_sparse_csr_tensor_unsafe(
596-
self.crow_indices().clone(),
597-
self.col_indices().clone(),
598-
self.values().clone(),
599-
self.sizes(),
600-
optTypeMetaToScalarType(options.dtype_opt()),
601-
options.layout_opt(),
602-
options.device_opt(),
603-
options.pinned_memory_opt());
595+
auto compressed_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(),
596+
"clone_sparse_compressed",
597+
[&]{ return self.crow_indices(); },
598+
[&]{ return self.ccol_indices(); });
599+
auto plain_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(),
600+
"clone_sparse_compressed",
601+
[&]{ return self.col_indices(); },
602+
[&]{ return self.row_indices(); });
603+
return at::native::_sparse_compressed_tensor_unsafe(
604+
compressed_indices.clone(),
605+
plain_indices.clone(),
606+
self.values().clone(),
607+
self.sizes(),
608+
optTypeMetaToScalarType(options.dtype_opt()),
609+
options.layout_opt(),
610+
options.device_opt(),
611+
options.pinned_memory_opt());
604612
}
605613

606614
Tensor empty_like_sparse_csr(

test/test_sparse_csr.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ def _generate_small_inputs(self, layout, device, dtype, index_dtype):
167167
The input is defined as a 4-tuple:
168168
compressed_indices, plain_indices, values, expected_size_from_shape_inference
169169
"""
170-
batch_shape = (2, 3)
170+
from operator import mul
171+
from functools import reduce
171172
if layout in {torch.sparse_csr, torch.sparse_csc}:
172173
yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype),
173174
torch.tensor([0, 1, 0, 1], device=device, dtype=index_dtype),
@@ -177,10 +178,12 @@ def _generate_small_inputs(self, layout, device, dtype, index_dtype):
177178
torch.tensor([], device=device, dtype=index_dtype),
178179
torch.tensor([], device=device, dtype=dtype),
179180
(0, 0))
180-
yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype).repeat(6, 1).reshape(*batch_shape, -1),
181-
torch.tensor([0, 1, 0, 1], device=device, dtype=index_dtype).repeat(6, 1).reshape(*batch_shape, -1),
182-
torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1),
183-
(*batch_shape, 2, 2))
181+
for batch_shape in {(2, 3), (2,)}:
182+
prod = reduce(mul, batch_shape, 1)
183+
yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype).repeat(prod, 1).reshape(*batch_shape, -1),
184+
torch.tensor([0, 1, 0, 1], device=device, dtype=index_dtype).repeat(prod, 1).reshape(*batch_shape, -1),
185+
torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(prod, 1).reshape(*batch_shape, -1),
186+
(*batch_shape, 2, 2))
184187
else:
185188
assert layout in {torch.sparse_bsr, torch.sparse_bsc}
186189
yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype),
@@ -191,11 +194,13 @@ def _generate_small_inputs(self, layout, device, dtype, index_dtype):
191194
torch.tensor([], device=device, dtype=index_dtype),
192195
torch.tensor([], device=device, dtype=dtype).reshape(1, 0, 0),
193196
(0, 0))
194-
yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype).repeat(6, 1).reshape(*batch_shape, -1),
195-
torch.tensor([0, 1, 0, 1], device=device, dtype=index_dtype).repeat(6, 1).reshape(*batch_shape, -1),
196-
torch.tensor([[[1, 11]], [[2, 22]], [[3, 33]], [[4, 44]]],
197-
device=device, dtype=dtype).repeat(6, 1, 1).reshape(*batch_shape, 4, 1, 2),
198-
(*batch_shape, 2, 2))
197+
for batch_shape in {(2, 3), (2,)}:
198+
prod = reduce(mul, batch_shape, 1)
199+
yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype).repeat(prod, 1).reshape(*batch_shape, -1),
200+
torch.tensor([0, 1, 0, 1], device=device, dtype=index_dtype).repeat(prod, 1).reshape(*batch_shape, -1),
201+
torch.tensor([[[1, 11]], [[2, 22]], [[3, 33]], [[4, 44]]],
202+
device=device, dtype=dtype).repeat(prod, 1, 1).reshape(*batch_shape, 4, 1, 2),
203+
(*batch_shape, 2, 2))
199204

200205
@all_sparse_compressed_layouts()
201206
@onlyCPU
@@ -312,6 +317,17 @@ def test_empty_errors(self, layout, device, dtype):
312317
", but got size"):
313318
torch.empty((5,), dtype=dtype, device=device, layout=layout)
314319

320+
@skipMeta
321+
@all_sparse_compressed_layouts()
322+
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
323+
def test_clone(self, layout, device, dtype):
324+
for compressed_indices, plain_indices, values, size in self._generate_small_inputs(
325+
layout, device, dtype, index_dtype=torch.int32):
326+
sparse = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, size,
327+
dtype=dtype, layout=layout, device=device)
328+
cloned_sparse = sparse.clone()
329+
self.assertEqual(sparse, cloned_sparse)
330+
315331

316332
class TestSparseCSR(TestCase):
317333

@@ -384,20 +400,6 @@ def test_sparse_csr_select(self, device, dtype):
384400
with self.assertRaisesRegex(TypeError, "Cannot assign to a sparse tensor"):
385401
sparse[0, 0, 0, 0] = 99.0
386402

387-
@skipMeta
388-
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
389-
def test_clone(self, device, dtype):
390-
from operator import mul
391-
from functools import reduce
392-
for batch_shape in ((), (2,), (2, 3)):
393-
prod = reduce(mul, batch_shape, 1)
394-
crow_indices = torch.tensor([0, 2, 4], device=device).repeat(prod, 1).reshape(*batch_shape, -1)
395-
col_indices = torch.tensor([0, 1, 0, 1], device=device).repeat(prod, 1).reshape(*batch_shape, -1)
396-
values = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(prod, 1).reshape(*batch_shape, -1)
397-
sparse = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=dtype, device=device)
398-
cloned_sparse = sparse.clone()
399-
self.assertEqual(sparse, cloned_sparse)
400-
401403
@skipMeta
402404
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
403405
def test_copy(self, device, dtype):

test/test_testing.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,10 +1091,7 @@ def test_matching(self):
10911091
col_indices = (1, 0)
10921092
values = (1, 2)
10931093
actual = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2))
1094-
# TODO: replace this by actual.clone() after https://github.com/pytorch/pytorch/issues/59285 is fixed
1095-
expected = torch.sparse_csr_tensor(
1096-
actual.crow_indices(), actual.col_indices(), actual.values(), size=actual.size(), device=actual.device
1097-
)
1094+
expected = actual.clone()
10981095

10991096
for fn in assert_close_with_inputs(actual, expected):
11001097
fn()
@@ -1152,10 +1149,7 @@ def test_matching(self):
11521149
row_indices = (1, 0)
11531150
values = (1, 2)
11541151
actual = torch.sparse_csc_tensor(ccol_indices, row_indices, values, size=(2, 2))
1155-
# TODO: replace this by actual.clone() after https://github.com/pytorch/pytorch/issues/59285 is fixed
1156-
expected = torch.sparse_csc_tensor(
1157-
actual.ccol_indices(), actual.row_indices(), actual.values(), size=actual.size(), device=actual.device
1158-
)
1152+
expected = actual.clone()
11591153

11601154
for fn in assert_close_with_inputs(actual, expected):
11611155
fn()
@@ -1213,10 +1207,7 @@ def test_matching(self):
12131207
col_indices = (1, 0)
12141208
values = ([[1]], [[2]])
12151209
actual = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(2, 2))
1216-
# TODO: replace this by actual.clone() after https://github.com/pytorch/pytorch/issues/59285 is fixed
1217-
expected = torch.sparse_bsr_tensor(
1218-
actual.crow_indices(), actual.col_indices(), actual.values(), size=actual.size(), device=actual.device
1219-
)
1210+
expected = actual.clone()
12201211

12211212
for fn in assert_close_with_inputs(actual, expected):
12221213
fn()
@@ -1274,10 +1265,7 @@ def test_matching(self):
12741265
row_indices = (1, 0)
12751266
values = ([[1]], [[2]])
12761267
actual = torch.sparse_bsc_tensor(ccol_indices, row_indices, values, size=(2, 2))
1277-
# TODO: replace this by actual.clone() after https://github.com/pytorch/pytorch/issues/59285 is fixed
1278-
expected = torch.sparse_bsc_tensor(
1279-
actual.ccol_indices(), actual.row_indices(), actual.values(), size=actual.size(), device=actual.device
1280-
)
1268+
expected = actual.clone()
12811269

12821270
for fn in assert_close_with_inputs(actual, expected):
12831271
fn()

0 commit comments

Comments
 (0)