Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ _(aten, full) \
_(aten, full_like) \
_(aten, gather) \
_(aten, ge) \
_(aten, gelu) \
_(aten, geometric) \
_(aten, geqrf) \
_(aten, ger) \
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/native/Activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,16 +353,16 @@ Tensor hardshrink_backward_cpu(const Tensor & grad, const Tensor & self, Scalar


Tensor gelu_cpu(const Tensor& self) {
const auto X = self.contiguous();
Tensor Y = at::native::empty_like(X);
GeluKernel(kCPU, X, &Y);
Tensor Y = at::native::empty_like(self);
auto it = TensorIterator::unary_op(Y, self);
GeluKernel(kCPU, it);
return Y;
}

Tensor gelu_backward_cpu(const Tensor& grad, const Tensor& self) {
const auto X = self.contiguous();
Tensor dX = at::native::empty_like(X);
GeluBackwardKernel(kCPU, grad.contiguous(), X, &dX);
Tensor dX = at::native::empty_like(self);
auto it = TensorIterator::binary_op(dX, grad, self);
GeluBackwardKernel(kCPU, it);
return dX;
}

Expand Down
5 changes: 2 additions & 3 deletions aten/src/ATen/native/Activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ struct TensorIterator;

namespace native {

using activation_fn = void (*)(TensorIterator&);
using activation_backward_fn = void (*)(TensorIterator&);
using threshold_fn = void (*)(TensorIterator&, Scalar, Scalar);
using activation_fn = void (*)(const Tensor& /* X */, Tensor* /* Y */);
using activation_backward_fn =
void (*)(const Tensor& /* dY */, const Tensor& /* X */, Tensor* /* dX */);
using hardshrink_cpu_fn = void (*)(TensorIterator&, Scalar);
using hardshrink_backward_cpu_fn = void (*)(TensorIterator&, Scalar);

Expand Down
266 changes: 145 additions & 121 deletions aten/src/ATen/native/cpu/Activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,166 +42,188 @@ static void threshold_kernel(

#if AT_MKL_ENABLED()

// TODO(yangxm): Consider to use TensorIterator here.
template <typename T>
void GeluKernelMKLImpl(const Tensor& X, Tensor* Y);

#define DELEGATE_GELU_KERNEL_MKL_IMPL(T, CdfNormFunc, MulFunc) \
template <> \
void GeluKernelMKLImpl<T>(const Tensor& X, Tensor* Y) { \
const int64_t N = X.numel(); \
const T* X_data = X.data_ptr<T>(); \
T* Y_data = Y->data_ptr<T>(); \
CdfNormFunc(N, X_data, Y_data); \
MulFunc(N, X_data, Y_data, Y_data); \
}
DELEGATE_GELU_KERNEL_MKL_IMPL(float, vsCdfNorm, vsMul)
DELEGATE_GELU_KERNEL_MKL_IMPL(double, vdCdfNorm, vdMul)
#undef DELEGATE_GELU_KERNEL_MKL_IMPL
void MKLCdfNorm(int64_t N, const T* X, T* Y);

#else // AT_MKL_ENABLED()
template <>
void MKLCdfNorm<float>(int64_t N, const float* X, float* Y) {
vsCdfNorm(N, X, Y);
}

template <>
void MKLCdfNorm<double>(int64_t N, const double* X, double* Y) {
vdCdfNorm(N, X, Y);
}

template <typename T>
void GeluKernelMKLImpl(const Tensor& X, Tensor* Y) {
AT_ASSERTM(false, "ATen not compiled with MKL");
void MKLMul(int64_t N, const T* A, const T* B, T* Y);

template <>
void MKLMul<float>(int64_t N, const float* A, const float* B, float* Y) {
vsMul(N, A, B, Y);
}

#endif // AT_MKL_ENABLED()
template <>
void MKLMul<double>(int64_t N, const double* A, const double* B, double* Y) {
vdMul(N, A, B, Y);
}

template <typename T>
void GeluKernelImplInternal(const Tensor& X, Tensor* Y) {
const int64_t N = X.numel();
const T* X_data = X.data_ptr<T>();
T* Y_data = Y->data_ptr<T>();
for (int64_t i = 0; i < N; ++i) {
Y_data[i] = X_data[i] * M_SQRT1_2;
}
Y->erf_();
for (int64_t i = 0; i < N; ++i) {
Y_data[i] = (Y_data[i] + T(1)) * X_data[i] * T(0.5);
}
void MKLExp(int64_t N, const T* X, T* Y);

template <>
void MKLExp<float>(int64_t N, const float* X, float* Y) {
vsExp(N, X, Y);
}

// TODO(yangxm): Add another fast kernel using formula
// y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
// and the fast tanh impl from Eigen.
void GeluKernelImpl(const Tensor& X, Tensor* Y) {
if (at::hasMKL()) {
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GeluKernelImpl", [&]() {
GeluKernelMKLImpl<scalar_t>(X, Y);
});
} else {
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GeluKernelImpl", [&]() {
GeluKernelImplInternal<scalar_t>(X, Y);
});
}
template <>
void MKLExp<double>(int64_t N, const double* X, double* Y) {
vdExp(N, X, Y);
}

#if AT_MKL_ENABLED()
template <typename T>
void GeluMKLKernelImpl(TensorIterator* it) {
if (!it->can_use_32bit_indexing()) {
for (auto& sub_it : it->with_32bit_indexing()) {
GeluMKLKernelImpl<T>(&sub_it);
}
return;
}
const int64_t N = it->numel();
const T* X_data = static_cast<T*>(it->data_ptr(1));
T* Y_data = static_cast<T*>(it->data_ptr(0));
MKLCdfNorm<T>(N, X_data, Y_data);
MKLMul<T>(N, X_data, Y_data, Y_data);
}

template <typename T>
void GeluBackwardKernelMKLImpl(const Tensor& dY, const Tensor& X, Tensor* dX);

// TODO(yangxm): Implement this by using template functions.
#define DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL(T, CdfNormFunc, ExpFunc) \
template <> \
void GeluBackwardKernelMKLImpl<T>( \
const Tensor& dY, const Tensor& X, Tensor* dX) { \
constexpr T kAlpha = M_2_SQRTPI * M_SQRT1_2 * T(0.5); \
Tensor scratch = at::native::empty_like(X); \
const int64_t N = X.numel(); \
const T* dY_data = dY.data_ptr<T>(); \
const T* X_data = X.data_ptr<T>(); \
T* dX_data = dX->data_ptr<T>(); \
T* scratch_data = scratch.data_ptr<T>(); \
CdfNormFunc(N, X_data, scratch_data); \
for (int64_t i = 0; i < N; ++i) { \
dX_data[i] = -T(0.5) * X_data[i] * X_data[i]; \
} \
ExpFunc(N, dX_data, dX_data); \
for (int64_t i = 0; i < N; ++i) { \
dX_data[i] = \
dY_data[i] * (scratch_data[i] + X_data[i] * dX_data[i] * kAlpha); \
} \
void GeluBackwardMKLKernelImpl(TensorIterator* it) {
if (!it->can_use_32bit_indexing()) {
for (auto& sub_it : it->with_32bit_indexing()) {
GeluBackwardMKLKernelImpl<T>(&sub_it);
}
return;
}
constexpr T kBeta = M_2_SQRTPI * M_SQRT1_2 * T(0.5);
const int64_t N = it->numel();
const T* dY_data = static_cast<T*>(it->data_ptr(1));
const T* X_data = static_cast<T*>(it->data_ptr(2));
T* dX_data = static_cast<T*>(it->data_ptr(0));
Tensor cdf = at::empty({N}, it->input(1).options());
T* cdf_data = cdf.template data_ptr<T>();
MKLCdfNorm<T>(N, X_data, cdf_data);
for (int64_t i = 0; i < N; ++i) {
dX_data[i] = T(-0.5) * X_data[i] * X_data[i];
}
MKLExp(N, dX_data, dX_data);
for (int64_t i = 0; i < N; ++i) {
dX_data[i] = dY_data[i] * (cdf_data[i] + kBeta * X_data[i] * dX_data[i]);
}
DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL(float, vsCdfNorm, vsExp)
DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL(double, vdCdfNorm, vdExp)
#undef DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL
}

#else // AT_MKL_ENABLED()

template <typename T>
void GeluBackwardKernelMKLImpl(const Tensor& dY, const Tensor& X, Tensor* dX) {
void GeluMKLKernelImpl(TensorIterator* /* it */) {
AT_ASSERTM(false, "ATen not compiled with MKL");
}

template <typename T>
void GeluBackwardMKLKernelImpl(TensorIterator* /* it */) {
AT_ASSERTM(false, "ATen not compiled with MKL");
}

#endif // AT_MKL_ENABLED()

template <typename T>
void GeluBackwardKernelImplInternal(
const Tensor& dY,
const Tensor& X,
Tensor* dX) {
constexpr T kAlpha = M_2_SQRTPI * M_SQRT1_2 * T(0.5);
Tensor scratch = at::native::empty_like(X);
const int64_t N = X.numel();
const T* dY_data = dY.data_ptr<T>();
const T* X_data = X.data_ptr<T>();
T* dX_data = dX->data_ptr<T>();
T* scratch_data = scratch.data_ptr<T>();
for (int64_t i = 0; i < N; ++i) {
scratch_data[i] = X_data[i] * M_SQRT1_2;
dX_data[i] = -T(0.5) * X_data[i] * X_data[i];
}
// TODO(yangxm): Consider let forward pass preserve CdfNorm(X) in training
// pass to reduce this extra tensor.
scratch.erf_();
dX->exp_();
for (int64_t i = 0; i < N; ++i) {
dX_data[i] = dY_data[i] *
(T(0.5) * (T(1) + scratch_data[i]) + X_data[i] * dX_data[i] * kAlpha);
// TODO(yangxm): Add another fast kernel using formula
// y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
// and the fast tanh impl from Eigen.
void GeluKernelImpl(TensorIterator& it) {
if (at::hasMKL() && it.is_contiguous()) {
AT_DISPATCH_FLOATING_TYPES(it.dtype(), "GeluKernelImpl", [&]() {
GeluMKLKernelImpl<scalar_t>(&it);
});
} else {
AT_DISPATCH_FLOATING_TYPES(it.dtype(), "GeluKernelImpl", [&]() {
using Vec = vec256::Vec256<scalar_t>;
const Vec kAlphaVec(M_SQRT1_2);
const Vec kOneVec(1);
const Vec kPointFiveVec(0.5);
cpu_kernel_vec(
it,
[](scalar_t x) {
constexpr scalar_t kAlpha = M_SQRT1_2;
return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
},
[&](Vec x_vec) {
return x_vec * kPointFiveVec *
(kOneVec + (x_vec * kAlphaVec).erf());
});
});
}
}

void GeluBackwardKernelImpl(const Tensor& dY, const Tensor& X, Tensor* dX) {
if (hasMKL()) {
AT_DISPATCH_FLOATING_TYPES(
X.scalar_type(), "GeluBackwardKernelImpl", [&]() {
GeluBackwardKernelMKLImpl<scalar_t>(dY, X, dX);
});
void GeluBackwardKernelImpl(TensorIterator& it) {
if (hasMKL() && it.is_contiguous()) {
AT_DISPATCH_FLOATING_TYPES(it.dtype(), "GeluBackwardKernelImpl", [&]() {
GeluBackwardMKLKernelImpl<scalar_t>(&it);
});
} else {
AT_DISPATCH_FLOATING_TYPES(
X.scalar_type(), "GeluBackwardKernelImpl", [&]() {
GeluBackwardKernelImplInternal<scalar_t>(dY, X, dX);
});
AT_DISPATCH_FLOATING_TYPES(it.dtype(), "GeluBackwardKernelImpl", [&]() {
using Vec = vec256::Vec256<scalar_t>;
const Vec kAlphaVec(M_SQRT1_2);
const Vec kBetaVec(M_2_SQRTPI * M_SQRT1_2 * 0.5);
const Vec kOneVec(1);
const Vec kPointFiveVec(0.5);
const Vec kMinusPointFiveVec(-0.5);
cpu_kernel_vec(
it,
[](scalar_t dy, scalar_t x) {
constexpr scalar_t kAlpha = M_SQRT1_2;
constexpr scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5;
const scalar_t cdf =
scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5));
return dy * (cdf + x * pdf);
},
[&](Vec dy_vec, Vec x_vec) {
const Vec cdf_vec =
kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf());
const Vec pdf_vec =
kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp();
return dy_vec * (cdf_vec + x_vec * pdf_vec);
});
});
}
}

void hardshrink_cpu_kernel(TensorIterator& iter, Scalar lambd) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_cpu", [&] {
auto lambd_val = lambd.to<scalar_t>();
cpu_kernel_vec(iter,
[=](scalar_t self_val) {
return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0) : self_val;
},
[=](Vec256<scalar_t> self_val) {
return ((self_val < -lambd_val) | (self_val > lambd_val)) & self_val;
}
);
cpu_kernel_vec(
iter,
[=](scalar_t self_val) {
return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0)
: self_val;
},
[=](Vec256<scalar_t> self_val) {
return ((self_val < -lambd_val) | (self_val > lambd_val)) & self_val;
});
});
}

