Skip to content

Commit 8a306d2

Browse files
committed
1. pass in lambd as scalar for CPU/CUDA_apply*; 2. removed tests for hardshrink at test_legacy_nn
1 parent b3bb062 commit 8a306d2

File tree

4 files changed

+51
-55
lines changed

4 files changed

+51
-55
lines changed

aten/src/ATen/native/Activation.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,36 +34,36 @@ Tensor & rrelu_(Tensor & self, Scalar lower, Scalar upper, bool training, Genera
3434
}
3535

3636
Tensor hardshrink_cpu(const Tensor & self, Scalar lambd) {
37-
auto lambd_tensor = at::zeros_like(self).fill_(lambd.toTensor());
38-
auto out_tensor = self.clone();
37+
auto lambd_tensor = lambd.toTensor().toType(self.type().scalarType()).toBackend(self.is_cuda() ? Backend::CUDA : Backend::CPU);
38+
auto out_tensor = at::zeros_like(self);
3939
AT_DISPATCH_FLOATING_TYPES(self.type(), "hardshrink_cpu", [&] {
40+
scalar_t* lambd_tensor_d = lambd_tensor.data<scalar_t>();
4041
at::CPU_tensor_apply2<scalar_t, scalar_t>(
41-
out_tensor,
42-
lambd_tensor,
43-
[](scalar_t& out_tensor_val,
44-
scalar_t& lambd_tensor_val) {
45-
if (out_tensor_val >= -lambd_tensor_val && out_tensor_val <= lambd_tensor_val) {
46-
out_tensor_val = convert<scalar_t, double>(0.0);
47-
}
42+
self,
43+
out_tensor,
44+
[lambd_tensor_d](
45+
scalar_t& self_val,
46+
scalar_t& out_tensor_val) {
47+
out_tensor_val = (self_val >= -*lambd_tensor_d && self_val <= *lambd_tensor_d) ? convert<scalar_t, double>(0.0) : self_val;
4848
});
4949
});
5050
return out_tensor;
5151
}
5252

5353
Tensor hardshrink_backward_cpu(const Tensor & grad, const Tensor & self, Scalar lambd) {
54-
auto lambd_tensor = at::zeros_like(self).fill_(lambd);
55-
auto out_tensor = grad.clone();
54+
auto lambd_tensor = lambd.toTensor().toType(self.type().scalarType()).toBackend(self.is_cuda() ? Backend::CUDA : Backend::CPU);
55+
auto out_tensor = at::zeros_like(self);
5656
AT_DISPATCH_FLOATING_TYPES(self.type(), "hardshrink_backward_cpu", [&] {
57+
scalar_t* lambd_tensor_d = lambd_tensor.data<scalar_t>();
5758
at::CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>(
58-
out_tensor,
59-
lambd_tensor,
60-
self,
61-
[](scalar_t& out_tensor_val,
62-
scalar_t& lambd_tensor_val,
63-
scalar_t& self_val) {
64-
if (self_val >= -lambd_tensor_val && self_val <= lambd_tensor_val) {
65-
out_tensor_val = convert<scalar_t, double>(0.0);
66-
}
59+
self,
60+
grad,
61+
out_tensor,
62+
[lambd_tensor_d](
63+
scalar_t& self_val,
64+
scalar_t& grad_val,
65+
scalar_t& out_tensor_val) {
66+
out_tensor_val = (self_val >= -*lambd_tensor_d && self_val <= *lambd_tensor_d) ? convert<scalar_t, double>(0.0) : grad_val;
6767
});
6868
});
6969
return out_tensor;

aten/src/ATen/native/cuda/Activation.cu

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,60 +9,48 @@
99
namespace at { namespace native {
1010

1111
template <typename scalar_t>
12-
void hardshrink_cuda_kernel(const Tensor& self, Tensor& out_tensor, Tensor& lambd_tensor) {
13-
at::cuda::CUDA_tensor_apply3<scalar_t, scalar_t, scalar_t>(
14-
self,
15-
out_tensor,
16-
lambd_tensor,
17-
[] __device__ (scalar_t& self_val,
18-
scalar_t& out_tensor_val,
19-
scalar_t& lambd_tensor_val) {
20-
if (self_val >= -lambd_tensor_val && self_val <= lambd_tensor_val) {
21-
out_tensor_val = ScalarConvert<double, scalar_t>::to(0.0);
22-
}
23-
else {
24-
out_tensor_val = self_val;
25-
}
12+
void hardshrink_cuda_kernel(const Tensor& self, Tensor& out_tensor, scalar_t* lambd) {
13+
at::cuda::CUDA_tensor_apply2<scalar_t, scalar_t>(
14+
self,
15+
out_tensor,
16+
[lambd] __device__ (
17+
scalar_t& self_val,
18+
scalar_t& out_tensor_val,
19+
bool early_exit) {
20+
out_tensor_val = (self_val >= -*lambd && self_val <= *lambd) ? ScalarConvert<double, scalar_t>::to(0.0) : self_val;
2621
});
2722
}
2823

2924
template <typename scalar_t>
30-
void hardshrink_backward_cuda_kernel(Tensor& out_tensor, Tensor& lambd_tensor, const Tensor& self, const Tensor& grad) {
31-
at::cuda::CUDA_tensor_apply4<scalar_t, scalar_t, scalar_t, scalar_t>(
32-
out_tensor,
33-
lambd_tensor,
34-
self,
35-
grad,
36-
[] __device__ (scalar_t& out_tensor_val,
37-
scalar_t& lambd_tensor_val,
38-
scalar_t& self_val,
39-
scalar_t& grad_val) {
40-
if (self_val >= -lambd_tensor_val && self_val <= lambd_tensor_val) {
41-
out_tensor_val = ScalarConvert<double, scalar_t>::to(0.0);
42-
}
43-
else {
44-
out_tensor_val = grad_val;
45-
}
25+
void hardshrink_backward_cuda_kernel(Tensor& out_tensor, scalar_t* lambd, const Tensor& self, const Tensor& grad) {
26+
at::cuda::CUDA_tensor_apply3<scalar_t, scalar_t, scalar_t>(
27+
self,
28+
grad,
29+
out_tensor,
30+
[lambd] __device__ (
31+
scalar_t& self_val,
32+
scalar_t& grad_val,
33+
scalar_t& out_tensor_val) {
34+
out_tensor_val = (self_val >= -*lambd && self_val <= *lambd) ? ScalarConvert<double, scalar_t>::to(0.0) : grad_val;
4635
});
4736
}
4837

4938
Tensor hardshrink_cuda(const Tensor & self, Scalar lambd) {
50-
auto lambd_tensor = at::zeros_like(self).fill_(lambd);
39+
auto lambd_tensor = lambd.toTensor().toType(self.type().scalarType()).toBackend(self.is_cuda() ? Backend::CUDA : Backend::CPU);
5140
auto out_tensor = at::zeros_like(self);
5241
AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "hardshrink_cuda", [&] {
5342
using cuda_scalar_t = cuda::into_type<scalar_t>;
54-
hardshrink_cuda_kernel<cuda_scalar_t>(self, out_tensor, lambd_tensor);
43+
hardshrink_cuda_kernel<cuda_scalar_t>(self, out_tensor, lambd_tensor.data<cuda_scalar_t>());
5544
});
5645
return out_tensor;
5746
}
5847

5948
Tensor hardshrink_backward_cuda(const Tensor & grad, const Tensor & self, Scalar lambd) {
60-
auto lambd_tensor = at::zeros_like(self).fill_(lambd);
61-
// auto lambd_tensor = lambd.toTensor().toType(grad.type().scalarType()).toBackend(grad.is_cuda() ? Backend::CUDA : Backend::CPU);
49+
auto lambd_tensor = lambd.toTensor().toType(self.type().scalarType()).toBackend(self.is_cuda() ? Backend::CUDA : Backend::CPU);
6250
auto out_tensor = at::zeros_like(grad);
6351
AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "hardshrink_backward_cuda", [&] {
6452
using cuda_scalar_t = cuda::into_type<scalar_t>;
65-
hardshrink_backward_cuda_kernel<cuda_scalar_t>(out_tensor, lambd_tensor, self, grad);
53+
hardshrink_backward_cuda_kernel<cuda_scalar_t>(out_tensor, lambd_tensor.data<cuda_scalar_t>(), self, grad);
6654
});
6755
return out_tensor;
6856
}

test/test_legacy_nn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,13 +618,20 @@ def add_test(test):
618618
test_params = deepcopy(test_params)
619619
name = test_params.pop('module_name')
620620
name = name_remap.get(name, name)
621+
# hardshrink is deprecated in nn
622+
if name == "HardShrink":
623+
continue
624+
621625
test_params['constructor'] = getattr(nn, name)
622626
test = OldModuleTest(**test_params)
623627
add_test(test)
624628
for test_params in criterion_tests:
625629
test_params = deepcopy(test_params)
626630
name = test_params.pop('module_name')
627631
name = name_remap.get(name, name.replace('Loss', 'Criterion'))
632+
# hardshrink is deprecated in nn
633+
if name == "HardShrink":
634+
continue
628635

629636
# nn.NLLLoss2d is deprecated, but there is a NLLLoss test for 2d
630637
if name == 'ClassNLLCriterion' and 'desc' in test_params.keys() and '2d' in test_params['desc']:

test/test_nn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5534,6 +5534,7 @@ def add(test_name, fn):
55345534
else:
55355535
add(cuda_test_name, lambda self, test=test: test.test_cuda(self))
55365536

5537+
55375538
def wrap_functional(fn, **kwargs):
55385539
class FunctionalModule(nn.Module):
55395540
def forward(self, *args):

0 commit comments

Comments
 (0)