Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions aten/doc/Functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,6 @@ static inline Tensor & glu_forward_out(Tensor & output, const Tensor & self, int
static inline Tensor glu_forward(const Tensor & self, int64_t dim);
static inline Tensor & glu_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, int64_t dim);
static inline Tensor glu_backward(const Tensor & grad_output, const Tensor & self, int64_t dim);
static inline Tensor & hardshrink_out(Tensor & output, const Tensor & self, Scalar lambd=0.5);
static inline Tensor hardshrink(const Tensor & self, Scalar lambd=0.5);
static inline Tensor & hardshrink_forward_out(Tensor & output, const Tensor & self, Scalar lambd);
static inline Tensor hardshrink_forward(const Tensor & self, Scalar lambd);
static inline Tensor & hardshrink_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, Scalar lambd);
static inline Tensor hardshrink_backward(const Tensor & grad_output, const Tensor & self, Scalar lambd);
static inline Tensor & hardtanh_out(Tensor & output, const Tensor & self, Scalar min_val=-1, Scalar max_val=1);
static inline Tensor hardtanh(const Tensor & self, Scalar min_val=-1, Scalar max_val=1);
static inline Tensor & hardtanh_forward_out(Tensor & output, const Tensor & self, Scalar min_val, Scalar max_val);
Expand Down Expand Up @@ -2008,24 +2002,6 @@ static inline Tensor & glu_backward_out(Tensor & grad_input, const Tensor & grad
static inline Tensor glu_backward(const Tensor & grad_output, const Tensor & self, int64_t dim) {
return infer_type(self).glu_backward(grad_output, self, dim);
}
static inline Tensor & hardshrink_out(Tensor & output, const Tensor & self, Scalar lambd) {
return infer_type(self).hardshrink_out(output, self, lambd);
}
static inline Tensor hardshrink(const Tensor & self, Scalar lambd) {
return infer_type(self).hardshrink(self, lambd);
}
static inline Tensor & hardshrink_forward_out(Tensor & output, const Tensor & self, Scalar lambd) {
return infer_type(self).hardshrink_forward_out(output, self, lambd);
}
static inline Tensor hardshrink_forward(const Tensor & self, Scalar lambd) {
return infer_type(self).hardshrink_forward(self, lambd);
}
static inline Tensor & hardshrink_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, Scalar lambd) {
return infer_type(self).hardshrink_backward_out(grad_input, grad_output, self, lambd);
}
static inline Tensor hardshrink_backward(const Tensor & grad_output, const Tensor & self, Scalar lambd) {
return infer_type(self).hardshrink_backward(grad_output, self, lambd);
}
static inline Tensor & hardtanh_out(Tensor & output, const Tensor & self, Scalar min_val, Scalar max_val) {
return infer_type(self).hardtanh_out(output, self, min_val, max_val);
}
Expand Down
6 changes: 0 additions & 6 deletions aten/doc/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -747,12 +747,6 @@ struct AT_API Type {
virtual Tensor glu_forward(const Tensor & self, int64_t dim) const;
virtual Tensor & glu_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, int64_t dim) const;
virtual Tensor glu_backward(const Tensor & grad_output, const Tensor & self, int64_t dim) const;
virtual Tensor & hardshrink_out(Tensor & output, const Tensor & self, Scalar lambd=0.5) const;
virtual Tensor hardshrink(const Tensor & self, Scalar lambd=0.5) const;
virtual Tensor & hardshrink_forward_out(Tensor & output, const Tensor & self, Scalar lambd) const;
virtual Tensor hardshrink_forward(const Tensor & self, Scalar lambd) const;
virtual Tensor & hardshrink_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, Scalar lambd) const;
virtual Tensor hardshrink_backward(const Tensor & grad_output, const Tensor & self, Scalar lambd) const;
virtual Tensor & hardtanh_out(Tensor & output, const Tensor & self, Scalar min_val=-1, Scalar max_val=1) const;
virtual Tensor hardtanh(const Tensor & self, Scalar min_val=-1, Scalar max_val=1) const;
virtual Tensor & hardtanh_forward_out(Tensor & output, const Tensor & self, Scalar min_val, Scalar max_val) const;
Expand Down
39 changes: 39 additions & 0 deletions aten/src/ATen/native/Activation.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"
#include "ATen/Dispatch.h"
#include "ATen/CPUApplyUtils.h"
#include "ATen/Half.h"

