Skip to content

Commit baeb0b8

Browse files
xiaomengyfacebook-github-bot
authored andcommitted
Add gelu activation in pytorch (#20665)
Summary: Pull Request resolved: #20665 Add gelu activation forward on CPU in pytorch Compare to current python implemented version of gelu in BERT model like def gelu(self, x): x * 0.5 * (1.0 + torch.erf(x / self.sqrt_two)) The torch.nn.functional.gelu function can reduce the forward time from 333ms to 109ms (with MKL) / 112ms (without MKL) for input size = [64, 128, 56, 56] on a devvm. Reviewed By: zheng-xq Differential Revision: D15400974 fbshipit-source-id: 78399123aef803376a2459d487d44557126070ac
1 parent aac424a commit baeb0b8

File tree

10 files changed

+193
-17
lines changed

10 files changed

+193
-17
lines changed

aten/src/ATen/native/Activation.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,4 +371,20 @@ Tensor hardshrink_backward_cpu(const Tensor & grad, const Tensor & self, Scalar
371371
return out_tensor;
372372
}
373373

374+
375+
Tensor gelu_cpu(const Tensor& self) {
376+
const auto X = self.contiguous();
377+
Tensor Y = at::native::empty_like(X);
378+
GeluKernel(kCPU, X, &Y);
379+
return Y;
380+
}
381+
382+
Tensor gelu_cuda(const Tensor& self) {
383+
Tensor Y = at::native::empty_like(self);
384+
GeluKernel(kCUDA, self, &Y);
385+
return Y;
386+
}
387+
388+
DEFINE_DISPATCH(GeluKernel);
389+
374390
}} // namespace at::native

aten/src/ATen/native/Activation.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
#pragma once
22

3-
#include <c10/core/Scalar.h>
3+
#include <ATen/ATen.h>
44
#include <ATen/native/DispatchStub.h>
5+
#include <c10/core/Scalar.h>
6+
7+
namespace at {
58

6-
namespace at { struct TensorIterator; }
9+
struct TensorIterator;
710

8-
namespace at { namespace native {
11+
namespace native {
912

10-
using threshold_fn = void(*)(TensorIterator&, Scalar, Scalar);
13+
using threshold_fn = void (*)(TensorIterator&, Scalar, Scalar);
14+
using activation_fn = void (*)(const Tensor& /* X */, Tensor* /* Y */);
1115

1216
DECLARE_DISPATCH(threshold_fn, threshold_stub);
17+
DECLARE_DISPATCH(activation_fn, GeluKernel);
1318

19+
} // namespace native
1420

15-
}} // namespace at::native
21+
} // namespace at
Lines changed: 84 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,104 @@
1+
#define _USE_MATH_DEFINES
2+
13
#include <ATen/native/Activation.h>
24

5+
#include <math.h>
6+
37
#include <ATen/ATen.h>
8+
#include <ATen/Config.h>
49
#include <ATen/cpu/vec256/vec256.h>
510
#include <ATen/native/TensorIterator.h>
611
#include <ATen/native/cpu/Loops.h>
712

