Skip to content

Commit d68802b

Browse files
nairbvfacebook-github-bot
authored andcommitted
Sparse half embeddings on cuda (#19695)
Summary: ``` import torch a = torch.nn.Embedding(3, 4, sparse=True).half().cuda() a(torch.LongTensor([1, 0]).cuda()).sum().backward() ``` gave: `RuntimeError: torch.cuda.sparse.HalfTensor is not enabled` This PR enables sparse.HalfTensor on cuda. Still won't work for CPU. Pull Request resolved: #19695 Differential Revision: D15281162 Pulled By: nairbv fbshipit-source-id: 0d83d946a059393bd53d8b8102e2daa9b4c02588
1 parent 148e90b commit d68802b

File tree

11 files changed

+180
-89
lines changed

11 files changed

+180
-89
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
[[
150150
name: _th_nonzero
151151
cname: nonzero
152+
cpu_half: True
152153
cpu_bool: True
153154
cuda_bool: True
154155
variants:

aten/src/ATen/gen.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,9 +390,6 @@ def legacy_iterate_types():
390390
for scalar_type in (scalar_types + quantized_scalar_types):
391391
if density == 'Mkldnn' and (backend != 'CPU' or scalar_type[0] != 'Float'):
392392
continue
393-
if density == 'Sparse' and scalar_type[0] == 'Half':
394-
# THS does not do half type yet.
395-
continue
396393
else:
397394
yield (backend, density, scalar_type)
398395
for backend in quantized_backends:

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,12 +2523,14 @@
25232523
variants: function, method
25242524

25252525
- func: to_sparse(Tensor self, int sparse_dim) -> Tensor
2526+
cpu_half: True
25262527
variants: method
25272528
dispatch:
25282529
CPU: dense_to_sparse
25292530
CUDA: dense_to_sparse
25302531

25312532
- func: to_sparse(Tensor self) -> Tensor
2533+
cpu_half: True
25322534
variants: method
25332535
dispatch:
25342536
CPU: dense_to_sparse

aten/src/ATen/native/sparse/SparseTensor.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,9 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
325325
// NB: Dropped the resizeNd variants
326326

327327
Tensor sparse_to_dense(const SparseTensor& self) {
328+
if(self.scalar_type() == ScalarType::Half && self.options().device().is_cpu()) {
329+
AT_ERROR("to_dense() not supported for float16 on CPU");
330+
}
328331
Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided));
329332
return dst.add_(self);
330333
}

aten/src/TH/THTensor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
#include <TH/generic/THTensorMath.h>
3232
#include <TH/THGenerateBoolType.h>
3333

34+
#include <TH/generic/THTensorMath.h>
35+
#include <TH/THGenerateHalfType.h>
36+
3437
/* fill and zero*/
3538
#include <TH/generic/THTensorFill.h>
3639
#include <TH/THGenerateAllTypes.h>

aten/src/TH/THTensorEvenMoreMath.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@
88

99
#include <TH/generic/THTensorEvenMoreMath.cpp>
1010
#include <TH/THGenerateBoolType.h>
11+
12+
#include <TH/generic/THTensorEvenMoreMath.cpp>
13+
#include <TH/THGenerateHalfType.h>

aten/src/TH/generic/THTensorEvenMoreMath.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
1111
int64_t *subscript_data;
1212
int64_t i = 0;
1313
#ifdef TH_REAL_IS_HALF
14-
#define IS_NONZERO(val) ((val.x & 0x7fff) != 0)
14+
#define IS_NONZERO(val) (c10::Half(0)!=val)
1515
#else
1616
#define IS_NONZERO(val) ((val)!=0)
1717
#endif
@@ -65,17 +65,20 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
6565
);
6666
delete [] sizes;
6767
delete [] idx;
68+
69+
#undef IS_NONZERO
6870
}
6971

