Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions aten/src/ATen/native/Activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,4 +371,34 @@ Tensor hardshrink_backward_cpu(const Tensor & grad, const Tensor & self, Scalar
return out_tensor;
}


Tensor gelu_cpu(const Tensor& self) {
const auto X = self.contiguous();
Tensor Y = at::native::empty_like(X);
GeluKernel(kCPU, X, &Y);
return Y;
}

Tensor gelu_cuda(const Tensor& self) {
Tensor Y = at::native::empty_like(self);
GeluKernel(kCUDA, self, &Y);
return Y;
}

Tensor gelu_backward_cpu(const Tensor& grad, const Tensor& self) {
const auto X = self.contiguous();
Tensor dX = at::native::empty_like(X);
GeluBackwardKernel(kCPU, grad.contiguous(), X, &dX);
return dX;
}

Tensor gelu_backward_cuda(const Tensor& grad, const Tensor& self) {
Tensor dX = at::native::empty_like(self);
GeluBackwardKernel(kCUDA, grad, self, &dX);
return dX;
}

DEFINE_DISPATCH(GeluKernel);
DEFINE_DISPATCH(GeluBackwardKernel);

}} // namespace at::native
19 changes: 14 additions & 5 deletions aten/src/ATen/native/Activation.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
#pragma once

#include <c10/core/Scalar.h>
#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>
#include <c10/core/Scalar.h>

namespace at {

namespace at { struct TensorIterator; }
struct TensorIterator;

namespace at { namespace native {
namespace native {

using threshold_fn = void(*)(TensorIterator&, Scalar, Scalar);
using threshold_fn = void (*)(TensorIterator&, Scalar, Scalar);
using activation_fn = void (*)(const Tensor& /* X */, Tensor* /* Y */);
using activation_backward_fn =
void (*)(const Tensor& /* dY */, const Tensor& /* X */, Tensor* /* dX */);

DECLARE_DISPATCH(threshold_fn, threshold_stub);
DECLARE_DISPATCH(activation_fn, GeluKernel);
DECLARE_DISPATCH(activation_backward_fn, GeluBackwardKernel);

} // namespace native

}} // namespace at::native
} // namespace at
176 changes: 165 additions & 11 deletions aten/src/ATen/native/cpu/Activation.cpp
Original file line number Diff line number Diff line change
@@ -1,31 +1,185 @@
#define _USE_MATH_DEFINES

#include <ATen/native/Activation.h>

#include <math.h>

#include <ATen/ATen.h>
#include <ATen/Config.h>
#include <ATen/cpu/vec256/vec256.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>

namespace at { namespace native {
#if AT_MKL_ENABLED()
#include <mkl.h>
#endif // AT_MKL_ENABLED()

namespace at {
namespace native {

namespace {

static void threshold_kernel(TensorIterator& iter, Scalar threshold_scalar, Scalar value_scalar) {
static void threshold_kernel(
TensorIterator& iter,
Scalar threshold_scalar,
Scalar value_scalar) {
AT_DISPATCH_ALL_TYPES(iter.dtype(), "threshold_cpu", [&] {
using Vec = Vec256<scalar_t>;
scalar_t threshold = threshold_scalar.to<scalar_t>();
scalar_t value = value_scalar.to<scalar_t>();
binary_kernel_vec(
iter,
[&](scalar_t x, scalar_t other) -> scalar_t {
return x <= threshold ? value : other;
},
[&](Vec x, Vec other) -> Vec {
return Vec::blendv(other, Vec(value), x <= Vec(threshold));
});
iter,
[&](scalar_t x, scalar_t other) -> scalar_t {
return x <= threshold ? value : other;
},
[&](Vec x, Vec other) -> Vec {
return Vec::blendv(other, Vec(value), x <= Vec(threshold));
});
});
}

} // anonymous namespace
#if AT_MKL_ENABLED()

// TODO(yangxm): Consider to use TensorIterator here.
template <typename T>
void GeluKernelMKLImpl(const Tensor& X, Tensor* Y);

#define DELEGATE_GELU_KERNEL_MKL_IMPL(T, CdfNormFunc, MulFunc) \
template <> \
void GeluKernelMKLImpl<T>(const Tensor& X, Tensor* Y) { \
const int64_t N = X.numel(); \
const T* X_data = X.data<T>(); \
T* Y_data = Y->data<T>(); \
CdfNormFunc(N, X_data, Y_data); \
MulFunc(N, X_data, Y_data, Y_data); \
}
DELEGATE_GELU_KERNEL_MKL_IMPL(float, vsCdfNorm, vsMul)
DELEGATE_GELU_KERNEL_MKL_IMPL(double, vdCdfNorm, vdMul)
#undef DELEGATE_GELU_KERNEL_MKL_IMPL

#else // AT_MKL_ENABLED()

template <typename T>
void GeluKernelMKLImpl(const Tensor& X, Tensor* Y) {
AT_ASSERTM(false, "ATen not compiled with MKL");
}

#endif // AT_MKL_ENABLED()

template <typename T>
void GeluKernelImplInternal(const Tensor& X, Tensor* Y) {
const int64_t N = X.numel();
const T* X_data = X.data<T>();
T* Y_data = Y->data<T>();
for (int64_t i = 0; i < N; ++i) {
Y_data[i] = X_data[i] * M_SQRT1_2;
}
Y->erf_();
for (int64_t i = 0; i < N; ++i) {
Y_data[i] = (Y_data[i] + T(1)) * X_data[i] * T(0.5);
}
}

// TODO(yangxm): Add another fast kernel using formula
// y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
// and the fast tanh impl from Eigen.
void GeluKernelImpl(const Tensor& X, Tensor* Y) {
if (at::hasMKL()) {
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GeluKernelImpl", [&]() {
GeluKernelMKLImpl<scalar_t>(X, Y);
});
} else {
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GeluKernelImpl", [&]() {
GeluKernelImplInternal<scalar_t>(X, Y);
});
}
}