void hardshrink_backward_cpu_kernel(TensorIterator& iter, Scalar lambd) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&] {
auto lambd_val = lambd.to<scalar_t>();
cpu_kernel_vec(iter,
[=](scalar_t grad_val, scalar_t self_val) {
return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0) : grad_val;
},
[=](Vec256<scalar_t> grad_val, Vec256<scalar_t> self_val) {
return ((self_val < -lambd_val) | (self_val > lambd_val)) & grad_val;
}
);
cpu_kernel_vec(
iter,
[=](scalar_t grad_val, scalar_t self_val) {
return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0)
: grad_val;
},
[=](Vec256<scalar_t> grad_val, Vec256<scalar_t> self_val) {
return ((self_val < -lambd_val) | (self_val > lambd_val)) & grad_val;
});
});
}

Expand All @@ -211,7 +233,9 @@ REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
REGISTER_DISPATCH(GeluKernel, &GeluKernelImpl);
REGISTER_DISPATCH(GeluBackwardKernel, &GeluBackwardKernelImpl);
REGISTER_DISPATCH(hardshrink_cpu_stub, &hardshrink_cpu_kernel);
REGISTER_DISPATCH(hardshrink_backward_cpu_stub, &hardshrink_backward_cpu_kernel);
REGISTER_DISPATCH(
hardshrink_backward_cpu_stub,
&hardshrink_backward_cpu_kernel);

} // namespace native
} // namespace at
Loading