8-
namespace at { namespace native {
13+
#if AT_MKL_ENABLED()
14+
#include <mkl.h>
15+
#endif // AT_MKL_ENABLED()
16+
17+
namespace at {
18+
namespace native {
19+
920
namespace {
1021

11-
static void threshold_kernel(TensorIterator& iter, Scalar threshold_scalar, Scalar value_scalar) {
22+
static void threshold_kernel(
23+
TensorIterator& iter,
24+
Scalar threshold_scalar,
25+
Scalar value_scalar) {
1226
AT_DISPATCH_ALL_TYPES(iter.dtype(), "threshold_cpu", [&] {
1327
using Vec = Vec256<scalar_t>;
1428
scalar_t threshold = threshold_scalar.to<scalar_t>();
1529
scalar_t value = value_scalar.to<scalar_t>();
1630
binary_kernel_vec(
17-
iter,
18-
[&](scalar_t x, scalar_t other) -> scalar_t {
19-
return x <= threshold ? value : other;
20-
},
21-
[&](Vec x, Vec other) -> Vec {
22-
return Vec::blendv(other, Vec(value), x <= Vec(threshold));
23-
});
31+
iter,
32+
[&](scalar_t x, scalar_t other) -> scalar_t {
33+
return x <= threshold ? value : other;
34+
},
35+
[&](Vec x, Vec other) -> Vec {
36+
return Vec::blendv(other, Vec(value), x <= Vec(threshold));
37+
});
2438
});
2539
}
2640

27-
} // anonymous namespace
41+
#if AT_MKL_ENABLED()
42+
43+
// TODO(yangxm): Consider to use TensorIterator here.
44+
template <typename T>
45+
void GeluKernelMKLImpl(const Tensor& X, Tensor* Y);
46+
47+
#define DELEGATE_GELU_KERNEL_MKL_IMPL(T, CdfNormFunc, MulFunc) \
48+
template <> \
49+
void GeluKernelMKLImpl<T>(const Tensor& X, Tensor* Y) { \
50+
const int64_t N = X.numel(); \
51+
const T* X_data = X.data<T>(); \
52+
T* Y_data = Y->data<T>(); \
53+
CdfNormFunc(N, X_data, Y_data); \
54+
MulFunc(N, X_data, Y_data, Y_data); \
55+
}
56+
DELEGATE_GELU_KERNEL_MKL_IMPL(float, vsCdfNorm, vsMul)
57+
DELEGATE_GELU_KERNEL_MKL_IMPL(double, vdCdfNorm, vdMul)
58+
#undef DELEGATE_GELU_KERNEL_MKL_IMPL
59+
60+
#else // AT_MKL_ENABLED()
61+
62+
template <typename T>
63+
void GeluKernelMKLImpl(const Tensor& X, Tensor* Y) {
64+
AT_ASSERTM(false, "ATen not compiled with MKL");
65+
}
66+
67+
#endif // AT_MKL_ENABLED()
68+
69+
template <typename T>
70+
void GeluKernelImplInternal(const Tensor& X, Tensor* Y) {
71+
const int64_t N = X.numel();
72+
const T* X_data = X.data<T>();
73+
T* Y_data = Y->data<T>();
74+
for (int64_t i = 0; i < N; ++i) {
75+
Y_data[i] = X_data[i] * M_SQRT1_2;
76+
}
77+
Y->erf_();
78+
for (int64_t i = 0; i < N; ++i) {
79+
Y_data[i] = (Y_data[i] + T(1)) * X_data[i] * T(0.5);
80+
}
81+
}
82+
83+
// TODO(yangxm): Add another fast kernel using formula
84+
// y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
85+
// and the fast tanh impl from Eigen.
86+
void GeluKernelImpl(const Tensor& X, Tensor* Y) {
87+
if (at::hasMKL()) {
88+
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GeluKernelImpl", [&]() {
89+
GeluKernelMKLImpl<scalar_t>(X, Y);
90+
});
91+
} else {
92+
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GeluKernelImpl", [&]() {
93+
GeluKernelImplInternal<scalar_t>(X, Y);
94+
});
95+
}
96+
}
97+
98+
} // namespace
2899

29100
REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
101+
REGISTER_DISPATCH(GeluKernel, &GeluKernelImpl);
30102

31-
}} // namespace at::native
103+
} // namespace native
104+
} // namespace at

aten/src/ATen/native/cuda/Activation.cu

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
#include <ATen/native/Activation.h>
2+
13
#include <ATen/ATen.h>
24
#include <ATen/NativeFunctions.h>
35
#include <ATen/Dispatch.h>
46
#include <ATen/cuda/CUDAApplyUtils.cuh>
57
#include <ATen/cuda/detail/IndexUtils.cuh>
6-
#include <ATen/native/Activation.h>
78
#include <ATen/native/cuda/Loops.cuh>
9+
#include <c10/cuda/CUDAMathCompat.h>
810

911

1012
namespace at { namespace native {
@@ -291,6 +293,24 @@ static void threshold_kernel(TensorIterator& iter, Scalar threshold, Scalar valu
291293
});
292294
}
293295

296+
namespace {
297+
298+
template <typename T>
299+
void GeluCUDAKernelImplInternal(const Tensor& X, Tensor* Y) {
300+
at::cuda::CUDA_tensor_apply2<T, T>(X, *Y, [] __device__(const T& x, T& y) {
301+
y = x * c10::cuda::compat::normcdf(x);
302+
});
303+
}
304+
305+
void GeluCUDAKernelImpl(const Tensor& X, Tensor* Y) {
306+
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GeluCUDAKernelImpl", [&]() {
307+
GeluCUDAKernelImplInternal<scalar_t>(X, Y);
308+
});
309+
}
310+
311+
} // namespace
312+
294313
REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
314+
REGISTER_DISPATCH(GeluKernel, &GeluCUDAKernelImpl);
295315

