Skip to content

Commit d456401

Browse files
committed
redo of add quantized layer norm implementation
Summary: This is a redo of #35329 with a better test. Adds a quantized implementation of LayerNorm for server. A future PR will add the Python wrapper. Test Plan: numerics match the floating point implementation benchmarks by input size: v1 (mean+var non-vectorized): https://gist.github.com/vkuzo/f6d72c04742608112f4c2e612c74bd13 v2 (mean+var vectorized in float): https://gist.github.com/vkuzo/4dd95657c5b5f3654e0965db00eff8d2 v3 (mean+var vectorized in int, current): https://gist.github.com/vkuzo/57a75f75629da9f23b64b38ca0e3d34b ghstack-source-id: 9bb87ea Pull Request resolved: #36593
1 parent fb70b4f commit d456401

File tree

10 files changed

+410
-6
lines changed

10 files changed

+410
-6
lines changed

aten/src/ATen/native/layer_norm.cpp

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <ATen/Config.h>
1313
#include <ATen/NativeFunctions.h>
1414
#include <ATen/Parallel.h>
15+
#include <ATen/core/op_registration/op_registration.h>
1516

1617
namespace at {
1718
namespace native {
@@ -60,13 +61,12 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_cpu(
6061
return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta));
6162
}
6263

63-
Tensor layer_norm(
64+
std::tuple<Tensor, Tensor, Tensor, int64_t, int64_t> _prepare_layer_norm_inputs(
6465
const Tensor& input,
6566
IntArrayRef normalized_shape,
6667
const Tensor& weight /* optional */,
67-
const Tensor& bias /* optional */,
68-
double eps,
69-
bool /* cudnn_enable, deprecated */) {
68+
const Tensor& bias /* optional */) {
69+
7070
const int normalized_ndim = normalized_shape.size();
7171
TORCH_CHECK(
7272
normalized_ndim >= 1,
@@ -119,11 +119,90 @@ Tensor layer_norm(
119119
const auto& X = input.is_contiguous() ? input : input.contiguous();
120120
const auto& gamma = weight.is_contiguous() ? weight : weight.contiguous();
121121
const auto& beta = bias.is_contiguous() ? bias : bias.contiguous();
122+
123+
return std::make_tuple(X, gamma, beta, M, N);
124+
}
125+
126+
Tensor layer_norm(
127+
const Tensor& input,
128+
IntArrayRef normalized_shape,
129+
const Tensor& weight /* optional */,
130+
const Tensor& bias /* optional */,
131+
double eps,
132+
bool /* cudnn_enable, deprecated */) {
133+
134+
auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias);
135+
auto X = std::get<0>(inputs);
136+
auto gamma = std::get<1>(inputs);
137+
auto beta = std::get<2>(inputs);
138+
auto M = std::get<3>(inputs);
139+
auto N = std::get<4>(inputs);
140+
122141
return std::get<0>(at::native_layer_norm(X, gamma, beta, M, N, eps));
123142
}
124143

144+
Tensor quantized_layer_norm_impl(
145+
const Tensor& input,
146+
IntArrayRef normalized_shape,
147+
const Tensor& weight /* optional */,
148+
const Tensor& bias /* optional */,
149+
double eps,
150+
double output_scale,
151+
int64_t output_zero_point) {
152+
153+
auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias);
154+
auto X = std::get<0>(inputs);
155+
auto gamma = std::get<1>(inputs);
156+
auto beta = std::get<2>(inputs);
157+
auto M = std::get<3>(inputs);
158+
auto N = std::get<4>(inputs);
159+
160+
Tensor Y = at::_empty_affine_quantized(
161+
X.sizes(),
162+
X.scalar_type(),
163+
output_scale,
164+
output_zero_point,
165+
X.suggest_memory_format());
166+
167+
if (M > 0) {
168+
quantized_layer_norm_stub(kCPU, X, gamma, beta, M, N, eps, &Y);
169+
}
170+
return Y;
171+
}
172+
173+
// Keep the registry in the anonymous namespace.
174+
namespace {
175+
class QLayerNorm2d final : public torch::OperatorKernel {
176+
public:
177+
Tensor operator()(
178+
Tensor input,
179+
std::vector<int64_t> normalized_shape,
180+
Tensor weight /* optional */,
181+
Tensor bias /* optional */,
182+
double eps,
183+
double output_scale,
184+
int64_t output_zero_point) {
185+
return quantized_layer_norm_impl(
186+
input, normalized_shape, weight, bias, eps, output_scale, output_zero_point);
187+
}
188+
};
189+
190+
static auto registry = torch::RegisterOperators().op(
191+
"quantized::layer_norm(Tensor input, "
192+
"int[] normalized_shape, "
193+
"Tensor weight, "
194+
"Tensor bias, "
195+
"float eps, "
196+
"float output_scale, "
197+
"int output_zero_point) -> Tensor",
198+
torch::RegisterOperators::options().kernel<QLayerNorm2d>(
199+
DispatchKey::QuantizedCPU));
200+
201+
} // namespace
202+
125203
DEFINE_DISPATCH(LayerNormKernel);
126204
DEFINE_DISPATCH(LayerNormBackwardKernel);
205+
DEFINE_DISPATCH(quantized_layer_norm_stub);
127206

128207
} // namespace native
129208
} // namespace at