namespace at { namespace native {

Expand Down Expand Up @@ -30,4 +33,40 @@ Tensor & rrelu_(Tensor & self, Scalar lower, Scalar upper, bool training, Genera
return at::rrelu_with_noise_(self, self.type().tensor(), lower, upper, training, generator);
}

Tensor hardshrink_cpu(const Tensor & self, Scalar lambd) {
auto lambd_tensor = lambd.toTensor().toType(self.type().scalarType()).toBackend(self.is_cuda() ? Backend::CUDA : Backend::CPU);
auto out_tensor = at::empty_like(self);
AT_DISPATCH_FLOATING_TYPES(self.type(), "hardshrink_cpu", [&] {
scalar_t* lambd_tensor_d = lambd_tensor.data<scalar_t>();
at::CPU_tensor_apply2<scalar_t, scalar_t>(
self,
out_tensor,
[lambd_tensor_d](
scalar_t& self_val,
scalar_t& out_tensor_val) {
out_tensor_val = (self_val >= -*lambd_tensor_d && self_val <= *lambd_tensor_d) ? convert<scalar_t, int>(0) : self_val;
});
});
return out_tensor;
}

Tensor hardshrink_backward_cpu(const Tensor & grad, const Tensor & self, Scalar lambd) {
auto lambd_tensor = lambd.toTensor().toType(self.type().scalarType()).toBackend(self.is_cuda() ? Backend::CUDA : Backend::CPU);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

auto out_tensor = at::empty_like(self);
AT_DISPATCH_FLOATING_TYPES(self.type(), "hardshrink_backward_cpu", [&] {
scalar_t* lambd_tensor_d = lambd_tensor.data<scalar_t>();
at::CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>(
self,
grad,
out_tensor,
[lambd_tensor_d](
scalar_t& self_val,
scalar_t& grad_val,
scalar_t& out_tensor_val) {
out_tensor_val = (self_val >= -*lambd_tensor_d && self_val <= *lambd_tensor_d) ? convert<scalar_t, int>(0) : grad_val;
});
});
return out_tensor;
}

}} // namespace at::native
58 changes: 58 additions & 0 deletions aten/src/ATen/native/cuda/Activation.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"
#include "ATen/Dispatch.h"
#include "ATen/cuda/CUDAApplyUtils.cuh"
#include "ATen/cuda/CUDATensorMethods.cuh"
#include "ATen/cuda/CUDATypeConversion.cuh"
#include "THCUNN/THCHalfAutoNumerics.cuh"

namespace at { namespace native {

template <typename scalar_t>
void hardshrink_cuda_kernel(const Tensor& self, Tensor& out_tensor, scalar_t* lambd) {
at::cuda::CUDA_tensor_apply2<scalar_t, scalar_t>(
self,
out_tensor,
[lambd] __device__ (
scalar_t& self_val,
scalar_t& out_tensor_val,
bool early_exit) {
out_tensor_val = (self_val >= -*lambd && self_val <= *lambd) ? scalar_cast<scalar_t>(0) : self_val;
});
}

template <typename scalar_t>
void hardshrink_backward_cuda_kernel(Tensor& out_tensor, scalar_t* lambd, const Tensor& self, const Tensor& grad) {
at::cuda::CUDA_tensor_apply3<scalar_t, scalar_t, scalar_t>(
self,
grad,
out_tensor,
[lambd] __device__ (
scalar_t& self_val,
scalar_t& grad_val,
scalar_t& out_tensor_val) {
out_tensor_val = (self_val >= -*lambd && self_val <= *lambd) ? scalar_cast<scalar_t>(0) : grad_val;
});
}

Tensor hardshrink_cuda(const Tensor & self, Scalar lambd) {
auto lambd_tensor = lambd.toTensor().toType(self.type().scalarType()).toBackend(self.is_cuda() ? Backend::CUDA : Backend::CPU);
auto out_tensor = at::empty_like(self);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "hardshrink_cuda", [&] {
using cuda_scalar_t = cuda::into_type<scalar_t>;
hardshrink_cuda_kernel<cuda_scalar_t>(self, out_tensor, lambd_tensor.data<cuda_scalar_t>());
});
return out_tensor;
}

Tensor hardshrink_backward_cuda(const Tensor & grad, const Tensor & self, Scalar lambd) {
auto lambd_tensor = lambd.toTensor().toType(self.type().scalarType()).toBackend(self.is_cuda() ? Backend::CUDA : Backend::CPU);
auto out_tensor = at::empty_like(grad);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "hardshrink_backward_cuda", [&] {
using cuda_scalar_t = cuda::into_type<scalar_t>;
hardshrink_backward_cuda_kernel<cuda_scalar_t>(out_tensor, lambd_tensor.data<cuda_scalar_t>(), self, grad);
});
return out_tensor;
}

}} // namespace at::native
12 changes: 12 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,18 @@

