@@ -167,7 +167,8 @@ def _generate_small_inputs(self, layout, device, dtype, index_dtype):
167167 The input is defined as a 4-tuple:
168168 compressed_indices, plain_indices, values, expected_size_from_shape_inference
169169 """
170- batch_shape = (2 , 3 )
170+ from operator import mul
171+ from functools import reduce
171172 if layout in {torch .sparse_csr , torch .sparse_csc }:
172173 yield (torch .tensor ([0 , 2 , 4 ], device = device , dtype = index_dtype ),
173174 torch .tensor ([0 , 1 , 0 , 1 ], device = device , dtype = index_dtype ),
@@ -177,10 +178,12 @@ def _generate_small_inputs(self, layout, device, dtype, index_dtype):
177178 torch .tensor ([], device = device , dtype = index_dtype ),
178179 torch .tensor ([], device = device , dtype = dtype ),
179180 (0 , 0 ))
180- yield (torch .tensor ([0 , 2 , 4 ], device = device , dtype = index_dtype ).repeat (6 , 1 ).reshape (* batch_shape , - 1 ),
181- torch .tensor ([0 , 1 , 0 , 1 ], device = device , dtype = index_dtype ).repeat (6 , 1 ).reshape (* batch_shape , - 1 ),
182- torch .tensor ([1 , 2 , 3 , 4 ], device = device , dtype = dtype ).repeat (6 , 1 ).reshape (* batch_shape , - 1 ),
183- (* batch_shape , 2 , 2 ))
181+ for batch_shape in {(2 , 3 ), (2 ,)}:
182+ prod = reduce (mul , batch_shape , 1 )
183+ yield (torch .tensor ([0 , 2 , 4 ], device = device , dtype = index_dtype ).repeat (prod , 1 ).reshape (* batch_shape , - 1 ),
184+ torch .tensor ([0 , 1 , 0 , 1 ], device = device , dtype = index_dtype ).repeat (prod , 1 ).reshape (* batch_shape , - 1 ),
185+ torch .tensor ([1 , 2 , 3 , 4 ], device = device , dtype = dtype ).repeat (prod , 1 ).reshape (* batch_shape , - 1 ),
186+ (* batch_shape , 2 , 2 ))
184187 else :
185188 assert layout in {torch .sparse_bsr , torch .sparse_bsc }
186189 yield (torch .tensor ([0 , 2 , 4 ], device = device , dtype = index_dtype ),
@@ -191,11 +194,13 @@ def _generate_small_inputs(self, layout, device, dtype, index_dtype):
191194 torch .tensor ([], device = device , dtype = index_dtype ),
192195 torch .tensor ([], device = device , dtype = dtype ).reshape (1 , 0 , 0 ),
193196 (0 , 0 ))
194- yield (torch .tensor ([0 , 2 , 4 ], device = device , dtype = index_dtype ).repeat (6 , 1 ).reshape (* batch_shape , - 1 ),
195- torch .tensor ([0 , 1 , 0 , 1 ], device = device , dtype = index_dtype ).repeat (6 , 1 ).reshape (* batch_shape , - 1 ),
196- torch .tensor ([[[1 , 11 ]], [[2 , 22 ]], [[3 , 33 ]], [[4 , 44 ]]],
197- device = device , dtype = dtype ).repeat (6 , 1 , 1 ).reshape (* batch_shape , 4 , 1 , 2 ),
198- (* batch_shape , 2 , 2 ))
197+ for batch_shape in {(2 , 3 ), (2 ,)}:
198+ prod = reduce (mul , batch_shape , 1 )
199+ yield (torch .tensor ([0 , 2 , 4 ], device = device , dtype = index_dtype ).repeat (prod , 1 ).reshape (* batch_shape , - 1 ),
200+ torch .tensor ([0 , 1 , 0 , 1 ], device = device , dtype = index_dtype ).repeat (prod , 1 ).reshape (* batch_shape , - 1 ),
201+ torch .tensor ([[[1 , 11 ]], [[2 , 22 ]], [[3 , 33 ]], [[4 , 44 ]]],
202+ device = device , dtype = dtype ).repeat (prod , 1 , 1 ).reshape (* batch_shape , 4 , 1 , 2 ),
203+ (* batch_shape , 2 , 2 ))
199204
200205 @all_sparse_compressed_layouts ()
201206 @onlyCPU
@@ -312,6 +317,17 @@ def test_empty_errors(self, layout, device, dtype):
312317 ", but got size" ):
313318 torch .empty ((5 ,), dtype = dtype , device = device , layout = layout )
314319
320+ @skipMeta
321+ @all_sparse_compressed_layouts ()
322+ @dtypes (* all_types_and_complex_and (torch .bool , torch .half , torch .bfloat16 ))
323+ def test_clone (self , layout , device , dtype ):
324+ for compressed_indices , plain_indices , values , size in self ._generate_small_inputs (
325+ layout , device , dtype , index_dtype = torch .int32 ):
326+ sparse = torch .sparse_compressed_tensor (compressed_indices , plain_indices , values , size ,
327+ dtype = dtype , layout = layout , device = device )
328+ cloned_sparse = sparse .clone ()
329+ self .assertEqual (sparse , cloned_sparse )
330+
315331
316332class TestSparseCSR (TestCase ):
317333
@@ -384,20 +400,6 @@ def test_sparse_csr_select(self, device, dtype):
384400 with self .assertRaisesRegex (TypeError , "Cannot assign to a sparse tensor" ):
385401 sparse [0 , 0 , 0 , 0 ] = 99.0
386402
387- @skipMeta
388- @dtypes (* all_types_and_complex_and (torch .bool , torch .half , torch .bfloat16 ))
389- def test_clone (self , device , dtype ):
390- from operator import mul
391- from functools import reduce
392- for batch_shape in ((), (2 ,), (2 , 3 )):
393- prod = reduce (mul , batch_shape , 1 )
394- crow_indices = torch .tensor ([0 , 2 , 4 ], device = device ).repeat (prod , 1 ).reshape (* batch_shape , - 1 )
395- col_indices = torch .tensor ([0 , 1 , 0 , 1 ], device = device ).repeat (prod , 1 ).reshape (* batch_shape , - 1 )
396- values = torch .tensor ([1 , 2 , 3 , 4 ], device = device , dtype = dtype ).repeat (prod , 1 ).reshape (* batch_shape , - 1 )
397- sparse = torch .sparse_csr_tensor (crow_indices , col_indices , values , dtype = dtype , device = device )
398- cloned_sparse = sparse .clone ()
399- self .assertEqual (sparse , cloned_sparse )
400-
401403 @skipMeta
402404 @dtypes (* all_types_and_complex_and (torch .half , torch .bool , torch .bfloat16 ))
403405 def test_copy (self , device , dtype ):
0 commit comments