aten/src/ATen/native/layer_norm.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,18 @@ using backward_fn = void (*)(
2929
Tensor* /* dgamma */,
3030
Tensor* /* dbeta */);
3131

32+
using forward_quantized_fn = void (*)(
33+
const Tensor& /* X */,
34+
const Tensor& /* gamma */,
35+
const Tensor& /* beta */,
36+
int64_t /* M */,
37+
int64_t /* N */,
38+
double /* eps */,
39+
Tensor* /* Y */);
40+
3241
DECLARE_DISPATCH(forward_fn, LayerNormKernel);
3342
DECLARE_DISPATCH(backward_fn, LayerNormBackwardKernel);
43+
DECLARE_DISPATCH(forward_quantized_fn, quantized_layer_norm_stub);
3444

3545
} // namespace native
3646
} // namespace at

aten/src/ATen/native/native_functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,6 +1645,11 @@
16451645
CPU: layer_norm_backward_cpu
16461646
CUDA: layer_norm_backward_cuda
16471647

1648+
- func: quantized_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor
1649+
requires_tensor: True
1650+
dispatch:
1651+
QuantizedCPU: quantized_layer_norm_impl
1652+
16481653
- func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
16491654
python_module: nn
16501655

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,6 +1889,120 @@ void fake_quant_grad_per_channel_cpu(TensorIterator &iter, int64_t quant_min, in
18891889
});
18901890
}
18911891

