Skip to content

Commit 2460dce

Browse files
xiaomengyfacebook-github-bot
authored andcommitted
Add torch.nn.GELU for GELU activation (#28944)
Summary: Pull Request resolved: #28944 Add torch.nn.GELU for GELU activation Test Plan: buck test mode/dev-nosan //caffe2/test:nn -- "GELU" Reviewed By: hl475, houseroad Differential Revision: D18240946 fbshipit-source-id: 6284b30def9bd4c12bf7fb2ed08b1b2f0310bb78
1 parent 3bffb73 commit 2460dce

File tree

23 files changed

+298
-184
lines changed

23 files changed

+298
-184
lines changed

aten/src/ATen/core/aten_interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ _(aten, full) \
337337
_(aten, full_like) \
338338
_(aten, gather) \
339339
_(aten, ge) \
340+
_(aten, gelu) \
340341
_(aten, geometric) \
341342
_(aten, geqrf) \
342343
_(aten, ger) \

aten/src/ATen/native/Activation.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -353,16 +353,16 @@ Tensor hardshrink_backward_cpu(const Tensor & grad, const Tensor & self, Scalar
353353

354354

355355
Tensor gelu_cpu(const Tensor& self) {
356-
const auto X = self.contiguous();
357-
Tensor Y = at::native::empty_like(X);
358-
GeluKernel(kCPU, X, &Y);
356+
Tensor Y = at::native::empty_like(self);
357+
auto it = TensorIterator::unary_op(Y, self);
358+
GeluKernel(kCPU, it);
359359
return Y;
360360
}
361361

362362
Tensor gelu_backward_cpu(const Tensor& grad, const Tensor& self) {
363-
const auto X = self.contiguous();
364-
Tensor dX = at::native::empty_like(X);
365-
GeluBackwardKernel(kCPU, grad.contiguous(), X, &dX);
363+
Tensor dX = at::native::empty_like(self);
364+
auto it = TensorIterator::binary_op(dX, grad, self);
365+
GeluBackwardKernel(kCPU, it);
366366
return dX;
367367
}
368368

aten/src/ATen/native/Activation.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ struct TensorIterator;
1010

1111
namespace native {
1212

13+
using activation_fn = void (*)(TensorIterator&);
14+
using activation_backward_fn = void (*)(TensorIterator&);
1315
using threshold_fn = void (*)(TensorIterator&, Scalar, Scalar);
14-
using activation_fn = void (*)(const Tensor& /* X */, Tensor* /* Y */);
15-
using activation_backward_fn =
16-
void (*)(const Tensor& /* dY */, const Tensor& /* X */, Tensor* /* dX */);
1716
using hardshrink_cpu_fn = void (*)(TensorIterator&, Scalar);
1817
using hardshrink_backward_cpu_fn = void (*)(TensorIterator&, Scalar);
1918

aten/src/ATen/native/cpu/Activation.cpp

Lines changed: 145 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -42,166 +42,188 @@ static void threshold_kernel(
4242

4343
#if AT_MKL_ENABLED()
4444

45-
// TODO(yangxm): Consider to use TensorIterator here.
4645
template <typename T>
47-
void GeluKernelMKLImpl(const Tensor& X, Tensor* Y);
48-
49-
#define DELEGATE_GELU_KERNEL_MKL_IMPL(T, CdfNormFunc, MulFunc) \
50-
template <> \
51-
void GeluKernelMKLImpl<T>(const Tensor& X, Tensor* Y) { \
52-
const int64_t N = X.numel(); \
53-
const T* X_data = X.data_ptr<T>(); \
54-
T* Y_data = Y->data_ptr<T>(); \
55-
CdfNormFunc(N, X_data, Y_data); \
56-
MulFunc(N, X_data, Y_data, Y_data); \
57-
}
58-
DELEGATE_GELU_KERNEL_MKL_IMPL(float, vsCdfNorm, vsMul)
59-
DELEGATE_GELU_KERNEL_MKL_IMPL(double, vdCdfNorm, vdMul)
60-
#undef DELEGATE_GELU_KERNEL_MKL_IMPL
46+
void MKLCdfNorm(int64_t N, const T* X, T* Y);
6147

62-
#else // AT_MKL_ENABLED()
48+
template <>
49+
void MKLCdfNorm<float>(int64_t N, const float* X, float* Y) {
50+
vsCdfNorm(N, X, Y);
51+
}
52+
53+
template <>
54+
void MKLCdfNorm<double>(int64_t N, const double* X, double* Y) {
55+
vdCdfNorm(N, X, Y);
56+
}
6357

6458
template <typename T>
65-
void GeluKernelMKLImpl(const Tensor& X, Tensor* Y) {
66-
AT_ASSERTM(false, "ATen not compiled with MKL");
59+
void MKLMul(int64_t N, const T* A, const T* B, T* Y);
60+
61+
template <>
62+
void MKLMul<float>(int64_t N, const float* A, const float* B, float* Y) {
63+
vsMul(N, A, B, Y);
6764
}
6865

69-
#endif // AT_MKL_ENABLED()
66+
template <>
67+
void MKLMul<double>(int64_t N, const double* A, const double* B, double* Y) {
68+
vdMul(N, A, B, Y);
69+
}
7070

7171
template <typename T>
72-
void GeluKernelImplInternal(const Tensor& X, Tensor* Y) {
73-
const int64_t N = X.numel();
74-
const T* X_data = X.data_ptr<T>();
75-
T* Y_data = Y->data_ptr<T>();
76-
for (int64_t i = 0; i < N; ++i) {
77-
Y_data[i] = X_data[i] * M_SQRT1_2;
78-
}
79-
Y->erf_();
80-
for (int64_t i = 0; i < N; ++i) {
81-
Y_data[i] = (Y_data[i] + T(1)) * X_data[i] * T(0.5);
82-
}
72+
void MKLExp(int64_t N, const T* X, T* Y);
73+
74+
template <>
75+
void MKLExp<float>(int64_t N, const float* X, float* Y) {
76+
vsExp(N, X, Y);
8377
}
8478

85-
// TODO(yangxm): Add another fast kernel using formula
86-
// y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
87-
// and the fast tanh impl from Eigen.
88-
void GeluKernelImpl(const Tensor& X, Tensor* Y) {
89-
if (at::hasMKL()) {
90-
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GeluKernelImpl", [&]() {
91-
GeluKernelMKLImpl<scalar_t>(X, Y);
92-
});
93-
} else {
94-
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GeluKernelImpl", [&]() {
95-
GeluKernelImplInternal<scalar_t>(X, Y);
96-
});
97-
}
79+
template <>
80+
void MKLExp<double>(int64_t N, const double* X, double* Y) {
81+
vdExp(N, X, Y);
9882
}
9983

100-
#if AT_MKL_ENABLED()
84+
template <typename T>
85+
void GeluMKLKernelImpl(TensorIterator* it) {
86+
if (!it->can_use_32bit_indexing()) {
87+
for (auto& sub_it : it->with_32bit_indexing()) {
88+
GeluMKLKernelImpl<T>(&sub_it);
89+
}
90+
return;
91+
}
92+
const int64_t N = it->numel();
93+
const T* X_data = static_cast<T*>(it->data_ptr(1));
94+
T* Y_data = static_cast<T*>(it->data_ptr(0));
95+
MKLCdfNorm<T>(N, X_data, Y_data);
96+
MKLMul<T>(N, X_data, Y_data, Y_data);
97+
}
10198

10299
template <typename T>
103-
void GeluBackwardKernelMKLImpl(const Tensor& dY, const Tensor& X, Tensor* dX);
104-
105-
// TODO(yangxm): Implement this by using template functions.
106-
#define DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL(T, CdfNormFunc, ExpFunc) \
107-
template <> \
108-
void GeluBackwardKernelMKLImpl<T>( \
109-
const Tensor& dY, const Tensor& X, Tensor* dX) { \
110-
constexpr T kAlpha = M_2_SQRTPI * M_SQRT1_2 * T(0.5); \
111-
Tensor scratch = at::native::empty_like(X); \
112-
const int64_t N = X.numel(); \
113-
const T* dY_data = dY.data_ptr<T>(); \
114-
const T* X_data = X.data_ptr<T>(); \
115-
T* dX_data = dX->data_ptr<T>(); \
116-
T* scratch_data = scratch.data_ptr<T>(); \
117-
CdfNormFunc(N, X_data, scratch_data); \
118-
for (int64_t i = 0; i < N; ++i) { \
119-
dX_data[i] = -T(0.5) * X_data[i] * X_data[i]; \
120-
} \
121-
ExpFunc(N, dX_data, dX_data); \
122-
for (int64_t i = 0; i < N; ++i) { \
123-
dX_data[i] = \
124-
dY_data[i] * (scratch_data[i] + X_data[i] * dX_data[i] * kAlpha); \
125-
} \
100+
void GeluBackwardMKLKernelImpl(TensorIterator* it) {
101+
if (!it->can_use_32bit_indexing()) {
102+
for (auto& sub_it : it->with_32bit_indexing()) {
103+
GeluBackwardMKLKernelImpl<T>(&sub_it);
104+
}
105+
return;
106+
}
107+
constexpr T kBeta = M_2_SQRTPI * M_SQRT1_2 * T(0.5);
108+
const int64_t N = it->numel();
109+
const T* dY_data = static_cast<T*>(it->data_ptr(1));
110+
const T* X_data = static_cast<T*>(it->data_ptr(2));
111+
T* dX_data = static_cast<T*>(it->data_ptr(0));
112+
Tensor cdf = at::empty({N}, it->input(1).options());
113+
T* cdf_data = cdf.template data_ptr<T>();
114+
MKLCdfNorm<T>(N, X_data, cdf_data);
115+
for (int64_t i = 0; i < N; ++i) {
116+
dX_data[i] = T(-0.5) * X_data[i] * X_data[i];
117+
}
118+
MKLExp(N, dX_data, dX_data);
119+
for (int64_t i = 0; i < N; ++i) {
120+
dX_data[i] = dY_data[i] * (cdf_data[i] + kBeta * X_data[i] * dX_data[i]);
126121
}
127-
DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL(float, vsCdfNorm, vsExp)
128-
DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL(double, vdCdfNorm, vdExp)
129-
#undef DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL
122+
}
130123

131124
#else // AT_MKL_ENABLED()
132125

133126
template <typename T>
134-
void GeluBackwardKernelMKLImpl(const Tensor& dY, const Tensor& X, Tensor* dX) {
127+
void GeluMKLKernelImpl(TensorIterator* /* it */) {
128+
AT_ASSERTM(false, "ATen not compiled with MKL");
129+
}
130+
131+
template <typename T>
132+
void GeluBackwardMKLKernelImpl(TensorIterator* /* it */) {
135133
AT_ASSERTM(false, "ATen not compiled with MKL");
136134
}
137135

138136
#endif // AT_MKL_ENABLED()
139137

140-
template <typename T>
141-
void GeluBackwardKernelImplInternal(
142-
const Tensor& dY,
143-
const Tensor& X,
144-
Tensor* dX) {
145-
constexpr T kAlpha = M_2_SQRTPI * M_SQRT1_2 * T(0.5);
146-
Tensor scratch = at::native::empty_like(X);
147-
const int64_t N = X.numel();
148-
const T* dY_data = dY.data_ptr<T>();
149-
const T* X_data = X.data_ptr<T>();
150-
T* dX_data = dX->data_ptr<T>();
151-
T* scratch_data = scratch.data_ptr<T>();
152-
for (int64_t i = 0; i < N; ++i) {
153-
scratch_data[i] = X_data[i] * M_SQRT1_2;
154-
dX_data[i] = -T(0.5) * X_data[i] * X_data[i];
155-
}
156-
// TODO(yangxm): Consider let forward pass preserve CdfNorm(X) in training
157-
// pass to reduce this extra tensor.
158-
scratch.erf_();
159-
dX->exp_();
160-
for (int64_t i = 0; i < N; ++i) {
161-
dX_data[i] = dY_data[i] *
162-
(T(0.5) * (T(1) + scratch_data[i]) + X_data[i] * dX_data[i] * kAlpha);
138+
// TODO(yangxm): Add another fast kernel using formula
139+
// y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
140+
// and the fast tanh impl from Eigen.
141+
void GeluKernelImpl(TensorIterator& it) {
142+
if (at::hasMKL() && it.is_contiguous()) {
143+
AT_DISPATCH_FLOATING_TYPES(it.dtype(), "GeluKernelImpl", [&]() {
144+
GeluMKLKernelImpl<scalar_t>(&it);
145+
});
146+
} else {
147+
AT_DISPATCH_FLOATING_TYPES(it.dtype(), "GeluKernelImpl", [&]() {
148+
using Vec = vec256::Vec256<scalar_t>;
149+
const Vec kAlphaVec(M_SQRT1_2);
150+
const Vec kOneVec(1);
151+
const Vec kPointFiveVec(0.5);
152+
cpu_kernel_vec(
153+
it,
154+
[](scalar_t x) {
155+
constexpr scalar_t kAlpha = M_SQRT1_2;
156+
return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
157+
},
158+
[&](Vec x_vec) {
159+
return x_vec * kPointFiveVec *
160+
(kOneVec + (x_vec * kAlphaVec).erf());
161+
});
162+
});
163163
}
164164
}
165165

166-
void GeluBackwardKernelImpl(const Tensor& dY, const Tensor& X, Tensor* dX) {
167-
if (hasMKL()) {
168-
AT_DISPATCH_FLOATING_TYPES(
169-
X.scalar_type(), "GeluBackwardKernelImpl", [&]() {
170-
GeluBackwardKernelMKLImpl<scalar_t>(dY, X, dX);
171-
});
166+
void GeluBackwardKernelImpl(TensorIterator& it) {
167+
if (hasMKL() && it.is_contiguous()) {
168+
AT_DISPATCH_FLOATING_TYPES(it.dtype(), "GeluBackwardKernelImpl", [&]() {
169+
GeluBackwardMKLKernelImpl<scalar_t>(&it);
170+
});
172171
} else {
173-
AT_DISPATCH_FLOATING_TYPES(
174-
X.scalar_type(), "GeluBackwardKernelImpl", [&]() {
175-
GeluBackwardKernelImplInternal<scalar_t>(dY, X, dX);
176-
});
172+
AT_DISPATCH_FLOATING_TYPES(it.dtype(), "GeluBackwardKernelImpl", [&]() {
173+
using Vec = vec256::Vec256<scalar_t>;
174+
const Vec kAlphaVec(M_SQRT1_2);
175+
const Vec kBetaVec(M_2_SQRTPI * M_SQRT1_2 * 0.5);
176+
const Vec kOneVec(1);
177+
const Vec kPointFiveVec(0.5);
178+
const Vec kMinusPointFiveVec(-0.5);
179+
cpu_kernel_vec(
180+
it,
181+
[](scalar_t dy, scalar_t x) {
182+
constexpr scalar_t kAlpha = M_SQRT1_2;
183+
constexpr scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5;
184+
const scalar_t cdf =
185+
scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
186+
const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5));
187+
return dy * (cdf + x * pdf);
188+
},
189+
[&](Vec dy_vec, Vec x_vec) {
190+
const Vec cdf_vec =
191+
kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf());
192+
const Vec pdf_vec =
193+
kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp();
194+
return dy_vec * (cdf_vec + x_vec * pdf_vec);
195+
});
196+
});
177197
}
178198
}
179199

180200
void hardshrink_cpu_kernel(TensorIterator& iter, Scalar lambd) {
181201
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_cpu", [&] {
182202
auto lambd_val = lambd.to<scalar_t>();
183-
cpu_kernel_vec(iter,
184-
[=](scalar_t self_val) {
185-
return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0) : self_val;
186-
},
187-
[=](Vec256<scalar_t> self_val) {
188-
return ((self_val < -lambd_val) | (self_val > lambd_val)) & self_val;
189-
}
190-
);
203+
cpu_kernel_vec(
204+
iter,
205+
[=](scalar_t self_val) {
206+
return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0)
207+
: self_val;
208+
},
209+
[=](Vec256<scalar_t> self_val) {
210+
return ((self_val < -lambd_val) | (self_val > lambd_val)) & self_val;
211+
});
191212
});
192213
}
193214

194215
void hardshrink_backward_cpu_kernel(TensorIterator& iter, Scalar lambd) {
195216
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&] {
196217
auto lambd_val = lambd.to<scalar_t>();
197-
cpu_kernel_vec(iter,
198-
[=](scalar_t grad_val, scalar_t self_val) {
199-
return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0) : grad_val;
200-
},
201-
[=](Vec256<scalar_t> grad_val, Vec256<scalar_t> self_val) {
202-
return ((self_val < -lambd_val) | (self_val > lambd_val)) & grad_val;
203-
}
204-
);
218+
cpu_kernel_vec(
219+
iter,
220+
[=](scalar_t grad_val, scalar_t self_val) {
221+
return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0)
222+
: grad_val;
223+
},
224+
[=](Vec256<scalar_t> grad_val, Vec256<scalar_t> self_val) {
225+
return ((self_val < -lambd_val) | (self_val > lambd_val)) & grad_val;
226+
});
205227
});
206228
}
207229

@@ -211,7 +233,9 @@ REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
211233
REGISTER_DISPATCH(GeluKernel, &GeluKernelImpl);
212234
REGISTER_DISPATCH(GeluBackwardKernel, &GeluBackwardKernelImpl);
213235
REGISTER_DISPATCH(hardshrink_cpu_stub, &hardshrink_cpu_kernel);
214-
REGISTER_DISPATCH(hardshrink_backward_cpu_stub, &hardshrink_backward_cpu_kernel);
236+
REGISTER_DISPATCH(
237+
hardshrink_backward_cpu_stub,
238+
&hardshrink_backward_cpu_kernel);
215239

216240
} // namespace native
217241
} // namespace at

0 commit comments

Comments
 (0)