Skip to content

Commit c8f17b1

Browse files
committed
add quantized layer norm implementation
Summary: Adds a quantized implementation of LayerNorm for server. A future PR will add the Python wrapper. Test Plan: numerics match the floating point implementation TODO: benchmarks Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 7f70433 Pull Request resolved: #35329
1 parent b66513b commit c8f17b1

File tree

10 files changed

+383
-6
lines changed

10 files changed

+383
-6
lines changed

aten/src/ATen/native/layer_norm.cpp

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <ATen/Config.h>
1313
#include <ATen/NativeFunctions.h>
1414
#include <ATen/Parallel.h>
15+
#include <ATen/core/op_registration/op_registration.h>
1516

1617
namespace at {
1718
namespace native {
@@ -60,13 +61,12 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_cpu(
6061
return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta));
6162
}
6263

63-
Tensor layer_norm(
64+
std::tuple<Tensor, Tensor, Tensor, int64_t, int64_t> _prepare_layer_norm_inputs(
6465
const Tensor& input,
6566
IntArrayRef normalized_shape,
6667
const Tensor& weight /* optional */,
67-
const Tensor& bias /* optional */,
68-
double eps,
69-
bool /* cudnn_enable, deprecated */) {
68+
const Tensor& bias /* optional */) {
69+
7070
const int normalized_ndim = normalized_shape.size();
7171
TORCH_CHECK(
7272
normalized_ndim >= 1,
@@ -119,11 +119,90 @@ Tensor layer_norm(
119119
const auto& X = input.is_contiguous() ? input : input.contiguous();
120120
const auto& gamma = weight.is_contiguous() ? weight : weight.contiguous();
121121
const auto& beta = bias.is_contiguous() ? bias : bias.contiguous();
122+
123+
return std::make_tuple(X, gamma, beta, M, N);
124+
}
125+
126+
Tensor layer_norm(
127+
const Tensor& input,
128+
IntArrayRef normalized_shape,
129+
const Tensor& weight /* optional */,
130+
const Tensor& bias /* optional */,
131+
double eps,
132+
bool /* cudnn_enable, deprecated */) {
133+
134+
auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias);
135+
auto X = std::get<0>(inputs);
136+
auto gamma = std::get<1>(inputs);
137+
auto beta = std::get<2>(inputs);
138+
auto M = std::get<3>(inputs);
139+
auto N = std::get<4>(inputs);
140+
122141
return std::get<0>(at::native_layer_norm(X, gamma, beta, M, N, eps));
123142
}
124143

144+
Tensor quantized_layer_norm_impl(
145+
const Tensor& input,
146+
IntArrayRef normalized_shape,
147+
const Tensor& weight /* optional */,
148+
const Tensor& bias /* optional */,
149+
double eps,
150+
double output_scale,
151+
int64_t output_zero_point) {
152+
153+
auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias);
154+
auto X = std::get<0>(inputs);
155+
auto gamma = std::get<1>(inputs);
156+
auto beta = std::get<2>(inputs);
157+
auto M = std::get<3>(inputs);
158+
auto N = std::get<4>(inputs);
159+
160+
Tensor Y = at::_empty_affine_quantized(
161+
X.sizes(),
162+
X.scalar_type(),
163+
output_scale,
164+
output_zero_point,
165+
X.suggest_memory_format());
166+
167+
if (M > 0) {
168+
quantized_layer_norm_stub(kCPU, X, gamma, beta, M, N, eps, &Y);
169+
}
170+
return Y;
171+
}
172+
173+
// Keep the registry in the anonymous namespace.
174+
namespace {
175+
class QLayerNorm2d final : public torch::OperatorKernel {
176+
public:
177+
Tensor operator()(
178+
Tensor input,
179+
std::vector<int64_t> normalized_shape,
180+
Tensor weight /* optional */,
181+
Tensor bias /* optional */,
182+
double eps,
183+
double output_scale,
184+
int64_t output_zero_point) {
185+
return quantized_layer_norm_impl(
186+
input, normalized_shape, weight, bias, eps, output_scale, output_zero_point);
187+
}
188+
};
189+
190+
static auto registry = torch::RegisterOperators().op(
191+
"quantized::layer_norm(Tensor input, "
192+
"int[] normalized_shape, "
193+
"Tensor weight, "
194+
"Tensor bias, "
195+
"float eps, "
196+
"float output_scale, "
197+
"int output_zero_point) -> Tensor",
198+
torch::RegisterOperators::options().kernel<QLayerNorm2d>(
199+
DispatchKey::QuantizedCPUTensorId));
200+
201+
} // namespace
202+
125203
DEFINE_DISPATCH(LayerNormKernel);
126204
DEFINE_DISPATCH(LayerNormBackwardKernel);
205+
DEFINE_DISPATCH(quantized_layer_norm_stub);
127206

128207
} // namespace native
129208
} // namespace at