#if AT_MKL_ENABLED()

template <typename T>
void GeluBackwardKernelMKLImpl(const Tensor& dY, const Tensor& X, Tensor* dX);

// TODO(yangxm): Implement this by using template functions.
#define DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL(T, CdfNormFunc, ExpFunc) \
template <> \
void GeluBackwardKernelMKLImpl<T>( \
const Tensor& dY, const Tensor& X, Tensor* dX) { \
constexpr T kAlpha = M_2_SQRTPI * M_SQRT1_2 * T(0.5); \
Tensor scratch = at::native::empty_like(X); \
const int64_t N = X.numel(); \
const T* dY_data = dY.data<T>(); \
const T* X_data = X.data<T>(); \
T* dX_data = dX->data<T>(); \
T* scratch_data = scratch.data<T>(); \
CdfNormFunc(N, X_data, scratch_data); \
for (int64_t i = 0; i < N; ++i) { \
dX_data[i] = -T(0.5) * X_data[i] * X_data[i]; \
} \
ExpFunc(N, dX_data, dX_data); \
for (int64_t i = 0; i < N; ++i) { \
dX_data[i] = \
dY_data[i] * (scratch_data[i] + X_data[i] * dX_data[i] * kAlpha); \
} \
}
DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL(float, vsCdfNorm, vsExp)
DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL(double, vdCdfNorm, vdExp)
#undef DELEGATE_GELU_BACKWARD_KERNEL_MKL_IMPL

#else // AT_MKL_ENABLED()

template <typename T>
void GeluBackwardKernelMKLImpl(const Tensor& dY, const Tensor& X, Tensor* dX) {
AT_ASSERTM(false, "ATen not compiled with MKL");
}

#endif // AT_MKL_ENABLED()

template <typename T>
void GeluBackwardKernelImplInternal(
const Tensor& dY,
const Tensor& X,
Tensor* dX) {
constexpr T kAlpha = M_2_SQRTPI * M_SQRT1_2 * T(0.5);
Tensor scratch = at::native::empty_like(X);
const int64_t N = X.numel();
const T* dY_data = dY.data<T>();
const T* X_data = X.data<T>();
T* dX_data = dX->data<T>();
T* scratch_data = scratch.data<T>();
for (int64_t i = 0; i < N; ++i) {
scratch_data[i] = X_data[i] * M_SQRT1_2;
dX_data[i] = -T(0.5) * X_data[i] * X_data[i];
}
// TODO(yangxm): Consider let forward pass preserve CdfNorm(X) in training
// pass to reduce this extra tensor.
scratch.erf_();
dX->exp_();
for (int64_t i = 0; i < N; ++i) {
dX_data[i] = dY_data[i] *
(T(0.5) * (T(1) + scratch_data[i]) + X_data[i] * dX_data[i] * kAlpha);
}
}