72+
#if !defined(TH_REAL_IS_HALF) /* non half only part */
73+
7074
accreal THTensor_(sumall)(THTensor *tensor)
7175
{
7276
accreal sum = 0;
7377
TH_TENSOR_APPLY_REDUCTION_SUM_PARALLEL(
7478
scalar_t, tensor, *tensor_data, sum, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
7579
return sum;
7680
}
77-
78-
#if !defined(TH_REAL_IS_BOOL) /* non bool only part */
81+
#if !defined(TH_REAL_IS_BOOL)
7982

8083
void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value)
8184
{
@@ -906,4 +909,6 @@ void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value)
906909

907910
#endif
908911

912+
#endif
913+
909914
#endif /* TH_GENERIC_FILE */

aten/src/TH/generic/THTensorMath.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor);
66

7+
#ifndef TH_REAL_IS_HALF
8+
79
TH_API void THTensor_(ltValue)(THByteTensor *r_, THTensor* t, scalar_t value);
810
TH_API void THTensor_(leValue)(THByteTensor *r_, THTensor* t, scalar_t value);
911
TH_API void THTensor_(gtValue)(THByteTensor *r_, THTensor* t, scalar_t value);
@@ -183,3 +185,4 @@ TH_API void THTensor_(dirichlet_grad)(THTensor *self, THTensor *x, THTensor *alp
183185

184186
#endif
185187
#endif
188+
#endif

test/test_nn.py

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,24 +2070,61 @@ def test_embedding_dense_grad(self):
20702070
def test_embedding_dense_grad_cuda(self):
20712071
self._test_embedding_dense_grad("cuda")
20722072

2073+
def test_move_sparse_half_embedding(self):
2074+
embedding = nn.Embedding(10, 3, sparse=True)
2075+
self.assertEqual(embedding.weight.device.type, 'cpu')
2076+
self.assertEqual(embedding.weight.dtype, torch.float64)
2077+
embedding.to(torch.float16)
2078+
self.assertEqual(embedding.weight.dtype, torch.float16)
2079+
self.assertEqual(embedding.embedding_dim, 3)
2080+
self.assertEqual(embedding.num_embeddings, 10)
2081+
2082+
if torch.cuda.is_available():
2083+
embedding.to('cuda')
2084+
self.assertEqual(embedding.weight.device.type, 'cuda')
2085+
embedding.to('cpu')
2086+
self.assertEqual(embedding.weight.device.type, 'cpu')
2087+
20732088
def test_embedding_sparse_backward(self):
2089+
self._test_embedding_backward()
2090+
2091+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
2092+
def test_embedding_sparse_half_backward(self):
2093+
# same as test_embedding_sparse_backward above but testing half types in
2094+
# cuda. cpu sum not supported for half types.
2095+
self._test_embedding_backward('cuda', torch.float16)
2096+
2097+
def _test_embedding_backward(self, device='cpu', dtype=torch.float64):
20742098
embedding = nn.Embedding(10, 3, sparse=True)
2099+
tensor = torch.tensor([[7, 1, 3]])
2100+
ones = torch.tensor(1.).expand(3, 3)
2101+
tensorTwice = tensor.repeat(1, 2)
2102+
onesTwice = torch.cat((ones, ones))
2103+
2104+
embedding = embedding.to(dtype=dtype).to(device)
2105+
tensor = tensor.to(device)
2106+
ones = ones.to(device)
2107+
tensorTwice = tensorTwice.to(device)
2108+
onesTwice = onesTwice.to(device)
2109+
20752110
embedding.zero_grad()
2076-
embedding(torch.LongTensor([7, 1, 3])).sum().backward()
2077-
self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3]]))
2078-
self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(3, 3))
2111+
embedding(tensor[0]).sum().backward()
2112+
self.assertEqual(embedding.weight.grad._indices(), tensor)
2113+
self.assertEqual(embedding.weight.grad._values(), ones)
20792114