1892+
template <typename T>
1893+
void quantized_layer_norm_kernel_impl(
1894+
const Tensor& X,
1895+
const Tensor& gamma,
1896+
const Tensor& beta,
1897+
int64_t M,
1898+
int64_t N,
1899+
float eps,
1900+
Tensor* Y) {
1901+
1902+
}
1903+
1904+
void quantized_layer_norm_kernel(
1905+
const Tensor& X,
1906+
const Tensor& gamma,
1907+
const Tensor& beta,
1908+
int64_t M,
1909+
int64_t N,
1910+
double eps,
1911+
Tensor* Y) {
1912+
AT_DISPATCH_QINT_TYPES(X.scalar_type(), "quantized_layer_norm_kernel_impl_cpu", [&]() {
1913+
using qVec = vec256::Vec256<scalar_t>;
1914+
using fVec = vec256::Vec256<float>;
1915+
1916+
TORCH_INTERNAL_ASSERT(X.numel() == M * N, "Unexpected num elements in X");
1917+
TORCH_INTERNAL_ASSERT(!gamma.defined() || gamma.numel() == N,
1918+
"Unexpected size of gamma");
1919+
TORCH_INTERNAL_ASSERT(!beta.defined() || beta.numel() == N,
1920+
"Unexpected size of beta");
1921+
scalar_t* X_data = X.data_ptr<scalar_t>();
1922+
const float* gamma_data = gamma.defined() ? gamma.data_ptr<float>() : nullptr;
1923+
const float* beta_data = beta.defined() ? beta.data_ptr<float>() : nullptr;
1924+
scalar_t* Y_data = Y->data_ptr<scalar_t>();
1925+
const bool gamma_null = gamma_data == nullptr;
1926+
const bool beta_null = beta_data == nullptr;
1927+
int64_t x_zp = X.q_zero_point();
1928+
float x_scale = X.q_scale();
1929+
fVec x_zp_vec((float)x_zp);
1930+
fVec one_vec(1.0f);
1931+
fVec zero_vec(0.0f);
1932+
float x_fake_scale = 1.0f;
1933+
fVec x_fake_scale_vec(x_fake_scale);
1934+
fVec x_fake_scale_zp_neg_premul_vec = x_fake_scale_vec * x_zp_vec.neg();
1935+
int64_t y_zp = Y->q_zero_point();
1936+
float y_scale = Y->q_scale();
1937+
float y_inv_scale = 1.0f / y_scale;
1938+
1939+
constexpr int kFloatVLen = 8;
1940+
int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs();
1941+
int64_t kNumIntVecInLayer = N / kIntVLen;
1942+
int64_t kNonVecRemInLayer = N % kIntVLen;
1943+
1944+
at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
1945+
for (int64_t i = start; i < end; ++i) {
1946+
1947+
scalar_t* X_ptr = X_data + i * N;
1948+
scalar_t* Y_ptr = Y_data + i * N;
1949+
1950+
// First pass: calculate mean and variance.
1951+
1952+
scalar_t::underlying* X_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(X_ptr);
1953+
auto l_sum_shifted = hsum(X_ptr_underlying, N);
1954+
auto l_sum_sq_shifted = hsum_sq(X_ptr_underlying, N);
1955+
float l_mean_shifted_div_scale_x = static_cast<float>(l_sum_shifted) / N;
1956+
// mean(dqX) / scale_x
1957+
float layer_mean_div_scale_x = l_mean_shifted_div_scale_x - x_zp;
1958+
// var(dqX) / scale_x^2
1959+
float layer_var_div_scale_x_sq =
1960+
std::max(static_cast<float>(l_sum_sq_shifted) / N -
1961+
l_mean_shifted_div_scale_x * l_mean_shifted_div_scale_x, 0.0f);
1962+
// scale_x / sqrt(var(dqX) + eps)
1963+
float scale_x_div_layer_std = x_scale /
1964+
std::sqrt(layer_var_div_scale_x_sq * x_scale * x_scale + eps);
1965+
fVec layer_mean_div_scale_xVec(layer_mean_div_scale_x);
1966+
fVec scale_x_div_layer_stdVec(scale_x_div_layer_std);
1967+
1968+
// Second pass: normalize
1969+
1970+
// TODO replace with TensorIterator implementation once #33166 is fixed.
1971+
for (int64_t vecIdx = 0; vecIdx < kNumIntVecInLayer; vecIdx++) {
1972+
int64_t vecStartIdx = vecIdx * kIntVLen;
1973+
auto qXVec = qVec::loadu(X_ptr + vecStartIdx);
1974+
auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
1975+
x_fake_scale_zp_neg_premul_vec);
1976+
for (int dqXVecIdx = 0; dqXVecIdx < dqXVec.size(); dqXVecIdx++) {
1977+
int64_t vecVecStartIdx = vecStartIdx + dqXVecIdx * kFloatVLen;
1978+
auto gammaVec = gamma_null
1979+
? one_vec
1980+
: fVec::loadu(gamma_data + vecVecStartIdx);
1981+
auto betaVec = beta_null
1982+
? zero_vec
1983+
: fVec::loadu(beta_data + vecVecStartIdx);
1984+
dqXVec[dqXVecIdx] =
1985+
(dqXVec[dqXVecIdx] - layer_mean_div_scale_xVec) *
1986+
scale_x_div_layer_stdVec * gammaVec + betaVec;
1987+
qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale)
1988+
.store(Y_ptr + vecStartIdx);
1989+
}
1990+
}
1991+
for (int64_t remIdx = N - kNonVecRemInLayer; remIdx < N; remIdx++) {
1992+
const float gamma_v = gamma_null ? 1.0f : gamma_data[remIdx];
1993+
const float beta_v = beta_null ? 0.0f : beta_data[remIdx];
1994+
auto qXVal = X_ptr[remIdx];
1995+
float dqXVal = at::dequantize_val(x_fake_scale, x_zp, qXVal);
1996+
float dqY =
1997+
((dqXVal - layer_mean_div_scale_x) * scale_x_div_layer_std) * gamma_v + beta_v;
1998+
Y_ptr[remIdx] = at::quantize_val<scalar_t>(y_scale, y_zp, dqY);
1999+
}
2000+
}
2001+
}); // parallel_for
2002+
2003+
});
2004+
}
2005+
18922006
} // namespace
18932007