void GeluBackwardKernelImpl(const Tensor& dY, const Tensor& X, Tensor* dX) {
if (hasMKL()) {
AT_DISPATCH_FLOATING_TYPES(
X.scalar_type(), "GeluBackwardKernelImpl", [&]() {
GeluBackwardKernelMKLImpl<scalar_t>(dY, X, dX);
});
} else {
AT_DISPATCH_FLOATING_TYPES(
X.scalar_type(), "GeluBackwardKernelImpl", [&]() {
GeluBackwardKernelImplInternal<scalar_t>(dY, X, dX);
});
}
}

} // namespace

REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
REGISTER_DISPATCH(GeluKernel, &GeluKernelImpl);
REGISTER_DISPATCH(GeluBackwardKernel, &GeluBackwardKernelImpl);

}} // namespace at::native
} // namespace native
} // namespace at
51 changes: 50 additions & 1 deletion aten/src/ATen/native/cuda/Activation.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
#define _USE_MATH_DEFINES

#include <ATen/native/Activation.h>

#include <math.h>

#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/native/Activation.h>
#include <ATen/native/cuda/Loops.cuh>
#include <c10/cuda/CUDAMathCompat.h>


namespace at { namespace native {
Expand Down Expand Up @@ -291,6 +297,49 @@ static void threshold_kernel(TensorIterator& iter, Scalar threshold, Scalar valu
});
}

namespace {

template <typename T>
void GeluCUDAKernelImplInternal(const Tensor& X, Tensor* Y) {
at::cuda::CUDA_tensor_apply2<T, T>(X, *Y, [] __device__(const T& x, T& y) {
y = x * c10::cuda::compat::normcdf(x);
});
}

void GeluCUDAKernelImpl(const Tensor& X, Tensor* Y) {
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GeluCUDAKernelImpl", [&]() {
GeluCUDAKernelImplInternal<scalar_t>(X, Y);
});
}

template <typename T>
void GeluBackwardCUDAKernelImplInternal(
const Tensor& dY,
const Tensor& X,
Tensor* dX) {
constexpr T kAlpha = M_2_SQRTPI * M_SQRT1_2 * T(0.5);
at::cuda::CUDA_tensor_apply3<T, T, T>(
dY, X, *dX, [] __device__(const T& dy, const T& x, T& dx) {
dx = dy *
(c10::cuda::compat::normcdf(x) +
x * kAlpha * c10::cuda::compat::exp(-T(0.5) * x * x));
});
}

void GeluBackwardCUDAKernelImpl(
const Tensor& dY,
const Tensor& X,
Tensor* dX) {
AT_DISPATCH_FLOATING_TYPES(
X.scalar_type(), "GeluBackwardCUDAKernelImpl", [&]() {
GeluBackwardCUDAKernelImplInternal<scalar_t>(dY, X, dX);
});
}

} // namespace

REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
REGISTER_DISPATCH(GeluKernel, &GeluCUDAKernelImpl);
REGISTER_DISPATCH(GeluBackwardKernel, &GeluBackwardCUDAKernelImpl);

}} // namespace at::native
12 changes: 12 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1587,6 +1587,18 @@
CPU: prelu_backward_cpu
CUDA: prelu_backward_cuda

- func: gelu(Tensor self) -> Tensor
python_module: nn
dispatch:
CPU: gelu_cpu
CUDA: gelu_cuda

- func: gelu_backward(Tensor grad, Tensor self) -> Tensor
python_module: nn
dispatch:
CPU: gelu_backward_cpu
CUDA: gelu_backward_cuda

- func: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor
variants: function, method
dispatch:
Expand Down
7 changes: 7 additions & 0 deletions c10/cuda/CUDAMathCompat.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ __MATH_FUNCTIONS_DECL__ double tan(double x) {
return ::tan(x);
}

__MATH_FUNCTIONS_DECL__ float normcdf(float x) {
return ::normcdff(x);
}
__MATH_FUNCTIONS_DECL__ double normcdf(double x) {
return ::normcdf(x);
}

} // namespace compat
} // namespace cuda
} // namespace c10
Expand Down
5 changes: 5 additions & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,11 @@ Non-linear activation functions

.. autofunction:: glu

:hidden:`gelu`
~~~~~~~~~~~~~~~

.. autofunction:: gelu

:hidden:`logsigmoid`
~~~~~~~~~~~~~~~~~~~~

Expand Down
Loading