20802115
embedding.zero_grad()
2081-
embedding(torch.LongTensor([7, 1, 3])).sum().backward()
2082-
embedding(torch.LongTensor([7, 1, 3])).sum().backward()
2083-
self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3, 7, 1, 3]]))
2084-
self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(6, 3))
2116+
embedding(tensor[0]).sum().backward()
2117+
embedding(tensor[0]).sum().backward()
2118+
self.assertEqual(embedding.weight.grad._indices(), tensorTwice)
2119+
self.assertEqual(embedding.weight.grad._values(), onesTwice)
20852120

20862121
embedding.zero_grad()
2087-
embedding(torch.LongTensor([7, 1, 3])).sum().backward()
2088-
embedding(torch.LongTensor([8, 1, 3])).sum().backward()
2089-
self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3, 8, 1, 3]]))
2090-
self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(6, 3))
2122+
embedding(tensor[0]).sum().backward()
2123+
tensor[0, 0] = 8
2124+
embedding(tensor[0]).sum().backward()
2125+
tensorTwice[0, 3] = 8
2126+
self.assertEqual(embedding.weight.grad._indices(), tensorTwice)
2127+
self.assertEqual(embedding.weight.grad._values(), onesTwice)
20912128

20922129
def test_embedding_padding_idx(self):
20932130
embedding = nn.Embedding(10, 20, padding_idx=0)
@@ -2377,6 +2414,7 @@ def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None,
23772414
needed_prec = dtype2prec[dtype] * 2
23782415
else:
23792416
needed_prec = backward_prec
2417+
23802418
self.assertEqual(es_weight_grad, e.weight.grad, needed_prec)
23812419

23822420
if test_per_sample_weights and trainable_per_sample_weights:
@@ -2564,12 +2602,13 @@ def test_contig_wrong_stride_cudnn(self):
25642602

25652603
def test_embedding_bag(self):
25662604
for dtype in [torch.double, torch.float]:
2567-
# TODO: figure out why backward on float breaks
2568-
test_backward = dtype is not torch.float
2569-
self._test_EmbeddingBag(False, 'sum', False, test_backward=test_backward, dtype=dtype)
2570-
self._test_EmbeddingBag(False, 'mean', False, test_backward=test_backward, dtype=dtype)
2571-
self._test_EmbeddingBag(False, 'max', False, test_backward=test_backward, dtype=dtype)
2605+
self._test_EmbeddingBag(False, 'sum', False, dtype=dtype)
2606+
self._test_EmbeddingBag(False, 'mean', False, dtype=dtype)
2607+
self._test_EmbeddingBag(False, 'max', False, dtype=dtype)
25722608

2609+
# TODO: figure out why precision on sparse embeddings isn't the
2610+
# same as for dense.
2611+
test_backward = dtype is not torch.float
25732612
self._test_EmbeddingBag(False, 'sum', True, test_backward=test_backward, dtype=dtype)
25742613
self._test_EmbeddingBag(False, 'mean', True, test_backward=test_backward, dtype=dtype)
25752614

@@ -2733,10 +2772,11 @@ def test_embedding_bag_cuda(self, dtype=torch.float):
27332772
self._test_EmbeddingBag(True, 'sum', False, dtype)
27342773
self._test_EmbeddingBag(True, 'mean', False, dtype)
27352774
self._test_EmbeddingBag(True, 'max', False, dtype)
2736-
if dtype != torch.half:
2737-
# torch.cuda.sparse.HalfTensor is not enabled.
2738-
self._test_EmbeddingBag(True, 'sum', True, dtype)
2739-
self._test_EmbeddingBag(True, 'mean', True, dtype)
2775+
2776+
# see 'todo' in test_embedding_bag.
2777+
test_backward = dtype is not torch.float16
2778+
self._test_EmbeddingBag(True, 'sum', True, dtype, test_backward=test_backward)
2779+
self._test_EmbeddingBag(True, 'mean', True, dtype, test_backward=test_backward)
27402780

27412781
def test_fractional_max_pool2d(self):
27422782
x = torch.randn(1, 2, 7, 7, requires_grad=True)

0 commit comments

Comments
 (0)