Skip to content

Commit 75a4862

Browse files
Added SiLU activation function (#41034)
Summary: Implemented the SiLU activation function as discussed in #3169. Pull Request resolved: #41034 Reviewed By: glaringlee Differential Revision: D22465203 Pulled By: heitorschueroff fbshipit-source-id: b27d064529fc99600c586ad49b594b52b718b0d2
1 parent f6eb92a commit 75a4862

File tree

11 files changed

+124
-2
lines changed

11 files changed

+124
-2
lines changed

aten/src/ATen/core/aten_interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,7 @@ _(aten, selu) \
607607
_(aten, set) \
608608
_(aten, sigmoid) \
609609
_(aten, sign) \
610+
_(aten, silu) \
610611
_(aten, sin) \
611612
_(aten, sinh) \
612613
_(aten, size) \

aten/src/ATen/native/Activation.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,22 @@ Tensor & celu_(Tensor & self, Scalar alpha) {
190190
return at::elu_(self, alpha, Scalar(1.0), Scalar(inv_alpha));
191191
}
192192

193+
Tensor silu(const Tensor& self) {
194+
return self * at::sigmoid(self);
195+
}
196+
197+
Tensor& silu_(Tensor& self) {
198+
return self.mul_(at::sigmoid(self));
199+
}
200+
201+
Tensor& silu_out(Tensor& result, const Tensor& self) {
202+
return at::mul_out(result, self, at::sigmoid(self));
203+
}
204+
205+
Tensor silu_backward(const Tensor& grad, const Tensor& self) {
206+
auto self_sigmoid = at::sigmoid(self);
207+
return grad * (self_sigmoid * (1 + self * (1 - self_sigmoid)));
208+
}
193209

194210
template <typename scalar_t>
195211
inline void _rrelu_with_noise_train(

aten/src/ATen/native/native_functions.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2428,6 +2428,20 @@
24282428

24292429
- func: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!)
24302430

2431+
- func: silu(Tensor self) -> Tensor
2432+
use_c10_dispatcher: full
2433+
python_module: nn
2434+
2435+
- func: silu_(Tensor(a!) self) -> Tensor(a!)
2436+
python_module: nn
2437+
2438+
- func: silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
2439+
python_module: nn
2440+
2441+
- func: silu_backward(Tensor grad_output, Tensor self) -> Tensor
2442+
use_c10_dispatcher: full
2443+
python_module: nn
2444+
24312445
- func: sigmoid(Tensor self) -> Tensor
24322446
use_c10_dispatcher: full
24332447
variants: function, method

docs/source/nn.functional.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,11 @@ Non-linear activation functions
278278

279279
.. autofunction:: hardsigmoid
280280

281+
:hidden:`silu`
282+
~~~~~~~~~~~~~~~~~~~~~
283+
284+
.. autofunction:: silu
285+
281286

282287
Normalization functions
283288
-----------------------

docs/source/nn.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ Non-linear Activations (weighted sum, nonlinearity)
125125
nn.CELU
126126
nn.GELU
127127
nn.Sigmoid
128+
nn.SiLU
128129
nn.Softplus
129130
nn.Softshrink
130131
nn.Softsign

test/test_torch.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16577,6 +16577,23 @@ def test_hardsigmoid(self, device, dtype):
1657716577
torch.tensor(expectedOutput, dtype=dtype, device=device),
1657816578
atol=precision_4dps, rtol=0)
1657916579

16580+
@dtypes(torch.float, torch.double)
16581+
def test_silu(self, device, dtype):
16582+
inputValues = [-1000, -1, 0, 0.5, 1, 2, 1000]
16583+
expectedOutput = [0.0000, -0.2689, 0, 0.3112, 0.7312, 1.7616, 1000]
16584+
precision_4dps = 0.0002
16585+
16586+
input_tensor = torch.tensor(inputValues, dtype=dtype, device=device)
16587+
expected_output_tensor = torch.tensor(expectedOutput, dtype=dtype, device=device)
16588+
16589+
self.assertEqual(torch.nn.functional.silu(input_tensor),
16590+
expected_output_tensor,
16591+
atol=precision_4dps, rtol=0)
16592+
16593+
self.assertEqual(torch.nn.functional.silu(input_tensor, inplace=True),
16594+
expected_output_tensor,
16595+
atol=precision_4dps, rtol=0)
16596+
1658016597
@onlyCPU
1658116598
@dtypes(torch.float)
1658216599
def test_diag_embed(self, device, dtype):

tools/autograd/derivatives.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,12 @@
11811181
- name: relu_(Tensor(a!) self) -> Tensor(a!)
11821182
self: threshold_backward(grad, result, 0)
11831183

1184+
- name: silu(Tensor self) -> Tensor
1185+
self: silu_backward(grad, self)
1186+
1187+
- name: silu_(Tensor(a!) self) -> Tensor(a!)
1188+
self: not_implemented("silu_ cannot compute gradient of inplace version, use silu instead")
1189+
11841190
- name: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor
11851191
self: elu_backward(grad, alpha, scale, input_scale, result)
11861192

torch/_overrides.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ def get_testing_overrides():
528528
torch.nn.functional.relu6: lambda input, inplace=False: -1,
529529
torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1,
530530
torch.nn.functional.selu: lambda input, inplace=False: -1,
531+
torch.nn.functional.silu: lambda input, inplace=False: -1,
531532
torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
532533
torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
533534
torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,

torch/nn/functional.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,6 +1700,29 @@ def bilinear(input1, input2, weight, bias=None):
17001700
"""
17011701
return torch.bilinear(input1, input2, weight, bias)
17021702

1703+
def silu(input, inplace=False):
1704+
# type: (Tensor, bool) -> Tensor
1705+
r"""Applies the silu function, element-wise.
1706+
1707+
.. math::
1708+
\text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
1709+
1710+
.. note::
1711+
See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
1712+
where the SiLU (Sigmoid Linear Unit) was originally coined, and see
1713+
`Sigmoid-Weighted Linear Units for Neural Network Function Approximation
1714+
in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
1715+
a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
1716+
where the SiLU was experimented with later.
1717+
1718+
See :class:`~torch.nn.SiLU` for more details.
1719+
"""
1720+
if not torch.jit.is_scripting():
1721+
if type(input) is not Tensor and has_torch_function((input,)):
1722+
return handle_torch_function(silu, (input,), input, inplace=inplace)
1723+
if inplace:
1724+
return torch._C._nn.silu_(input)
1725+
return torch._C._nn.silu(input)
17031726

17041727
def hardswish(input, inplace=False):
17051728
# type: (Tensor, bool) -> Tensor

torch/nn/modules/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \
66
Softmax, Softmax2d, LogSoftmax, ELU, SELU, CELU, GELU, Hardshrink, LeakyReLU, LogSigmoid, \
77
Softplus, Softshrink, MultiheadAttention, PReLU, Softsign, Softmin, Tanhshrink, RReLU, GLU, \
8-
Hardsigmoid, Hardswish
8+
Hardsigmoid, Hardswish, SiLU
99
from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \
1010
CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \
1111
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, \
@@ -54,5 +54,5 @@
5454
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
5555
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
5656
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
57-
'Flatten', 'Hardsigmoid', 'Hardswish',
57+
'Flatten', 'Hardsigmoid', 'Hardswish', 'SiLU',
5858
]

0 commit comments

Comments
 (0)