@@ -522,17 +522,30 @@ Tensor& copy_sparse_compressed_(Tensor& self, const Tensor& src, bool non_blocki
522522 TORCH_CHECK (
523523 self._nnz () == src._nnz (), // actually, values copy allows different shapes as long as operands are broadcastable
524524 " torch.copy_: only sparse compressed tensors with the same number of specified elements are supported." );
525- auto self_compressed_dims = self.size (compressedDimension (self.layout (), self.sizes ()));
525+ auto self_compressed_dim = compressedDimension (self.layout (), self.sizes ());
526+ auto src_compressed_dim = compressedDimension (src.layout (), src.sizes ());
527+ auto self_compressed_dims = self.size (self_compressed_dim);
526528 auto src_compressed_dims = src.size (compressedDimension (src.layout (), src.sizes ()));
527- TORCH_CHECK (self_compressed_dims == src_compressed_dims,
528- " torch.copy_: only sparse compressed tensors with the same number of compressed dimensions are supported." ,
529- " self and src compressed dimensions are " ,
530- self_compressed_dims, " and " , src_compressed_dims, " , respectively." );
529+ if (self_compressed_dim == src_compressed_dim) {
530+ TORCH_CHECK (self_compressed_dims == src_compressed_dims,
531+ " torch.copy_: expected shapes of self and src to match along dimension " ,
532+ self_compressed_dim, " for " ,
533+ self.layout (), " layout but the corresponding dimensions of self and src are " ,
534+ self_compressed_dims, " and " , src_compressed_dims, " , respecitvely." );
535+ } else {
536+ TORCH_CHECK (self_compressed_dims == src_compressed_dims,
537+ " torch.copy_: expected shapes of self and src to match along dimensions " ,
538+ self_compressed_dim, " and " , src_compressed_dim, " , respectively, for " ,
539+ self.layout (), " layout but the corresponding dimensions of self and src are " ,
540+ self_compressed_dims, " and " , src_compressed_dims, " , respecitvely." );
541+ }
531542 AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS (self.layout (), " copy_sparse_compressed_" ,
532543 [&]{},
533544 [&]{
534- auto self_block_size = DimVector (self.values ().sizes ().slice (self.values ().dim ()-2 , 2 ));
535- auto src_block_size = DimVector (src.values ().sizes ().slice (src.values ().dim ()-2 , 2 ));
545+ auto self_values = self.values ();
546+ auto src_values = src.values ();
547+ auto self_block_size = DimVector (self_values.sizes ().slice (self_values.dim ()-2 , 2 ));
548+ auto src_block_size = DimVector (src_values.sizes ().slice (src_values.dim ()-2 , 2 ));
536549 TORCH_CHECK (self_block_size == src_block_size,
537550 " torch.copy_: copy of sparse compressed tensors having different block sizes is not supported." ,
538551 " self and src block sizes are " , self_block_size, " and " , src_block_size, " , respectivly." );
0 commit comments