- func: relu_(Tensor self) -> Tensor

# TODO: there is a bug in float values of Scalar, after bug fixes, set the default value of lambd=0.5
# for now, getting around this with default values defined in torch/nn/functional.py
- func: hardshrink(Tensor self, Scalar lambd) -> Tensor
dispatch:
CPU: hardshrink_cpu
CUDA: hardshrink_cuda

- func: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor
dispatch:
CPU: hardshrink_backward_cpu
CUDA: hardshrink_backward_cuda

- func: rsqrt(Tensor self) -> Tensor

- func: rsqrt_(Tensor self) -> Tensor
Expand Down
5 changes: 0 additions & 5 deletions aten/src/ATen/nn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@
scalar_check:
output: 'false'

- name: hardshrink(Tensor self, Scalar lambd=0.5)
cname: HardShrink
scalar_check:
output: self_->isScalar()

- name: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1)
cname: HardTanh
has_inplace: True
Expand Down
13 changes: 0 additions & 13 deletions aten/src/THNN/generic/THNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,6 @@ TH_API void THNN_(GatedLinear_updateGradInput)(
THTensor *gradInput, // [OUT] gradient w.r.t input
int dim); // dimension for halving operation

// HardShink outputs 0 on interval of (-lambda; lambda) or original value otherwise.
TH_API void THNN_(HardShrink_updateOutput)(
THNNState *state, // library's state
THTensor *input, // input tensor
THTensor *output, // [OUT] output tensor
accreal lambda); // HardShrink parameter
TH_API void THNN_(HardShrink_updateGradInput)(
THNNState *state, // library's state
THTensor *input, // input tensor
THTensor *gradOutput, // gradient w.r.t. module's output
THTensor *gradInput, // [OUT] gradient w.r.t. input
accreal lambda); // HardShrink parameter

// HardTanh clamps the values to the interval [min_val; max_val].
TH_API void THNN_(HardTanh_updateOutput)(
THNNState *state, // library's state
Expand Down
3 changes: 0 additions & 3 deletions aten/src/THNN/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@
#include "generic/ELU.c"
#include "THGenerateFloatTypes.h"

#include "generic/HardShrink.c"
#include "THGenerateFloatTypes.h"

#include "generic/HardTanh.c"
#include "THGenerateFloatTypes.h"

Expand Down
7 changes: 7 additions & 0 deletions test/test_legacy_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,13 +618,20 @@ def add_test(test):
test_params = deepcopy(test_params)
name = test_params.pop('module_name')
name = name_remap.get(name, name)
# hardshrink is deprecated in nn
if name == "HardShrink":
continue

test_params['constructor'] = getattr(nn, name)
test = OldModuleTest(**test_params)
add_test(test)
for test_params in criterion_tests:
test_params = deepcopy(test_params)
name = test_params.pop('module_name')
name = name_remap.get(name, name.replace('Loss', 'Criterion'))
# hardshrink is deprecated in nn
if name == "HardShrink":
continue

# nn.NLLLoss2d is deprecated, but there is a NLLLoss test for 2d
if name == 'ClassNLLCriterion' and 'desc' in test_params.keys() and '2d' in test_params['desc']:
Expand Down
29 changes: 13 additions & 16 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,7 @@ def _do_test(self, test_case, module, input):
for p in module.parameters():
test_case.assertIsInstance(p, torch.DoubleTensor)

# TODO: Hardshrink is lacking a CUDA implementation
if TEST_CUDA and self.should_test_cuda and type(module) != nn.Hardshrink:
if TEST_CUDA and self.should_test_cuda:
# check that cuda() moves module parameters to correct GPU device,
# and that float() casts parameters correctly

Expand Down Expand Up @@ -363,7 +362,7 @@ def _do_test(self, test_case, module, input):
input = input.half().cuda()
module.half().cuda()
module(input)
for o in module.parameters():
for p in module.parameters():

This comment was marked as off-topic.

This comment was marked as off-topic.

test_case.assertIsInstance(p, torch.cuda.HalfTensor)
test_case.assertEqual(p.get_device(), 0)

Expand Down Expand Up @@ -5523,19 +5522,17 @@ def add(test_name, fn):

test_name = test.get_name()
add(test_name, lambda self, test=test: test(self))
# Hardshrink is not implemented in CUDA, so we must not test it.
if not test_name.startswith("test_Hardshrink"):
cuda_test_name = test_name + '_cuda'
# With dtype enable, it's good enough to test against three floating types
if 'dtype' in get_function_arglist(test.test_cuda):
add(cuda_test_name + '_float', lambda self,
test=test: test.test_cuda(self, dtype=torch.float))
add(cuda_test_name + '_double', lambda self,
test=test: test.test_cuda(self, dtype=torch.double))
add(cuda_test_name + '_half', lambda self,
test=test: test.test_cuda(self, dtype=torch.half))
else:
add(cuda_test_name, lambda self, test=test: test.test_cuda(self))
cuda_test_name = test_name + '_cuda'
# With dtype enable, it's good enough to test against three floating types
if 'dtype' in get_function_arglist(test.test_cuda):
add(cuda_test_name + '_float', lambda self,
test=test: test.test_cuda(self, dtype=torch.float))
add(cuda_test_name + '_double', lambda self,
test=test: test.test_cuda(self, dtype=torch.double))
add(cuda_test_name + '_half', lambda self,
test=test: test.test_cuda(self, dtype=torch.half))
else:
add(cuda_test_name, lambda self, test=test: test.test_cuda(self))


def wrap_functional(fn, **kwargs):
Expand Down
18 changes: 18 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5547,6 +5547,24 @@ def test_abs(self):
res = torch.LongTensor((-bignumber,))
self.assertGreater(res.abs()[0], 0)

def test_hardshrink(self):
data_original = torch.tensor([1, 0.5, 0.3, 0.6]).view(2, 2)
float_types = [
'torch.DoubleTensor',
'torch.FloatTensor'
]
for t in float_types:
data = data_original.type(t)
self.assertEqual(torch.tensor([1, 0.5, 0, 0.6]).view(2, 2), torch.nn.Hardshrink(0.3)(data))
self.assertEqual(torch.tensor([1, 0, 0, 0.6]).view(2, 2), torch.nn.Hardshrink(0.5)(data))
self.assertEqual(torch.tensor([1, 0, 0, 0.6]).view(2, 2), torch.nn.Hardshrink()(data))

# test non-contiguous case
self.assertEqual(torch.tensor([1, 0.3, 0.5, 0.6]).view(2, 2), torch.nn.Hardshrink(0.1)(data.t()))

# not supporting default lambd value for torch.hardshrink() due to a Scalar bug
self.assertRaises(TypeError, lambda: data.hardshrink())

def test_unbiased(self):
tensor = torch.randn(100)
self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True))
Expand Down
4 changes: 4 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,13 @@ def init(cls):
long_size = 8 if sys.platform == 'win32' else None
tests = load_lua(path, long_size=long_size)
for name, test in tests['modules'].items():
if name == "HardShrink":
continue
test_name = 'test_' + name.replace('nn.', '')
setattr(cls, test_name, cls._module_test(name, test))
for name, test in tests['criterions'].items():
if name == "HardShrink":
continue
test_name = 'test_' + name.replace('nn.', '')
setattr(cls, test_name, cls._criterion_test(name, test))