296316
}} // namespace at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,6 +1587,12 @@
15871587
CPU: prelu_backward_cpu
15881588
CUDA: prelu_backward_cuda
15891589

1590+
- func: gelu(Tensor self) -> Tensor
1591+
python_module: nn
1592+
dispatch:
1593+
CPU: gelu_cpu
1594+
CUDA: gelu_cuda
1595+
15901596
- func: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor
15911597
variants: function, method
15921598
dispatch:

c10/cuda/CUDAMathCompat.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ __MATH_FUNCTIONS_DECL__ double tan(double x) {
8484
return ::tan(x);
8585
}
8686

87+
__MATH_FUNCTIONS_DECL__ float normcdf(float x) {
88+
return ::normcdff(x);
89+
}
90+
__MATH_FUNCTIONS_DECL__ double normcdf(double x) {
91+
return ::normcdf(x);
92+
}
93+
8794
} // namespace compat
8895
} // namespace cuda
8996
} // namespace c10

docs/source/nn.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,11 @@ Non-linear activation functions
10461046

10471047
.. autofunction:: glu
10481048

1049+
:hidden:`gelu`
1050+
~~~~~~~~~~~~~~~
1051+
1052+
.. autofunction:: gelu
1053+
10491054
:hidden:`logsigmoid`
10501055
~~~~~~~~~~~~~~~~~~~~
10511056

test/test_nn.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6059,6 +6059,33 @@ def test_PReLU_backward_requires_grad_false(self):
60596059
y.mean().backward()
60606060
self.assertEqual(x.grad, None)
60616061

6062+
@unittest.skipIf(
6063+
not TEST_NUMPY or not TEST_SCIPY, "Numpy or Scipy not found")
6064+
def test_gelu(self):
6065+
def _test_gelu(n, m, dtype, contiguous):
6066+
def _gelu_ref(X):
6067+
return X * stats.norm.cdf(X)
6068+
6069+
if contiguous:
6070+
X = torch.rand(n, m, dtype=dtype)
6071+
else:
6072+
X = torch.rand(n, m, dtype=dtype)[:, ::2]
6073+
res = F.gelu(X)
6074+
ref = _gelu_ref(X.numpy())
6075+
self.assertEqual(res, ref)
6076+
6077+
if TEST_CUDA:
6078+
res_cuda = F.gelu(X.cuda())
6079+
self.assertEqual(res_cuda.cpu(), ref)
6080+
6081+
for n in range(1, 10):
6082+
for m in range(1, 10):
6083+
_test_gelu(n, m, torch.float32, True)
6084+
_test_gelu(n, m, torch.float32, False)
6085+
_test_gelu(n, m, torch.float64, True)
6086+
_test_gelu(n, m, torch.float64, False)
6087+
6088+
60626089
def test_bce_loss_always_nonnegative(self):
60636090
target = torch.ones(5)
60646091
input = torch.ones(5)

tools/autograd/derivatives.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,9 @@
10341034
- name: glu(Tensor self, int64_t dim)
10351035
self: glu_backward(grad, self, dim)
10361036

1037+
- name: gelu(Tensor self)
1038+
self: not_implemented("gelu")
1039+
10371040
- name: hardshrink(Tensor self, Scalar lambd)
10381041
self: hardshrink_backward(grad, self, lambd)
10391042

torch/nn/functional.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,19 @@ def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False):
11481148
See :class:`~torch.nn.LogSigmoid` for more details.
11491149
""")
11501150

1151+
@weak_script
1152+
def gelu(input):
1153+
r"""gelu(input) -> Tensor
1154+
1155+
Applies element-wise the function
1156+
:math:`\text{GeLU}(x) = x * \Phi(x)`
1157+
1158+
where `\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
1159+
1160+
See :`Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`.
1161+
"""
1162+
return torch._C._nn.gelu(input)
1163+
11511164

11521165
@weak_script
11531166
def hardshrink(input, lambd=0.5):

0 commit comments

Comments
 (0)