|
9 | 9 | namespace at { namespace native { |
10 | 10 |
|
11 | 11 | template <typename scalar_t> |
12 | | -void hardshrink_cuda_kernel(const Tensor& self, Tensor& out_tensor, Tensor& lambd_tensor) { |
13 | | - at::cuda::CUDA_tensor_apply3<scalar_t, scalar_t, scalar_t>( |
14 | | - self, |
15 | | - out_tensor, |
16 | | - lambd_tensor, |
17 | | - [] __device__ (scalar_t& self_val, |
18 | | - scalar_t& out_tensor_val, |
19 | | - scalar_t& lambd_tensor_val) { |
20 | | - if (self_val >= -lambd_tensor_val && self_val <= lambd_tensor_val) { |
21 | | - out_tensor_val = ScalarConvert<double, scalar_t>::to(0.0); |
22 | | - } |
23 | | - else { |
24 | | - out_tensor_val = self_val; |
25 | | - } |
| 12 | +void hardshrink_cuda_kernel(const Tensor& self, Tensor& out_tensor, scalar_t* lambd) { |
| 13 | + at::cuda::CUDA_tensor_apply2<scalar_t, scalar_t>( |
| 14 | + self, |
| 15 | + out_tensor, |
| 16 | + [lambd] __device__ ( |
| 17 | + scalar_t& self_val, |
| 18 | + scalar_t& out_tensor_val, |
| 19 | + bool early_exit) { |
| 20 | + out_tensor_val = (self_val >= -*lambd && self_val <= *lambd) ? ScalarConvert<double, scalar_t>::to(0.0) : self_val; |
26 | 21 | }); |
27 | 22 | } |
28 | 23 |
|
29 | 24 | template <typename scalar_t> |
30 | | -void hardshrink_backward_cuda_kernel(Tensor& out_tensor, Tensor& lambd_tensor, const Tensor& self, const Tensor& grad) { |
31 | | - at::cuda::CUDA_tensor_apply4<scalar_t, scalar_t, scalar_t, scalar_t>( |
32 | | - out_tensor, |
33 | | - lambd_tensor, |
34 | | - self, |
35 | | - grad, |
36 | | - [] __device__ (scalar_t& out_tensor_val, |
37 | | - scalar_t& lambd_tensor_val, |
38 | | - scalar_t& self_val, |
39 | | - scalar_t& grad_val) { |
40 | | - if (self_val >= -lambd_tensor_val && self_val <= lambd_tensor_val) { |
41 | | - out_tensor_val = ScalarConvert<double, scalar_t>::to(0.0); |
42 | | - } |
43 | | - else { |
44 | | - out_tensor_val = grad_val; |
45 | | - } |
| 25 | +void hardshrink_backward_cuda_kernel(Tensor& out_tensor, scalar_t* lambd, const Tensor& self, const Tensor& grad) { |
| 26 | + at::cuda::CUDA_tensor_apply3<scalar_t, scalar_t, scalar_t>( |
| 27 | + self, |
| 28 | + grad, |
| 29 | + out_tensor, |
| 30 | + [lambd] __device__ ( |
| 31 | + scalar_t& self_val, |
| 32 | + scalar_t& grad_val, |
| 33 | + scalar_t& out_tensor_val) { |
| 34 | + out_tensor_val = (self_val >= -*lambd && self_val <= *lambd) ? ScalarConvert<double, scalar_t>::to(0.0) : grad_val; |
46 | 35 | }); |
47 | 36 | } |
48 | 37 |
|
49 | 38 | Tensor hardshrink_cuda(const Tensor & self, Scalar lambd) { |
50 | | - auto lambd_tensor = at::zeros_like(self).fill_(lambd); |
| 39 | + auto lambd_tensor = lambd.toTensor().toType(self.type().scalarType()).toBackend(self.is_cuda() ? Backend::CUDA : Backend::CPU); |
51 | 40 | auto out_tensor = at::zeros_like(self); |
52 | 41 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "hardshrink_cuda", [&] { |
53 | 42 | using cuda_scalar_t = cuda::into_type<scalar_t>; |
54 | | - hardshrink_cuda_kernel<cuda_scalar_t>(self, out_tensor, lambd_tensor); |
| 43 | + hardshrink_cuda_kernel<cuda_scalar_t>(self, out_tensor, lambd_tensor.data<cuda_scalar_t>()); |
55 | 44 | }); |
56 | 45 | return out_tensor; |
57 | 46 | } |
58 | 47 |
|
59 | 48 | Tensor hardshrink_backward_cuda(const Tensor & grad, const Tensor & self, Scalar lambd) { |
60 | | - auto lambd_tensor = at::zeros_like(self).fill_(lambd); |
61 | | - // auto lambd_tensor = lambd.toTensor().toType(grad.type().scalarType()).toBackend(grad.is_cuda() ? Backend::CUDA : Backend::CPU); |
| 49 | + auto lambd_tensor = lambd.toTensor().toType(self.type().scalarType()).toBackend(self.is_cuda() ? Backend::CUDA : Backend::CPU); |
62 | 50 | auto out_tensor = at::zeros_like(grad); |
63 | 51 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "hardshrink_backward_cuda", [&] { |
64 | 52 | using cuda_scalar_t = cuda::into_type<scalar_t>; |
65 | | - hardshrink_backward_cuda_kernel<cuda_scalar_t>(out_tensor, lambd_tensor, self, grad); |
| 53 | + hardshrink_backward_cuda_kernel<cuda_scalar_t>(out_tensor, lambd_tensor.data<cuda_scalar_t>(), self, grad); |
66 | 54 | }); |
67 | 55 | return out_tensor; |
68 | 56 | } |
|
0 commit comments