aten/src/ATen/native/layer_norm.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,18 @@ using backward_fn = void (*)(
2929
Tensor* /* dgamma */,
3030
Tensor* /* dbeta */);
3131

32+
using forward_quantized_fn = void (*)(
33+
const Tensor& /* X */,
34+
const Tensor& /* gamma */,
35+
const Tensor& /* beta */,
36+
int64_t /* M */,
37+
int64_t /* N */,
38+
double /* eps */,
39+
Tensor* /* Y */);
40+
3241
DECLARE_DISPATCH(forward_fn, LayerNormKernel);
3342
DECLARE_DISPATCH(backward_fn, LayerNormBackwardKernel);
43+
DECLARE_DISPATCH(forward_quantized_fn, quantized_layer_norm_stub);
3444

3545
} // namespace native
3646
} // namespace at

aten/src/ATen/native/native_functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,6 +1614,11 @@
16141614
CPU: layer_norm_backward_cpu
16151615
CUDA: layer_norm_backward_cuda
16161616

1617+
- func: quantized_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor
1618+
requires_tensor: True
1619+
dispatch:
1620+
QuantizedCPU: quantized_layer_norm_impl
1621+
16171622
- func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
16181623
python_module: nn
16191624

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,6 +1834,120 @@ void fake_quant_grad_per_channel_cpu(TensorIterator &iter, int64_t quant_min, in
18341834
});
18351835
}
18361836

