@@ -42,166 +42,188 @@ static void threshold_kernel(
4242
4343#if AT_MKL_ENABLED()
4444
45- // TODO(yangxm): Consider to use TensorIterator here.
4645template <typename T>
47- void GeluKernelMKLImpl (const Tensor& X, Tensor* Y);
48-
49- #define DELEGATE_GELU_KERNEL_MKL_IMPL (T, CdfNormFunc, MulFunc ) \
50- template <> \
51- void GeluKernelMKLImpl<T>(const Tensor& X, Tensor* Y) { \
52- const int64_t N = X.numel (); \
53- const T* X_data = X.data_ptr <T>(); \
54- T* Y_data = Y->data_ptr <T>(); \
55- CdfNormFunc (N, X_data, Y_data); \
56- MulFunc (N, X_data, Y_data, Y_data); \
57- }
58- DELEGATE_GELU_KERNEL_MKL_IMPL (float , vsCdfNorm, vsMul)
59- DELEGATE_GELU_KERNEL_MKL_IMPL(double , vdCdfNorm, vdMul)
60- #undef DELEGATE_GELU_KERNEL_MKL_IMPL
46+ void MKLCdfNorm (int64_t N, const T* X, T* Y);
6147
62- #else // AT_MKL_ENABLED()
48+ template <>
49+ void MKLCdfNorm<float >(int64_t N, const float * X, float * Y) {
50+ vsCdfNorm (N, X, Y);
51+ }
52+
53+ template <>
54+ void MKLCdfNorm<double >(int64_t N, const double * X, double * Y) {
55+ vdCdfNorm (N, X, Y);
56+ }
6357
6458template <typename T>
65- void GeluKernelMKLImpl (const Tensor& X, Tensor* Y) {
66- AT_ASSERTM (false , " ATen not compiled with MKL" );
59+ void MKLMul (int64_t N, const T* A, const T* B, T* Y);
60+
61+ template <>
62+ void MKLMul<float >(int64_t N, const float * A, const float * B, float * Y) {
63+ vsMul (N, A, B, Y);
6764}
6865
69- #endif // AT_MKL_ENABLED()
66+ template <>
67+ void MKLMul<double >(int64_t N, const double * A, const double * B, double * Y) {
68+ vdMul (N, A, B, Y);
69+ }
7070
7171template <typename T>
72- void GeluKernelImplInternal (const Tensor& X, Tensor* Y) {
73- const int64_t N = X.numel ();
74- const T* X_data = X.data_ptr <T>();
75- T* Y_data = Y->data_ptr <T>();
76- for (int64_t i = 0 ; i < N; ++i) {
77- Y_data[i] = X_data[i] * M_SQRT1_2;
78- }
79- Y->erf_ ();
80- for (int64_t i = 0 ; i < N; ++i) {
81- Y_data[i] = (Y_data[i] + T (1 )) * X_data[i] * T (0.5 );
82- }
72+ void MKLExp (int64_t N, const T* X, T* Y);
73+
74+ template <>
75+ void MKLExp<float >(int64_t N, const float * X, float * Y) {
76+ vsExp (N, X, Y);
8377}
8478
85- // TODO(yangxm): Add another fast kernel using formula
86- // y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
87- // and the fast tanh impl from Eigen.
88- void GeluKernelImpl (const Tensor& X, Tensor* Y) {
89- if (at::hasMKL ()) {
90- AT_DISPATCH_FLOATING_TYPES (X.scalar_type (), " GeluKernelImpl" , [&]() {
91- GeluKernelMKLImpl<scalar_t >(X, Y);
92- });
93- } else {
94- AT_DISPATCH_FLOATING_TYPES (X.scalar_type (), " GeluKernelImpl" , [&]() {
95- GeluKernelImplInternal<scalar_t >(X, Y);
96- });
97- }
79+ template <>
80+ void MKLExp<double >(int64_t N, const double * X, double * Y) {
81+ vdExp (N, X, Y);
9882}
9983
100- #if AT_MKL_ENABLED()
84+ template <typename T>
85+ void GeluMKLKernelImpl (TensorIterator* it) {
86+ if (!it->can_use_32bit_indexing ()) {
87+ for (auto & sub_it : it->with_32bit_indexing ()) {
88+ GeluMKLKernelImpl<T>(&sub_it);
89+ }
90+ return ;
91+ }
92+ const int64_t N = it->numel ();
93+ const T* X_data = static_cast <T*>(it->data_ptr (1 ));
94+ T* Y_data = static_cast <T*>(it->data_ptr (0 ));
95+ MKLCdfNorm<T>(N, X_data, Y_data);
96+ MKLMul<T>(N, X_data, Y_data, Y_data);
97+ }
10198
10299template <typename T>
103- void GeluBackwardKernelMKLImpl (const Tensor& dY, const Tensor& X, Tensor* dX);
104-
105- // TODO(yangxm): Implement this by using template functions.
106- #define DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL (T, CdfNormFunc, ExpFunc ) \
107- template <> \
108- void GeluBackwardKernelMKLImpl<T>( \
109- const Tensor& dY, const Tensor& X, Tensor* dX) { \
110- constexpr T kAlpha = M_2_SQRTPI * M_SQRT1_2 * T (0.5 ); \
111- Tensor scratch = at::native::empty_like (X); \
112- const int64_t N = X.numel (); \
113- const T* dY_data = dY.data_ptr <T>(); \
114- const T* X_data = X.data_ptr <T>(); \
115- T* dX_data = dX->data_ptr <T>(); \
116- T* scratch_data = scratch.data_ptr <T>(); \
117- CdfNormFunc (N, X_data, scratch_data); \
118- for (int64_t i = 0 ; i < N; ++i) { \
119- dX_data[i] = -T (0.5 ) * X_data[i] * X_data[i]; \
120- } \
121- ExpFunc (N, dX_data, dX_data); \
122- for (int64_t i = 0 ; i < N; ++i) { \
123- dX_data[i] = \
124- dY_data[i] * (scratch_data[i] + X_data[i] * dX_data[i] * kAlpha ); \
125- } \
100+ void GeluBackwardMKLKernelImpl (TensorIterator* it) {
101+ if (!it->can_use_32bit_indexing ()) {
102+ for (auto & sub_it : it->with_32bit_indexing ()) {
103+ GeluBackwardMKLKernelImpl<T>(&sub_it);
104+ }
105+ return ;
106+ }
107+ constexpr T kBeta = M_2_SQRTPI * M_SQRT1_2 * T (0.5 );
108+ const int64_t N = it->numel ();
109+ const T* dY_data = static_cast <T*>(it->data_ptr (1 ));
110+ const T* X_data = static_cast <T*>(it->data_ptr (2 ));
111+ T* dX_data = static_cast <T*>(it->data_ptr (0 ));
112+ Tensor cdf = at::empty ({N}, it->input (1 ).options ());
113+ T* cdf_data = cdf.template data_ptr <T>();
114+ MKLCdfNorm<T>(N, X_data, cdf_data);
115+ for (int64_t i = 0 ; i < N; ++i) {
116+ dX_data[i] = T (-0.5 ) * X_data[i] * X_data[i];
117+ }
118+ MKLExp (N, dX_data, dX_data);
119+ for (int64_t i = 0 ; i < N; ++i) {
120+ dX_data[i] = dY_data[i] * (cdf_data[i] + kBeta * X_data[i] * dX_data[i]);
126121 }
127- DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL (float , vsCdfNorm, vsExp)
128- DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL(double , vdCdfNorm, vdExp)
129- #undef DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL
122+ }
130123
131124#else // AT_MKL_ENABLED()
132125
133126template <typename T>
134- void GeluBackwardKernelMKLImpl (const Tensor& dY, const Tensor& X, Tensor* dX) {
127+ void GeluMKLKernelImpl (TensorIterator* /* it */ ) {
128+ AT_ASSERTM (false , " ATen not compiled with MKL" );
129+ }
130+
131+ template <typename T>
132+ void GeluBackwardMKLKernelImpl (TensorIterator* /* it */ ) {
135133 AT_ASSERTM (false , " ATen not compiled with MKL" );
136134}
137135
138136#endif // AT_MKL_ENABLED()
139137
140- template <typename T>
141- void GeluBackwardKernelImplInternal (
142- const Tensor& dY,
143- const Tensor& X,
144- Tensor* dX) {
145- constexpr T kAlpha = M_2_SQRTPI * M_SQRT1_2 * T (0.5 );
146- Tensor scratch = at::native::empty_like (X);
147- const int64_t N = X.numel ();
148- const T* dY_data = dY.data_ptr <T>();
149- const T* X_data = X.data_ptr <T>();
150- T* dX_data = dX->data_ptr <T>();
151- T* scratch_data = scratch.data_ptr <T>();
152- for (int64_t i = 0 ; i < N; ++i) {
153- scratch_data[i] = X_data[i] * M_SQRT1_2;
154- dX_data[i] = -T (0.5 ) * X_data[i] * X_data[i];
155- }
156- // TODO(yangxm): Consider let forward pass preserve CdfNorm(X) in training
157- // pass to reduce this extra tensor.
158- scratch.erf_ ();
159- dX->exp_ ();
160- for (int64_t i = 0 ; i < N; ++i) {
161- dX_data[i] = dY_data[i] *
162- (T (0.5 ) * (T (1 ) + scratch_data[i]) + X_data[i] * dX_data[i] * kAlpha );
138+ // TODO(yangxm): Add another fast kernel using formula
139+ // y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
140+ // and the fast tanh impl from Eigen.
141+ void GeluKernelImpl (TensorIterator& it) {
142+ if (at::hasMKL () && it.is_contiguous ()) {
143+ AT_DISPATCH_FLOATING_TYPES (it.dtype (), " GeluKernelImpl" , [&]() {
144+ GeluMKLKernelImpl<scalar_t >(&it);
145+ });
146+ } else {
147+ AT_DISPATCH_FLOATING_TYPES (it.dtype (), " GeluKernelImpl" , [&]() {
148+ using Vec = vec256::Vec256<scalar_t >;
149+ const Vec kAlphaVec (M_SQRT1_2);
150+ const Vec kOneVec (1 );
151+ const Vec kPointFiveVec (0.5 );
152+ cpu_kernel_vec (
153+ it,
154+ [](scalar_t x) {
155+ constexpr scalar_t kAlpha = M_SQRT1_2;
156+ return x * scalar_t (0.5 ) * (scalar_t (1 ) + std::erf (x * kAlpha ));
157+ },
158+ [&](Vec x_vec) {
159+ return x_vec * kPointFiveVec *
160+ (kOneVec + (x_vec * kAlphaVec ).erf ());
161+ });
162+ });
163163 }
164164}
165165
166- void GeluBackwardKernelImpl (const Tensor& dY, const Tensor& X, Tensor* dX) {
167- if (hasMKL ()) {
168- AT_DISPATCH_FLOATING_TYPES (
169- X.scalar_type (), " GeluBackwardKernelImpl" , [&]() {
170- GeluBackwardKernelMKLImpl<scalar_t >(dY, X, dX);
171- });
166+ void GeluBackwardKernelImpl (TensorIterator& it) {
167+ if (hasMKL () && it.is_contiguous ()) {
168+ AT_DISPATCH_FLOATING_TYPES (it.dtype (), " GeluBackwardKernelImpl" , [&]() {
169+ GeluBackwardMKLKernelImpl<scalar_t >(&it);
170+ });
172171 } else {
173- AT_DISPATCH_FLOATING_TYPES (
174- X.scalar_type (), " GeluBackwardKernelImpl" , [&]() {
175- GeluBackwardKernelImplInternal<scalar_t >(dY, X, dX);
176- });
172+ AT_DISPATCH_FLOATING_TYPES (it.dtype (), " GeluBackwardKernelImpl" , [&]() {
173+ using Vec = vec256::Vec256<scalar_t >;
174+ const Vec kAlphaVec (M_SQRT1_2);
175+ const Vec kBetaVec (M_2_SQRTPI * M_SQRT1_2 * 0.5 );
176+ const Vec kOneVec (1 );
177+ const Vec kPointFiveVec (0.5 );
178+ const Vec kMinusPointFiveVec (-0.5 );
179+ cpu_kernel_vec (
180+ it,
181+ [](scalar_t dy, scalar_t x) {
182+ constexpr scalar_t kAlpha = M_SQRT1_2;
183+ constexpr scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5 ;
184+ const scalar_t cdf =
185+ scalar_t (0.5 ) * (scalar_t (1 ) + std::erf (x * kAlpha ));
186+ const scalar_t pdf = kBeta * std::exp (x * x * scalar_t (-0.5 ));
187+ return dy * (cdf + x * pdf);
188+ },
189+ [&](Vec dy_vec, Vec x_vec) {
190+ const Vec cdf_vec =
191+ kPointFiveVec * (kOneVec + (x_vec * kAlphaVec ).erf ());
192+ const Vec pdf_vec =
193+ kBetaVec * (x_vec * x_vec * kMinusPointFiveVec ).exp ();
194+ return dy_vec * (cdf_vec + x_vec * pdf_vec);
195+ });
196+ });
177197 }
178198}
179199
180200void hardshrink_cpu_kernel (TensorIterator& iter, Scalar lambd) {
181201 AT_DISPATCH_FLOATING_TYPES (iter.dtype (), " hardshrink_cpu" , [&] {
182202 auto lambd_val = lambd.to <scalar_t >();
183- cpu_kernel_vec (iter,
184- [=](scalar_t self_val) {
185- return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t (0 ) : self_val;
186- },
187- [=](Vec256<scalar_t > self_val) {
188- return ((self_val < -lambd_val) | (self_val > lambd_val)) & self_val;
189- }
190- );
203+ cpu_kernel_vec (
204+ iter,
205+ [=](scalar_t self_val) {
206+ return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t (0 )
207+ : self_val;
208+ },
209+ [=](Vec256<scalar_t > self_val) {
210+ return ((self_val < -lambd_val) | (self_val > lambd_val)) & self_val;
211+ });
191212 });
192213}
193214
194215void hardshrink_backward_cpu_kernel (TensorIterator& iter, Scalar lambd) {
195216 AT_DISPATCH_FLOATING_TYPES (iter.dtype (), " hardshrink_backward_cpu" , [&] {
196217 auto lambd_val = lambd.to <scalar_t >();
197- cpu_kernel_vec (iter,
198- [=](scalar_t grad_val, scalar_t self_val) {
199- return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t (0 ) : grad_val;
200- },
201- [=](Vec256<scalar_t > grad_val, Vec256<scalar_t > self_val) {
202- return ((self_val < -lambd_val) | (self_val > lambd_val)) & grad_val;
203- }
204- );
218+ cpu_kernel_vec (
219+ iter,
220+ [=](scalar_t grad_val, scalar_t self_val) {
221+ return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t (0 )
222+ : grad_val;
223+ },
224+ [=](Vec256<scalar_t > grad_val, Vec256<scalar_t > self_val) {
225+ return ((self_val < -lambd_val) | (self_val > lambd_val)) & grad_val;
226+ });
205227 });
206228}
207229
@@ -211,7 +233,9 @@ REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
211233REGISTER_DISPATCH (GeluKernel, &GeluKernelImpl);
212234REGISTER_DISPATCH (GeluBackwardKernel, &GeluBackwardKernelImpl);
213235REGISTER_DISPATCH (hardshrink_cpu_stub, &hardshrink_cpu_kernel);
214- REGISTER_DISPATCH (hardshrink_backward_cpu_stub, &hardshrink_backward_cpu_kernel);
236+ REGISTER_DISPATCH (
237+ hardshrink_backward_cpu_stub,
238+ &hardshrink_backward_cpu_kernel);
215239
216240} // namespace native
217241} // namespace at
0 commit comments