18942008
REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel);
@@ -1924,6 +2038,7 @@ REGISTER_DISPATCH(fake_quant_tensor_stub, &fake_quantize_tensor_kernel);
19242038
REGISTER_DISPATCH(fake_quant_grad_tensor_stub, &fake_quantize_grad_tensor_kernel);
19252039
REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cpu);
19262040
REGISTER_DISPATCH(fake_quant_grad_per_channel_stub, &fake_quant_grad_per_channel_cpu);
2041+
REGISTER_DISPATCH(quantized_layer_norm_stub, &quantized_layer_norm_kernel);
19272042

19282043
} // namespace native
19292044
} // namespace at

benchmarks/operator_benchmark/benchmark_all_other_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
add_test, as_strided_test, batchnorm_test, binary_test, cat_test, # noqa
99
chunk_test, conv_test, diag_test, embeddingbag_test, fill_test, # noqa
1010
gather_test, linear_test, matmul_test, pool_test, # noqa
11-
softmax_test, hardsigmoid_test, hardswish_test # noqa
11+
softmax_test, hardsigmoid_test, hardswish_test, layernorm_test # noqa
1212
)
1313

1414
if __name__ == "__main__":

benchmarks/operator_benchmark/benchmark_all_quantized_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
qcomparators_test,
1313
qconv_test,
1414
qinterpolate_test,
15+
qlayernorm_test,
1516
qlinear_test,
1617
qobserver_test,
1718
qpool_test,
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
from __future__ import unicode_literals
5+
6+
7+
import operator_benchmark as op_bench
8+
import torch
9+
import torch.nn.functional as F
10+
11+
12+
"""Microbenchmarks for layernorm operator."""
13+
14+
layernorm_configs_short = op_bench.cross_product_configs(
15+
dims=(
16+
(1, 8, 16),
17+
(8, 8, 16),
18+
(32, 8, 16),
19+
(64, 128, 56, 56),
20+
),
21+
tags=["short"],
22+
)
23+
24+
25+
class LayerNormBenchmark(op_bench.TorchBenchmarkBase):
26+
def init(self, dims):
27+
self.X = (torch.rand(*dims) - 0.5) * 256
28+
self.weight = torch.rand(*self.X.size()[1:], dtype=torch.float)
29+
self.bias = torch.rand(*self.X.size()[1:], dtype=torch.float)
30+
self.eps = 1e-5
31+
32+
def forward(self):
33+
return F.layer_norm(
34+
self.X, self.X.size()[1:], weight=self.weight, bias=self.bias, eps=self.eps)
35+
36+
37+
op_bench.generate_pt_test(layernorm_configs_short, LayerNormBenchmark)
38+
39+
40+
if __name__ == "__main__":
41+
op_bench.benchmark_runner.main()
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
from __future__ import unicode_literals
5+
6+
7+
import operator_benchmark as op_bench
8+
import torch
9+
10+
11+
"""Microbenchmarks for quantized layernorm operator."""
12+
13+
layernorm_configs_short = op_bench.cross_product_configs(
14+
dims=(
15+
(1, 8, 16),
16+
(8, 8, 16),
17+
(32, 8, 16),
18+
(64, 128, 56, 56),
19+
),
20+
dtype=(torch.qint8,),
21+
tags=["short"],
22+
)
23+
24+
25+
class QLayerNormBenchmark(op_bench.TorchBenchmarkBase):
26+
27+
def init(self, dims, dtype):
28+
X = (torch.rand(*dims) - 0.5) * 256
29+
scale = 1.0
30+
zero_point = 0
31+
self.qX = torch.quantize_per_tensor(
32+
X, scale=scale, zero_point=zero_point, dtype=dtype)
33+
self.weight = torch.rand(*self.qX.size()[1:], dtype=torch.float)
34+
self.bias = torch.rand(*self.qX.size()[1:], dtype=torch.float)
35+
self.eps = 1e-5
36+
self.Y_scale = 0.1
37+
self.Y_zero_point = 0
38+
39+
def forward(self):
40+
return torch.ops.quantized.layer_norm(
41+
self.qX, self.qX.size()[1:], weight=self.weight, bias=self.bias,
42+
eps=self.eps, output_scale=self.Y_scale,
43+
output_zero_point=self.Y_zero_point)
44+
45+
46+
op_bench.generate_pt_test(layernorm_configs_short, QLayerNormBenchmark)
47+
48+
49+
if __name__ == "__main__":
50+
op_bench.benchmark_runner.main()

0 commit comments

Comments
 (0)