1837+
template <typename T>
1838+
void quantized_layer_norm_kernel_impl(
1839+
const Tensor& X,
1840+
const Tensor& gamma,
1841+
const Tensor& beta,
1842+
int64_t M,
1843+
int64_t N,
1844+
float eps,
1845+
Tensor* Y) {
1846+
1847+
}
1848+
1849+
void quantized_layer_norm_kernel(
1850+
const Tensor& X,
1851+
const Tensor& gamma,
1852+
const Tensor& beta,
1853+
int64_t M,
1854+
int64_t N,
1855+
double eps,
1856+
Tensor* Y) {
1857+
AT_DISPATCH_QINT_TYPES(X.scalar_type(), "quantized_layer_norm_kernel_impl_cpu", [&]() {
1858+
using qVec = vec256::Vec256<scalar_t>;
1859+
using fVec = vec256::Vec256<float>;
1860+
1861+
TORCH_INTERNAL_ASSERT(X.numel() == M * N, "Unexpected num elements in X");
1862+
TORCH_INTERNAL_ASSERT(!gamma.defined() || gamma.numel() == N,
1863+
"Unexpected size of gamma");
1864+
TORCH_INTERNAL_ASSERT(!beta.defined() || beta.numel() == N,
1865+
"Unexpected size of beta");
1866+
scalar_t* X_data = X.data_ptr<scalar_t>();
1867+
const float* gamma_data = gamma.defined() ? gamma.data_ptr<float>() : nullptr;
1868+
const float* beta_data = beta.defined() ? beta.data_ptr<float>() : nullptr;
1869+
scalar_t* Y_data = Y->data_ptr<scalar_t>();
1870+
const bool gamma_null = gamma_data == nullptr;
1871+
const bool beta_null = beta_data == nullptr;
1872+
int64_t x_zp = X.q_zero_point();
1873+
float x_scale = X.q_scale();
1874+
fVec x_zp_vec((float)x_zp);
1875+
fVec one_vec(1.0f);
1876+
fVec zero_vec(0.0f);
1877+
float x_fake_scale = 1.0f;
1878+
fVec x_fake_scale_vec(x_fake_scale);
1879+
fVec x_fake_scale_zp_neg_premul_vec = x_fake_scale_vec * x_zp_vec.neg();
1880+
int64_t y_zp = Y->q_zero_point();
1881+
float y_scale = Y->q_scale();
1882+
float y_inv_scale = 1.0f / y_scale;
1883+
1884+
constexpr int kFloatVLen = 8;
1885+
int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs();
1886+
int64_t kNumIntVecInLayer = N / kIntVLen;
1887+
int64_t kNonVecRemInLayer = N % kIntVLen;
1888+
1889+
at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
1890+
for (int64_t i = start; i < end; ++i) {
1891+
1892+
scalar_t* X_ptr = X_data + i * N;
1893+
scalar_t* Y_ptr = Y_data + i * N;
1894+
1895+
// First pass: calculate mean and variance.
1896+
1897+
scalar_t::underlying* X_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(X_ptr);
1898+
auto l_sum_shifted = hsum(X_ptr_underlying, N);
1899+
auto l_sum_sq_shifted = hsum_sq(X_ptr_underlying, N);
1900+
float l_mean_shifted_div_scale_x = static_cast<float>(l_sum_shifted) / N;
1901+
// mean(dqX) / scale_x
1902+
float layer_mean_div_scale_x = l_mean_shifted_div_scale_x - x_zp;
1903+
// var(dqX) / scale_x^2
1904+
float layer_var_div_scale_x_sq =
1905+
std::max(static_cast<float>(l_sum_sq_shifted) / N -
1906+
l_mean_shifted_div_scale_x * l_mean_shifted_div_scale_x, 0.0f);
1907+
// scale_x / sqrt(var(dqX) + eps)
1908+
float scale_x_div_layer_std = x_scale /
1909+
std::sqrt(layer_var_div_scale_x_sq * x_scale * x_scale + eps);
1910+
fVec layer_mean_div_scale_xVec(layer_mean_div_scale_x);
1911+
fVec scale_x_div_layer_stdVec(scale_x_div_layer_std);
1912+
1913+
// Second pass: normalize
1914+
1915+
// TODO replace with TensorIterator implementation once #33166 is fixed.
1916+
for (int64_t vecIdx = 0; vecIdx < kNumIntVecInLayer; vecIdx++) {
1917+
int64_t vecStartIdx = vecIdx * kIntVLen;
1918+
auto qXVec = qVec::loadu(X_ptr + vecStartIdx);
1919+
auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
1920+
x_fake_scale_zp_neg_premul_vec);
1921+
for (int dqXVecIdx = 0; dqXVecIdx < dqXVec.size(); dqXVecIdx++) {
1922+
int64_t vecVecStartIdx = vecStartIdx + dqXVecIdx * kFloatVLen;
1923+
auto gammaVec = gamma_null
1924+
? one_vec
1925+
: fVec::loadu(gamma_data + vecVecStartIdx);
1926+
auto betaVec = beta_null
1927+
? zero_vec
1928+
: fVec::loadu(beta_data + vecVecStartIdx);
1929+
dqXVec[dqXVecIdx] =
1930+
(dqXVec[dqXVecIdx] - layer_mean_div_scale_xVec) *
1931+
scale_x_div_layer_stdVec * gammaVec + betaVec;
1932+
qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale)
1933+
.store(Y_ptr + vecStartIdx);
1934+
}
1935+
}
1936+
for (int64_t remIdx = N - kNonVecRemInLayer; remIdx < N; remIdx++) {
1937+
const float gamma_v = gamma_null ? 1.0f : gamma_data[remIdx];
1938+
const float beta_v = beta_null ? 0.0f : beta_data[remIdx];
1939+
auto qXVal = X_ptr[remIdx];
1940+
float dqXVal = at::dequantize_val(x_fake_scale, x_zp, qXVal);
1941+
float dqY =
1942+
((dqXVal - layer_mean_div_scale_x) * scale_x_div_layer_std) * gamma_v + beta_v;
1943+
Y_ptr[remIdx] = at::quantize_val<scalar_t>(y_scale, y_zp, dqY);
1944+
}
1945+
}
1946+
}); // parallel_for
1947+
1948+
});
1949+
}
1950+
18371951
} // namespace
18381952

