@@ -2070,24 +2070,61 @@ def test_embedding_dense_grad(self):
20702070 def test_embedding_dense_grad_cuda (self ):
20712071 self ._test_embedding_dense_grad ("cuda" )
20722072
2073+ def test_move_sparse_half_embedding (self ):
2074+ embedding = nn .Embedding (10 , 3 , sparse = True )
2075+ self .assertEqual (embedding .weight .device .type , 'cpu' )
2076+ self .assertEqual (embedding .weight .dtype , torch .float64 )
2077+ embedding .to (torch .float16 )
2078+ self .assertEqual (embedding .weight .dtype , torch .float16 )
2079+ self .assertEqual (embedding .embedding_dim , 3 )
2080+ self .assertEqual (embedding .num_embeddings , 10 )
2081+
2082+ if torch .cuda .is_available ():
2083+ embedding .to ('cuda' )
2084+ self .assertEqual (embedding .weight .device .type , 'cuda' )
2085+ embedding .to ('cpu' )
2086+ self .assertEqual (embedding .weight .device .type , 'cpu' )
2087+
20732088 def test_embedding_sparse_backward (self ):
2089+ self ._test_embedding_backward ()
2090+
2091+ @unittest .skipIf (not TEST_CUDA , "CUDA unavailable" )
2092+ def test_embedding_sparse_half_backward (self ):
2093+ # same as test_embedding_sparse_backward above but testing half types in
2094+ # cuda. cpu sum not supported for half types.
2095+ self ._test_embedding_backward ('cuda' , torch .float16 )
2096+
2097+ def _test_embedding_backward (self , device = 'cpu' , dtype = torch .float64 ):
20742098 embedding = nn .Embedding (10 , 3 , sparse = True )
2099+ tensor = torch .tensor ([[7 , 1 , 3 ]])
2100+ ones = torch .tensor (1. ).expand (3 , 3 )
2101+ tensorTwice = tensor .repeat (1 , 2 )
2102+ onesTwice = torch .cat ((ones , ones ))
2103+
2104+ embedding = embedding .to (dtype = dtype ).to (device )
2105+ tensor = tensor .to (device )
2106+ ones = ones .to (device )
2107+ tensorTwice = tensorTwice .to (device )
2108+ onesTwice = onesTwice .to (device )
2109+
20752110 embedding .zero_grad ()
2076- embedding (torch . LongTensor ([ 7 , 1 , 3 ]) ).sum ().backward ()
2077- self .assertEqual (embedding .weight .grad ._indices (), torch . LongTensor ([[ 7 , 1 , 3 ]]) )
2078- self .assertEqual (embedding .weight .grad ._values (), torch . tensor ( 1. ). expand ( 3 , 3 ) )
2111+ embedding (tensor [ 0 ] ).sum ().backward ()
2112+ self .assertEqual (embedding .weight .grad ._indices (), tensor )
2113+ self .assertEqual (embedding .weight .grad ._values (), ones )
20792114
20802115 embedding .zero_grad ()
2081- embedding (torch . LongTensor ([ 7 , 1 , 3 ]) ).sum ().backward ()
2082- embedding (torch . LongTensor ([ 7 , 1 , 3 ]) ).sum ().backward ()
2083- self .assertEqual (embedding .weight .grad ._indices (), torch . LongTensor ([[ 7 , 1 , 3 , 7 , 1 , 3 ]]) )
2084- self .assertEqual (embedding .weight .grad ._values (), torch . tensor ( 1. ). expand ( 6 , 3 ) )
2116+ embedding (tensor [ 0 ] ).sum ().backward ()
2117+ embedding (tensor [ 0 ] ).sum ().backward ()
2118+ self .assertEqual (embedding .weight .grad ._indices (), tensorTwice )
2119+ self .assertEqual (embedding .weight .grad ._values (), onesTwice )
20852120
20862121 embedding .zero_grad ()
2087- embedding (torch .LongTensor ([7 , 1 , 3 ])).sum ().backward ()
2088- embedding (torch .LongTensor ([8 , 1 , 3 ])).sum ().backward ()
2089- self .assertEqual (embedding .weight .grad ._indices (), torch .LongTensor ([[7 , 1 , 3 , 8 , 1 , 3 ]]))
2090- self .assertEqual (embedding .weight .grad ._values (), torch .tensor (1. ).expand (6 , 3 ))
2122+ embedding (tensor [0 ]).sum ().backward ()
2123+ tensor [0 , 0 ] = 8
2124+ embedding (tensor [0 ]).sum ().backward ()
2125+ tensorTwice [0 , 3 ] = 8
2126+ self .assertEqual (embedding .weight .grad ._indices (), tensorTwice )
2127+ self .assertEqual (embedding .weight .grad ._values (), onesTwice )
20912128
20922129 def test_embedding_padding_idx (self ):
20932130 embedding = nn .Embedding (10 , 20 , padding_idx = 0 )
@@ -2377,6 +2414,7 @@ def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None,
23772414 needed_prec = dtype2prec [dtype ] * 2
23782415 else :
23792416 needed_prec = backward_prec
2417+
23802418 self .assertEqual (es_weight_grad , e .weight .grad , needed_prec )
23812419
23822420 if test_per_sample_weights and trainable_per_sample_weights :
@@ -2564,12 +2602,13 @@ def test_contig_wrong_stride_cudnn(self):
25642602
25652603 def test_embedding_bag (self ):
25662604 for dtype in [torch .double , torch .float ]:
2567- # TODO: figure out why backward on float breaks
2568- test_backward = dtype is not torch .float
2569- self ._test_EmbeddingBag (False , 'sum' , False , test_backward = test_backward , dtype = dtype )
2570- self ._test_EmbeddingBag (False , 'mean' , False , test_backward = test_backward , dtype = dtype )
2571- self ._test_EmbeddingBag (False , 'max' , False , test_backward = test_backward , dtype = dtype )
2605+ self ._test_EmbeddingBag (False , 'sum' , False , dtype = dtype )
2606+ self ._test_EmbeddingBag (False , 'mean' , False , dtype = dtype )
2607+ self ._test_EmbeddingBag (False , 'max' , False , dtype = dtype )
25722608
2609+ # TODO: figure out why precision on sparse embeddings isn't the
2610+ # same as for dense.
2611+ test_backward = dtype is not torch .float
25732612 self ._test_EmbeddingBag (False , 'sum' , True , test_backward = test_backward , dtype = dtype )
25742613 self ._test_EmbeddingBag (False , 'mean' , True , test_backward = test_backward , dtype = dtype )
25752614
@@ -2733,10 +2772,11 @@ def test_embedding_bag_cuda(self, dtype=torch.float):
27332772 self ._test_EmbeddingBag (True , 'sum' , False , dtype )
27342773 self ._test_EmbeddingBag (True , 'mean' , False , dtype )
27352774 self ._test_EmbeddingBag (True , 'max' , False , dtype )
2736- if dtype != torch .half :
2737- # torch.cuda.sparse.HalfTensor is not enabled.
2738- self ._test_EmbeddingBag (True , 'sum' , True , dtype )
2739- self ._test_EmbeddingBag (True , 'mean' , True , dtype )
2775+
2776+ # see 'todo' in test_embedding_bag.
2777+ test_backward = dtype is not torch .float16
2778+ self ._test_EmbeddingBag (True , 'sum' , True , dtype , test_backward = test_backward )
2779+ self ._test_EmbeddingBag (True , 'mean' , True , dtype , test_backward = test_backward )
27402780
27412781 def test_fractional_max_pool2d (self ):
27422782 x = torch .randn (1 , 2 , 7 , 7 , requires_grad = True )
0 commit comments