-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Optimize pytorch layer_norm forward #20345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| #include <ATen/native/cpu/layer_norm_kernel.h> | ||
|
|
||
| #include <ATen/ATen.h> | ||
| #include <ATen/CPUApplyUtils.h> | ||
| #include <ATen/Dispatch.h> | ||
|
|
||
| namespace at { | ||
| namespace native { | ||
|
|
||
| namespace { | ||
|
|
||
| template <typename T> | ||
| void LayerNormKernelImplInternal( | ||
| const Tensor& X, | ||
| const Tensor& gamma, | ||
| const Tensor& beta, | ||
| int64_t M, | ||
| int64_t N, | ||
| T eps, | ||
| Tensor* Y, | ||
| Tensor* mean, | ||
| Tensor* rstd) { | ||
| DCHECK_EQ(X.numel(), M * N); | ||
| DCHECK(!gamma.defined() || gamma.numel() == N); | ||
| DCHECK(!beta.defined() || beta.numel() == N); | ||
| const T* X_data = X.data<T>(); | ||
| const T* gamma_data = gamma.defined() ? gamma.data<T>() : nullptr; | ||
| const T* beta_data = beta.defined() ? beta.data<T>() : nullptr; | ||
| T* Y_data = Y->data<T>(); | ||
| T* mean_data = mean->data<T>(); | ||
| T* rstd_data = rstd->data<T>(); | ||
| const T c = T(1) / static_cast<T>(N); | ||
| const bool gamma_null = gamma_data == nullptr; | ||
| const bool beta_null = beta_data == nullptr; | ||
| for (int64_t i = 0; i < M; ++i) { | ||
| const T* X_ptr = X_data + i * N; | ||
| T* Y_ptr = Y_data + i * N; | ||
| T mean_val = T(0); | ||
| T rstd_val = T(0); | ||
| for (int64_t j = 0; j < N; ++j) { | ||
| mean_val += X_ptr[j]; | ||
| rstd_val += X_ptr[j] * X_ptr[j]; | ||
| } | ||
| mean_val *= c; | ||
| rstd_val = T(1) / std::sqrt(rstd_val * c - mean_val * mean_val + eps); | ||
| const T scale = rstd_val; | ||
| const T bias = -rstd_val * mean_val; | ||
| for (int64_t j = 0; j < N; ++j) { | ||
| const T gamma_v = gamma_null ? T(1) : gamma_data[j]; | ||
| const T beta_v = beta_null ? T(0) : beta_data[j]; | ||
| Y_ptr[j] = (X_ptr[j] * scale + bias) * gamma_v + beta_v; | ||
| } | ||
| mean_data[i] = mean_val; | ||
| rstd_data[i] = rstd_val; | ||
| } | ||
| } | ||
|
|
||
| void LayerNormKernelImpl( | ||
| const Tensor& X, | ||
| const Tensor& gamma, | ||
| const Tensor& beta, | ||
| int64_t M, | ||
| int64_t N, | ||
| double eps, | ||
| Tensor* Y, | ||
| Tensor* mean, | ||
| Tensor* rstd) { | ||
| AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "LayerNormKernelImpl", [&]() { | ||
| LayerNormKernelImplInternal<scalar_t>( | ||
| X, gamma, beta, M, N, static_cast<scalar_t>(eps), Y, mean, rstd); | ||
| }); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| REGISTER_DISPATCH(LayerNormKernel, &LayerNormKernelImpl); | ||
|
|
||
| } // namespace native | ||
| } // namespace at |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| #ifndef ATEN_SRC_NATIVE_CPU_LAYER_NORM_KERNEL_H_ | ||
| #define ATEN_SRC_NATIVE_CPU_LAYER_NORM_KERNEL_H_ | ||
|
|
||
| #include <ATen/ATen.h> | ||
| #include <ATen/native/DispatchStub.h> | ||
|
|
||
| namespace at { | ||
| namespace native { | ||
|
|
||
| using forward_fn = void (*)( | ||
| const Tensor& /* X */, | ||
| const Tensor& /* gamma */, | ||
| const Tensor& /* beta */, | ||
| int64_t /* M */, | ||
| int64_t /* N */, | ||
| double /* eps */, | ||
| Tensor* /* Y */, | ||
| Tensor* /* mean */, | ||
| Tensor* /* rstd */); | ||
|
|
||
| DECLARE_DISPATCH(forward_fn, LayerNormKernel); | ||
|
|
||
| } // namespace native | ||
| } // namespace at | ||
|
|
||
| #endif // ATEN_SRC_NATIVE_CPU_LAYER_NORM_KERNEL_H_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| #include <ATen/NativeFunctions.h> | ||
|
|
||
| #include <functional> | ||
| #include <numeric> | ||
| #include <tuple> | ||
| #include <vector> | ||
|
|
||
| #include <ATen/ATen.h> | ||
| #include <ATen/AccumulateType.h> | ||
| #include <ATen/CPUApplyUtils.h> | ||
| #include <ATen/Config.h> | ||
| #include <ATen/Parallel.h> | ||
| #include <ATen/native/cpu/layer_norm_kernel.h> | ||
|
|
||
| namespace at { | ||
| namespace native { | ||
|
|
||
| namespace { | ||
|
|
||
| std::tuple<Tensor, Tensor, Tensor> layer_norm_forward_cpu( | ||
| const Tensor& X, | ||
| const Tensor& gamma /* optional */, | ||
| const Tensor& beta /* optional */, | ||
| int64_t M, | ||
| int64_t N, | ||
| double eps) { | ||
| Tensor Y = at::native::empty_like(X); | ||
| Tensor mean = at::empty({M}, X.options()); | ||
| Tensor rstd = at::empty({M}, X.options()); | ||
| LayerNormKernel(kCPU, X, gamma, beta, M, N, eps, &Y, &mean, &rstd); | ||
| return std::make_tuple(Y, mean, rstd); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| Tensor layer_norm( | ||
| const Tensor& input, | ||
| IntArrayRef normalized_shape, | ||
| const Tensor& weight /* optional */, | ||
| const Tensor& bias /* optional */, | ||
| double eps, | ||
| bool cudnn_enabled) { | ||
| const int normalized_ndim = normalized_shape.size(); | ||
| TORCH_CHECK( | ||
| normalized_ndim >= 1, | ||
| "Expected normalized_shape to be at least 1-dimensional, i.e., ", | ||
| "containing at least one element, but got normalized_shape = ", | ||
| normalized_shape); | ||
| TORCH_CHECK( | ||
| !weight.defined() || weight.sizes().equals(normalized_shape), | ||
| "Expected weight to be of same shape as normalized_shape, but got ", | ||
| "weight of shape ", | ||
| weight.sizes(), | ||
| " and normalized_shape = ", | ||
| normalized_shape); | ||
| TORCH_CHECK( | ||
| !bias.defined() || bias.sizes().equals(normalized_shape), | ||
| "Expected bias to be of same shape as normalized_shape, but got ", | ||
| "bias of shape ", | ||
| bias.sizes(), | ||
| " and normalized_shape = ", | ||
| normalized_shape); | ||
|
|
||
| const auto input_shape = input.sizes(); | ||
| const auto input_ndim = input.dim(); | ||
|
|
||
| if (input_ndim < normalized_ndim || | ||
| !input_shape.slice(input_ndim - normalized_ndim) | ||
| .equals(normalized_shape)) { | ||
| std::stringstream ss; | ||
| ss << "Given normalized_shape=" << normalized_shape | ||
| << ", expected input with shape [*"; | ||
| for (auto size : normalized_shape) { | ||
| ss << ", " << size; | ||
| } | ||
| ss << "], but got input of size" << input_shape; | ||
| AT_ERROR(ss.str()); | ||
| } | ||
|
|
||
| const int axis = input_ndim - normalized_ndim; | ||
| const int64_t M = std::accumulate( | ||
| input_shape.cbegin(), | ||
| input_shape.cbegin() + axis, | ||
| 1LL, | ||
| std::multiplies<int64_t>()); | ||
| const int64_t N = std::accumulate( | ||
| input_shape.cbegin() + axis, | ||
| input_shape.cend(), | ||
| 1LL, | ||
| std::multiplies<int64_t>()); | ||
|
|
||
| // TODO(yangxm): Remove this check after backward pass landed. | ||
| const auto is_forward = [](const Tensor& tensor) { | ||
| return tensor.is_variable() && !tensor.requires_grad(); | ||
| }; | ||
| if (input.device().is_cpu() && is_forward(input) && is_forward(weight) && | ||
| is_forward(bias)) { | ||
| return std::get<0>(layer_norm_forward_cpu( | ||
| input.contiguous(), weight.contiguous(), bias.contiguous(), M, N, eps)); | ||
| } | ||
|
|
||
| // Apply layer norm | ||
| auto input_reshaped = input.contiguous().view({1, M, -1}); | ||
| auto out = at::batch_norm( | ||
| input_reshaped, {}, {}, {}, {}, true, 0, eps, cudnn_enabled); | ||
| out = out.view(input_shape); | ||
|
|
||
| if (weight.defined() && bias.defined()) { | ||
| return bias.addcmul(out, weight, 1); | ||
| } else if (weight.defined()) { | ||
| return out.mul(weight); | ||
| } else if (bias.defined()) { | ||
| return out.add(bias); | ||
| } else { | ||
| return out; | ||
| } | ||
| } | ||
|
|
||
| DEFINE_DISPATCH(LayerNormKernel); | ||
|
|
||
| } // namespace native | ||
| } // namespace at | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move LayerNormForwardCPUImpl and all CPU logic into the /cpu subfolder using DECLARE_DISPATCH logic (see CopyKernel for the reference). It will allow to utilize AVX2 instructions and other optimizations of OSS build.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done