18391953
REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel);
@@ -1869,6 +1983,7 @@ REGISTER_DISPATCH(fake_quant_tensor_stub, &fake_quantize_tensor_kernel);
18691983
REGISTER_DISPATCH(fake_quant_grad_tensor_stub, &fake_quantize_grad_tensor_kernel);
18701984
REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cpu);
18711985
REGISTER_DISPATCH(fake_quant_grad_per_channel_stub, &fake_quant_grad_per_channel_cpu);
1986+
REGISTER_DISPATCH(quantized_layer_norm_stub, &quantized_layer_norm_kernel);
18721987

18731988
} // namespace native
18741989
} // namespace at

benchmarks/operator_benchmark/benchmark_all_other_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
add_test, as_strided_test, batchnorm_test, binary_test, cat_test, # noqa
99
chunk_test, conv_test, diag_test, embeddingbag_test, fill_test, # noqa
1010
gather_test, linear_test, matmul_test, pool_test, # noqa
11-
softmax_test, hardsigmoid_test, hardswish_test # noqa
11+
softmax_test, hardsigmoid_test, hardswish_test, layernorm_test # noqa
1212
)
1313

1414
if __name__ == "__main__":

benchmarks/operator_benchmark/benchmark_all_quantized_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
qcomparators_test,
1313
qconv_test,
1414
qinterpolate_test,
15+
qlayernorm_test,
1516
qlinear_test,
1617
qobserver_test,
1718
qpool_test,
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
from __future__ import unicode_literals
5+
6+
7+
import operator_benchmark as op_bench
8+
import torch
9+
import torch.nn.functional as F
10+
11+
12+
"""Microbenchmarks for layernorm operator."""
13+
14+
layernorm_configs_short = op_bench.cross_product_configs(
15+
dims=(
16+
(1, 8, 16),
17+
(8, 8, 16),
18+
(32, 8, 16),
19+
(64, 128, 56, 56),
20+
),
21+
tags=["short"],
22+
)
23+
24+
25+
class LayerNormBenchmark(op_bench.TorchBenchmarkBase):
26+
def init(self, dims):
27+
self.X = (torch.rand(*dims) - 0.5) * 256
28+
self.weight = torch.rand(*self.X.size()[1:], dtype=torch.float)
29+
self.bias = torch.rand(*self.X.size()[1:], dtype=torch.float)
30+
self.eps = 1e-5
31+
32+
def forward(self):
33+
return F.layer_norm(
34+
self.X, self.X.size()[1:], weight=self.weight, bias=self.bias, eps=self.eps)
35+
36+
37+
op_bench.generate_pt_test(layernorm_configs_short, LayerNormBenchmark)
38+
39+
40+
if __name__ == "__main__":
41+
op_bench.benchmark_runner.main()
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
from __future__ import unicode_literals
5+
6+
7+
import operator_benchmark as op_bench
8+
import torch
9+
10+
11+
"""Microbenchmarks for quantized layernorm operator."""
12+
13+
layernorm_configs_short = op_bench.cross_product_configs(
14+
dims=(
15+
(1, 8, 16),
16+
(8, 8, 16),
17+
(32, 8, 16),
18+
(64, 128, 56, 56),
19+
),
20+
dtype=(torch.qint8,),
21+
tags=["short"],
22+
)
23+
24+
25+
class QLayerNormBenchmark(op_bench.TorchBenchmarkBase):
26+
27+
def init(self, dims, dtype):
28+
X = (torch.rand(*dims) - 0.5) * 256
29+
scale = 1.0
30+
zero_point = 0
31+
self.qX = torch.quantize_per_tensor(
32+
X, scale=scale, zero_point=zero_point, dtype=dtype)
33+
self.weight = torch.rand(*self.qX.size()[1:], dtype=torch.float)
34+
self.bias = torch.rand(*self.qX.size()[1:], dtype=torch.float)
35+
self.eps = 1e-5
36+
self.Y_scale = 0.1
37+
self.Y_zero_point = 0
38+
39+
def forward(self):
40+
return torch.ops.quantized.layer_norm(
41+
self.qX, self.qX.size()[1:], weight=self.weight, bias=self.bias,
42+
eps=self.eps, output_scale=self.Y_scale,
43+
output_zero_point=self.Y_zero_point)
44+
45+
46+
op_bench.generate_pt_test(layernorm_configs_short, QLayerNormBenchmark)
47+
48+
49+
if __name__ == "__main__":
50+
op_bench.benchmark_runner.main()

0 commit comments

Comments
 (0)