Expand Down
10 changes: 5 additions & 5 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -752,9 +752,13 @@
- name: glu_forward(Tensor self, int64_t dim)
self: glu_backward(grad, self, dim)

- name: hardshrink_forward(Tensor self, Scalar lambd)
- name: hardshrink(Tensor self, Scalar lambd)
self: hardshrink_backward(grad, self, lambd)

- name: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd)
grad_out: hardshrink_backward(grad, self, lambd)
self: zeros_like(grad)

- name: hardtanh_forward(Tensor self, Scalar min_val, Scalar max_val)
self: hardtanh_backward(grad, self, min_val, max_val)

Expand Down Expand Up @@ -952,10 +956,6 @@
grad_output: glu_double_backward_grad_output(grad, self, dim)
self: glu_double_backward(grad, grad_output, self, dim)

- name: hardshrink_backward(Tensor grad_output, Tensor self, Scalar lambd)
grad_output: hardshrink_backward(grad, self, lambd)
self: zeros_like(grad)

- name: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val)
grad_output: hardtanh_backward(grad, self, min_val, max_val)
self: zeros_like(grad)
Expand Down
13 changes: 8 additions & 5 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,13 +777,16 @@ def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False):
See :class:`~torch.nn.LogSigmoid` for more details.
""")

hardshrink = _add_docstr(torch._C._nn.hardshrink, r"""
hardshrink(input, lambd=0.5) -> Tensor

Applies the hard shrinkage function element-wise
def hardshrink(input, lambd=0.5):
r"""
hardshrink(input, lambd=0.5) -> Tensor

See :class:`~torch.nn.Hardshrink` for more details.
""")
Applies the hard shrinkage function element-wise

See :class:`~torch.nn.Hardshrink` for more details.
"""
return torch.hardshrink(input, lambd)


def tanhshrink(input):
Expand Down