Skip to content

Commit c9da011

Browse files
xiaomengyfacebook-github-bot
authored andcommitted
Optimize pytorch layer_norm forward (#20345)
Summary: Pull Request resolved: #20345 Seperate from D15194600 Optimize pytorch layer_norm op part 1: optimize layer_norm_forward_cpu import Eigen Maps for the performance of reduction Reviewed By: zheng-xq Differential Revision: D15290608 fbshipit-source-id: cf2c208dfd6fbcbc4c69db3ed60278d9bee156b5
1 parent 9cec8ae commit c9da011

File tree

4 files changed

+227
-58
lines changed

4 files changed

+227
-58
lines changed

aten/src/ATen/native/Normalization.cpp

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -462,64 +462,6 @@ Tensor instance_norm(
462462
return out.view(input.sizes());
463463
}
464464

465-
Tensor layer_norm(const Tensor& input, IntArrayRef normalized_shape,
466-
const Tensor& weight /* optional */, const Tensor& bias /* optional */,
467-
double eps, bool cudnn_enabled) {
468-
469-
int64_t normalized_ndim = normalized_shape.size();
470-
471-
TORCH_CHECK(normalized_ndim >= 1,
472-
"Expected normalized_shape to be at least 1-dimensional, i.e., ",
473-
"containing at least one element, but got normalized_shape=",
474-
normalized_shape);
475-
476-
TORCH_CHECK(!weight.defined() || weight.sizes().equals(normalized_shape),
477-
"Expected weight to be of same shape as normalized_shape, but got ",
478-
"weight of shape ", weight.sizes(), " and normalized_shape=",
479-
normalized_shape);
480-
TORCH_CHECK(!bias.defined() || bias.sizes().equals(normalized_shape),
481-
"Expected bias to be of same shape as normalized_shape, but got ",
482-
"bias of shape ", bias.sizes(), " and normalized_shape=",
483-
normalized_shape);
484-
485-
auto input_shape = input.sizes();
486-
auto input_ndim = input.dim();
487-
488-
if (input_ndim < normalized_ndim ||
489-
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
490-
std::stringstream ss;
491-
ss << "Given normalized_shape=" << normalized_shape
492-
<< ", expected input with shape [*";
493-
for (auto size : normalized_shape) {
494-
ss << ", " << size;
495-
}
496-
ss << "], but got input of size" << input_shape;
497-
AT_ERROR(ss.str());
498-
}
499-
500-
int64_t n = 1;
501-
for (int64_t i = 0; i < input_ndim - normalized_ndim; i++) {
502-
n *= input_shape[i];
503-
}
504-
505-
// Apply layer norm
506-
auto input_reshaped = input.contiguous().view({1, n, -1});
507-
508-
auto out = at::batch_norm(input_reshaped, {}, {}, {}, {}, true, 0, eps,
509-
cudnn_enabled);
510-
out = out.view(input_shape);
511-
512-
if (weight.defined() && bias.defined()) {
513-
return bias.addcmul(out, weight, 1);
514-
} else if (weight.defined()) {
515-
return out.mul(weight);
516-
} else if (bias.defined()) {
517-
return out.add(bias);
518-
} else {
519-
return out;
520-
}
521-
}
522-
523465
Tensor group_norm(const Tensor& input, int64_t num_groups,
524466
const Tensor& weight /* optional */, const Tensor& bias /* optional */,
525467
double eps, bool cudnn_enabled) {
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#include <ATen/native/cpu/layer_norm_kernel.h>
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/CPUApplyUtils.h>
5+
#include <ATen/Dispatch.h>
6+
7+
namespace at {
8+
namespace native {
9+
10+
namespace {
11+
12+
template <typename T>
13+
void LayerNormKernelImplInternal(
14+
const Tensor& X,
15+
const Tensor& gamma,
16+
const Tensor& beta,
17+
int64_t M,
18+
int64_t N,
19+
T eps,
20+
Tensor* Y,
21+
Tensor* mean,
22+
Tensor* rstd) {
23+
DCHECK_EQ(X.numel(), M * N);
24+
DCHECK(!gamma.defined() || gamma.numel() == N);
25+
DCHECK(!beta.defined() || beta.numel() == N);
26+
const T* X_data = X.data<T>();
27+
const T* gamma_data = gamma.defined() ? gamma.data<T>() : nullptr;
28+
const T* beta_data = beta.defined() ? beta.data<T>() : nullptr;
29+
T* Y_data = Y->data<T>();
30+
T* mean_data = mean->data<T>();
31+
T* rstd_data = rstd->data<T>();
32+
const T c = T(1) / static_cast<T>(N);
33+
const bool gamma_null = gamma_data == nullptr;
34+
const bool beta_null = beta_data == nullptr;
35+
for (int64_t i = 0; i < M; ++i) {
36+
const T* X_ptr = X_data + i * N;
37+
T* Y_ptr = Y_data + i * N;
38+
T mean_val = T(0);
39+
T rstd_val = T(0);
40+
for (int64_t j = 0; j < N; ++j) {
41+
mean_val += X_ptr[j];
42+
rstd_val += X_ptr[j] * X_ptr[j];
43+
}
44+
mean_val *= c;
45+
rstd_val = T(1) / std::sqrt(rstd_val * c - mean_val * mean_val + eps);
46+
const T scale = rstd_val;
47+
const T bias = -rstd_val * mean_val;
48+
for (int64_t j = 0; j < N; ++j) {
49+
const T gamma_v = gamma_null ? T(1) : gamma_data[j];
50+
const T beta_v = beta_null ? T(0) : beta_data[j];
51+
Y_ptr[j] = (X_ptr[j] * scale + bias) * gamma_v + beta_v;
52+
}
53+
mean_data[i] = mean_val;
54+
rstd_data[i] = rstd_val;
55+
}
56+
}
57+
58+
void LayerNormKernelImpl(
59+
const Tensor& X,
60+
const Tensor& gamma,
61+
const Tensor& beta,
62+
int64_t M,
63+
int64_t N,
64+
double eps,
65+
Tensor* Y,
66+
Tensor* mean,
67+
Tensor* rstd) {
68+
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "LayerNormKernelImpl", [&]() {
69+
LayerNormKernelImplInternal<scalar_t>(
70+
X, gamma, beta, M, N, static_cast<scalar_t>(eps), Y, mean, rstd);
71+
});
72+
}
73+
74+
} // namespace
75+
76+
REGISTER_DISPATCH(LayerNormKernel, &LayerNormKernelImpl);
77+
78+
} // namespace native
79+
} // namespace at
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef ATEN_SRC_NATIVE_CPU_LAYER_NORM_KERNEL_H_
2+
#define ATEN_SRC_NATIVE_CPU_LAYER_NORM_KERNEL_H_
3+
4+
#include <ATen/ATen.h>
5+
#include <ATen/native/DispatchStub.h>
6+
7+
namespace at {
8+
namespace native {
9+
10+
using forward_fn = void (*)(
11+
const Tensor& /* X */,
12+
const Tensor& /* gamma */,
13+
const Tensor& /* beta */,
14+
int64_t /* M */,
15+
int64_t /* N */,
16+
double /* eps */,
17+
Tensor* /* Y */,
18+
Tensor* /* mean */,
19+
Tensor* /* rstd */);
20+
21+
DECLARE_DISPATCH(forward_fn, LayerNormKernel);
22+
23+
} // namespace native
24+
} // namespace at
25+
26+
#endif // ATEN_SRC_NATIVE_CPU_LAYER_NORM_KERNEL_H_
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#include <ATen/NativeFunctions.h>
2+
3+
#include <functional>
4+
#include <numeric>
5+
#include <tuple>
6+
#include <vector>
7+
8+
#include <ATen/ATen.h>
9+
#include <ATen/AccumulateType.h>
10+
#include <ATen/CPUApplyUtils.h>
11+
#include <ATen/Config.h>
12+
#include <ATen/Parallel.h>
13+
#include <ATen/native/cpu/layer_norm_kernel.h>
14+
15+
namespace at {
16+
namespace native {
17+
18+
namespace {
19+
20+
std::tuple<Tensor, Tensor, Tensor> layer_norm_forward_cpu(
21+
const Tensor& X,
22+
const Tensor& gamma /* optional */,
23+
const Tensor& beta /* optional */,
24+
int64_t M,
25+
int64_t N,
26+
double eps) {
27+
Tensor Y = at::native::empty_like(X);
28+
Tensor mean = at::empty({M}, X.options());
29+
Tensor rstd = at::empty({M}, X.options());
30+
LayerNormKernel(kCPU, X, gamma, beta, M, N, eps, &Y, &mean, &rstd);
31+
return std::make_tuple(Y, mean, rstd);
32+
}
33+
34+
} // namespace
35+
36+
Tensor layer_norm(
37+
const Tensor& input,
38+
IntArrayRef normalized_shape,
39+
const Tensor& weight /* optional */,
40+
const Tensor& bias /* optional */,
41+
double eps,
42+
bool cudnn_enabled) {
43+
const int normalized_ndim = normalized_shape.size();
44+
TORCH_CHECK(
45+
normalized_ndim >= 1,
46+
"Expected normalized_shape to be at least 1-dimensional, i.e., ",
47+
"containing at least one element, but got normalized_shape = ",
48+
normalized_shape);
49+
TORCH_CHECK(
50+
!weight.defined() || weight.sizes().equals(normalized_shape),
51+
"Expected weight to be of same shape as normalized_shape, but got ",
52+
"weight of shape ",
53+
weight.sizes(),
54+
" and normalized_shape = ",
55+
normalized_shape);
56+
TORCH_CHECK(
57+
!bias.defined() || bias.sizes().equals(normalized_shape),
58+
"Expected bias to be of same shape as normalized_shape, but got ",
59+
"bias of shape ",
60+
bias.sizes(),
61+
" and normalized_shape = ",
62+
normalized_shape);
63+
64+
const auto input_shape = input.sizes();
65+
const auto input_ndim = input.dim();
66+
67+
if (input_ndim < normalized_ndim ||
68+
!input_shape.slice(input_ndim - normalized_ndim)
69+
.equals(normalized_shape)) {
70+
std::stringstream ss;
71+
ss << "Given normalized_shape=" << normalized_shape
72+
<< ", expected input with shape [*";
73+
for (auto size : normalized_shape) {
74+
ss << ", " << size;
75+
}
76+
ss << "], but got input of size" << input_shape;
77+
AT_ERROR(ss.str());
78+
}
79+
80+
const int axis = input_ndim - normalized_ndim;
81+
const int64_t M = std::accumulate(
82+
input_shape.cbegin(),
83+
input_shape.cbegin() + axis,
84+
1LL,
85+
std::multiplies<int64_t>());
86+
const int64_t N = std::accumulate(
87+
input_shape.cbegin() + axis,
88+
input_shape.cend(),
89+
1LL,
90+
std::multiplies<int64_t>());
91+
92+
// TODO(yangxm): Remove this check after backward pass landed.
93+
const auto is_forward = [](const Tensor& tensor) {
94+
return tensor.is_variable() && !tensor.requires_grad();
95+
};
96+
if (input.device().is_cpu() && is_forward(input) && is_forward(weight) &&
97+
is_forward(bias)) {
98+
return std::get<0>(layer_norm_forward_cpu(
99+
input.contiguous(), weight.contiguous(), bias.contiguous(), M, N, eps));
100+
}
101+
102+
// Apply layer norm
103+
auto input_reshaped = input.contiguous().view({1, M, -1});
104+
auto out = at::batch_norm(
105+
input_reshaped, {}, {}, {}, {}, true, 0, eps, cudnn_enabled);
106+
out = out.view(input_shape);
107+
108+
if (weight.defined() && bias.defined()) {
109+
return bias.addcmul(out, weight, 1);
110+
} else if (weight.defined()) {
111+
return out.mul(weight);
112+
} else if (bias.defined()) {
113+
return out.add(bias);
114+
} else {
115+
return out;
116+
}
117+
}
118+
119+
DEFINE_DISPATCH(LayerNormKernel);
120+
121+
} // namespace native
122+
} // namespace at

0 commit comments

Comments
 (0)