@@ -363,6 +363,64 @@ def test_print(self, layout, device):
363363 self .maxDiff = orig_maxDiff
364364 raise
365365
366+ @skipMeta
367+ @all_sparse_compressed_layouts ()
368+ @dtypes (* all_types_and_complex_and (torch .half , torch .bool , torch .bfloat16 ))
369+ def test_copy (self , layout , device , dtype ):
370+
371+ def run_test (shape , nnz , index_type ):
372+ block_size = (2 , 3 ) if layout in {torch .sparse_bsr , torch .sparse_bsc } else ()
373+ a = self .genSparseCompressedTensor (shape , nnz , dtype = dtype , layout = layout , device = device ,
374+ index_dtype = index_dtype , block_size = block_size )
375+ b = self .genSparseCompressedTensor (shape , nnz , dtype = dtype , layout = layout , device = device ,
376+ index_dtype = index_dtype , block_size = block_size )
377+
378+ a .copy_ (b )
379+
380+ self .assertEqual (a , b )
381+
382+ ns = [5 , 2 , 0 ]
383+ batch_shapes = [(), (2 ,), (2 , 3 )]
384+ for (m , n , b ), index_dtype in zip (itertools .product (ns , ns , batch_shapes ), [torch .int32 , torch .int64 ]):
385+ run_test ((* b , m , n ), 0 , index_dtype )
386+ run_test ((* b , m , n ), m * n , index_dtype )
387+
388+ @skipMeta
389+ @all_sparse_compressed_layouts ()
390+ @dtypes (* all_types_and_complex_and (torch .half , torch .bool , torch .bfloat16 ))
391+ def test_copy_errors (self , layout , device , dtype ):
392+ block_size = (2 , 3 ) if layout in {torch .sparse_bsr , torch .sparse_bsc } else ()
393+ for index_dtype in [torch .int32 , torch .int64 ]:
394+ shape1 = (2 , 3 )
395+ a = self .genSparseCompressedTensor (shape1 , 0 , dtype = dtype , layout = layout , device = device ,
396+ index_dtype = index_dtype , block_size = block_size )
397+
398+ with self .assertRaisesRegex (RuntimeError ,
399+ "copy of sparse compressed tensors having different layouts is not supported." ):
400+ a .copy_ (torch .empty (a .shape , dtype = dtype , device = device ))
401+
402+ b = self .genSparseCompressedTensor (shape1 , 1 , dtype = dtype , layout = layout , device = device ,
403+ index_dtype = index_dtype , block_size = block_size )
404+ with self .assertRaisesRegex (RuntimeError ,
405+ "only sparse compressed tensors with the same number of specified elements are supported." ):
406+ a .copy_ (b )
407+
408+ shape2 = tuple (reversed (shape1 ))
409+ c = self .genSparseCompressedTensor (shape2 , 1 , dtype = dtype , layout = layout , device = device ,
410+ index_dtype = index_dtype , block_size = block_size )
411+ with self .assertRaisesRegex (
412+ RuntimeError ,
413+ "only sparse compressed tensors with the same number of compressed dimensions are supported." ):
414+ b .copy_ (c )
415+
416+ if block_size :
417+ block_size1 = tuple (reversed (block_size ))
418+ d = self .genSparseCompressedTensor (shape1 , 1 , dtype = dtype , layout = layout , device = device ,
419+ index_dtype = index_dtype , block_size = block_size1 )
420+ with self .assertRaisesRegex (RuntimeError ,
421+ "only sparse compressed tensors with the same values shape are supported." ):
422+ b .copy_ (d )
423+
366424
367425class TestSparseCSR (TestCase ):
368426
@@ -435,38 +493,6 @@ def test_sparse_csr_select(self, device, dtype):
435493 with self .assertRaisesRegex (TypeError , "Cannot assign to a sparse tensor" ):
436494 sparse [0 , 0 , 0 , 0 ] = 99.0
437495
438- @skipMeta
439- @dtypes (* all_types_and_complex_and (torch .half , torch .bool , torch .bfloat16 ))
440- def test_copy (self , device , dtype ):
441-
442- def run_test (shape , nnz , index_type ):
443- a = self .genSparseCSRTensor (shape , nnz , dtype = dtype , device = device , index_dtype = index_dtype )
444- b = self .genSparseCSRTensor (shape , nnz , dtype = dtype , device = device , index_dtype = index_dtype )
445-
446- a .copy_ (b )
447-
448- self .assertEqual (a , b )
449-
450- ns = [5 , 2 , 0 ]
451- batch_shapes = [(), (2 ,), (2 , 3 )]
452- for (m , n , b ), index_dtype in zip (itertools .product (ns , ns , batch_shapes ), [torch .int32 , torch .int64 ]):
453- run_test ((* b , m , n ), 0 , index_dtype )
454- run_test ((* b , m , n ), m * n , index_dtype )
455-
456- @skipMeta
457- @dtypes (* all_types_and_complex_and (torch .half , torch .bool , torch .bfloat16 ))
458- def test_copy_errors (self , device , dtype ):
459- for index_dtype in [torch .int32 , torch .int64 ]:
460- shape1 = (2 , 3 )
461- a = self .genSparseCSRTensor (shape1 , 0 , dtype = dtype , device = device , index_dtype = index_dtype )
462-
463- with self .assertRaisesRegex (RuntimeError , "copy between different layouts is not supported." ):
464- a .copy_ (torch .empty (a .shape , dtype = dtype , device = device ))
465-
466- b = self .genSparseCSRTensor (shape1 , 1 , dtype = dtype , device = device , index_dtype = index_dtype )
467- with self .assertRaisesRegex (RuntimeError , "only tensors with the same number of specified elements are supported." ):
468- a .copy_ (b )
469-
470496 @skipMeta
471497 @dtypes (* all_types_and_complex_and (torch .half , torch .bool , torch .bfloat16 ))
472498 def test_resize (self , device , dtype ):
0 commit comments