Skip to content

Commit 15d19f7

Browse files
committed
add quantized layer norm implementation
Summary: Adds a quantized implementation of LayerNorm for server. Relevant PRs: * #20345 (floating point LN) * #33080 (quantized BN) A future PR will add the Python wrapper. Test Plan: numerics match the floating point implementation TODO: benchmarks Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent eff68bc commit 15d19f7

File tree

7 files changed

+290
-6
lines changed

7 files changed

+290
-6
lines changed

aten/src/ATen/native/layer_norm.cpp

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <ATen/Config.h>
1313
#include <ATen/NativeFunctions.h>
1414
#include <ATen/Parallel.h>
15+
#include <ATen/core/op_registration/op_registration.h>
1516

1617
namespace at {
1718
namespace native {
@@ -60,13 +61,12 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_cpu(
6061
return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta));
6162
}
6263

63-
Tensor layer_norm(
64+
std::tuple<Tensor, Tensor, Tensor, int64_t, int64_t> _prepare_layer_norm_inputs(
6465
const Tensor& input,
6566
IntArrayRef normalized_shape,
6667
const Tensor& weight /* optional */,
67-
const Tensor& bias /* optional */,
68-
double eps,
69-
bool /* cudnn_enable, deprecated */) {
68+
const Tensor& bias /* optional */) {
69+
7070
const int normalized_ndim = normalized_shape.size();
7171
TORCH_CHECK(
7272
normalized_ndim >= 1,
@@ -119,11 +119,90 @@ Tensor layer_norm(
119119
const auto& X = input.is_contiguous() ? input : input.contiguous();
120120
const auto& gamma = weight.is_contiguous() ? weight : weight.contiguous();
121121
const auto& beta = bias.is_contiguous() ? bias : bias.contiguous();
122+
123+
return std::make_tuple(X, gamma, beta, M, N);
124+
}
125+
126+
Tensor layer_norm(
127+
const Tensor& input,
128+
IntArrayRef normalized_shape,
129+
const Tensor& weight /* optional */,
130+
const Tensor& bias /* optional */,
131+
double eps,
132+
bool /* cudnn_enable, deprecated */) {
133+
134+
auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias);
135+
auto X = std::get<0>(inputs);
136+
auto gamma = std::get<1>(inputs);
137+
auto beta = std::get<2>(inputs);
138+
auto M = std::get<3>(inputs);
139+
auto N = std::get<4>(inputs);
140+
122141
return std::get<0>(at::native_layer_norm(X, gamma, beta, M, N, eps));
123142
}
124143

144+
Tensor quantized_layer_norm_impl(
145+
const Tensor& input,
146+
IntArrayRef normalized_shape,
147+
const Tensor& weight /* optional */,
148+
const Tensor& bias /* optional */,
149+
double eps,
150+
double output_scale,
151+
int64_t output_zero_point) {
152+
153+
auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias);
154+
auto X = std::get<0>(inputs);
155+
auto gamma = std::get<1>(inputs);
156+
auto beta = std::get<2>(inputs);
157+
auto M = std::get<3>(inputs);
158+
auto N = std::get<4>(inputs);
159+
160+
Tensor Y = at::_empty_affine_quantized(
161+
X.sizes(),
162+
X.scalar_type(),
163+
output_scale,
164+
output_zero_point,
165+
X.suggest_memory_format());
166+
167+
if (M > 0) {
168+
LayerNormKernelQuantized(kCPU, X, gamma, beta, M, N, eps, &Y);
169+
}
170+
return Y;
171+
}
172+
173+
// Keep the registry in the anonymous namespace.
174+
namespace {
175+
class QLayerNorm2d final : public torch::OperatorKernel {
176+
public:
177+
Tensor operator()(
178+
Tensor input,
179+
std::vector<int64_t> normalized_shape,
180+
Tensor weight /* optional */,
181+
Tensor bias /* optional */,
182+
double eps,
183+
double output_scale,
184+
int64_t output_zero_point) {
185+
return quantized_layer_norm_impl(
186+
input, normalized_shape, weight, bias, eps, output_scale, output_zero_point);
187+
}
188+
};
189+
190+
static auto registry = torch::RegisterOperators().op(
191+
"quantized::layer_norm(Tensor input, "
192+
"int[] normalized_shape, "
193+
"Tensor weight, "
194+
"Tensor bias, "
195+
"float eps, "
196+
"float output_scale, "
197+
"int output_zero_point) -> Tensor",
198+
torch::RegisterOperators::options().kernel<QLayerNorm2d>(
199+
DispatchKey::QuantizedCPUTensorId));
200+
201+
} // namespace
202+
125203
DEFINE_DISPATCH(LayerNormKernel);
126204
DEFINE_DISPATCH(LayerNormBackwardKernel);
205+
DEFINE_DISPATCH(LayerNormKernelQuantized);
127206

128207
} // namespace native
129208
} // namespace at

aten/src/ATen/native/layer_norm.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,18 @@ using backward_fn = void (*)(
2929
Tensor* /* dgamma */,
3030
Tensor* /* dbeta */);
3131

32+
using forward_quantized_fn = void (*)(
33+
const Tensor& /* X */,
34+
const Tensor& /* gamma */,
35+
const Tensor& /* beta */,
36+
int64_t /* M */,
37+
int64_t /* N */,
38+
double /* eps */,
39+
Tensor* /* Y */);
40+
3241
DECLARE_DISPATCH(forward_fn, LayerNormKernel);
3342
DECLARE_DISPATCH(backward_fn, LayerNormBackwardKernel);
43+
DECLARE_DISPATCH(forward_quantized_fn, LayerNormKernelQuantized);
3444

3545
} // namespace native
3646
} // namespace at

aten/src/ATen/native/native_functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,11 @@
15981598
CPU: layer_norm_backward_cpu
15991599
CUDA: layer_norm_backward_cuda
16001600

1601+
- func: quantized_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor
1602+
requires_tensor: True
1603+
dispatch:
1604+
QuantizedCPU: quantized_layer_norm_impl
1605+
16011606
- func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
16021607
python_module: nn
16031608

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <ATen/native/quantized/cpu/quantized_ops.h>
77
#include <ATen/quantized/Quantizer.h>
88
#include <ATen/native/SortingUtils.h>
9+
#include <ATen/cpu/vec256/functional.h>
910

1011
#include <cmath>
1112
#ifdef USE_FBGEMM
@@ -1497,6 +1498,156 @@ void fake_quant_grad_per_channel_cpu(TensorIterator &iter, int64_t quant_min, in
14971498
});
14981499
}
14991500

1501+
template <typename T>
1502+
void LayerNormKernelQuantizedImplInternal(
1503+
const Tensor& X,
1504+
const Tensor& gamma,
1505+
const Tensor& beta,
1506+
int64_t M,
1507+
int64_t N,
1508+
float eps,
1509+
Tensor* Y) {
1510+
1511+
using qVec = vec256::Vec256<T>;
1512+
using fVec = vec256::Vec256<float>;
1513+
1514+
DCHECK_EQ(X.numel(), M * N);
1515+
DCHECK(!gamma.defined() || gamma.numel() == N);
1516+
DCHECK(!beta.defined() || beta.numel() == N);
1517+
T* X_data = X.data_ptr<T>();
1518+
const float* gamma_data = gamma.defined() ? gamma.data_ptr<float>() : nullptr;
1519+
const float* beta_data = beta.defined() ? beta.data_ptr<float>() : nullptr;
1520+
T* Y_data = Y->data_ptr<T>();
1521+
const float c = 1.0f / static_cast<float>(N);
1522+
const bool gamma_null = gamma_data == nullptr;
1523+
const bool beta_null = beta_data == nullptr;
1524+
1525+
int64_t x_zp = X.q_zero_point();
1526+
float x_scale = X.q_scale();
1527+
1528+
fVec x_zp_vec = fVec((float)x_zp);
1529+
fVec one_vec = fVec(1.0f);
1530+
fVec zero_vec = fVec(0.0f);
1531+
1532+
float x_fake_scale = 1.0f;
1533+
fVec x_fake_scale_vec = fVec(x_fake_scale);
1534+
fVec x_fake_scale_zp_neg_premul_vec = x_fake_scale_vec * x_zp_vec.neg();
1535+
1536+
int64_t y_zp = Y->q_zero_point();
1537+
float y_scale = Y->q_scale();
1538+
float y_inv_scale = 1.0f / y_scale;
1539+
1540+
// 8 floats in a 256 bit Vec256
1541+
constexpr int kFloatVLen = 8;
1542+
// N ints in a qVec
1543+
int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs();
1544+
// portion of layer that can be vectorized
1545+
int64_t kNumIntVecInLayer = N / kIntVLen;
1546+
// remainder of layer that cannot be vectorized
1547+
int64_t kNonVecRemInLayer = N % kIntVLen;
1548+
1549+
at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
1550+
for (int64_t i = start; i < end; ++i) {
1551+
1552+
T* X_ptr = X_data + i * N;
1553+
T* Y_ptr = Y_data + i * N;
1554+
1555+
// First pass: calculate mean and variance.
1556+
// Note: Fake dequant using scale=1.0f because scale_x cancels out
1557+
// during normalization, with the exception of epsilon
1558+
1559+
// TODO replace with TensorIterator implementation once #33166 is fixed.
1560+
float layerSum = 0.0f;
1561+
float layerSumSquares = 0.0f;
1562+
for (int64_t vecIdx = 0; vecIdx < kNumIntVecInLayer; vecIdx++) {
1563+
auto qXVec = qVec::loadu(X_ptr + vecIdx * kIntVLen);
1564+
auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
1565+
x_fake_scale_zp_neg_premul_vec);
1566+
// sum of vals
1567+
float thisLayerSum = vec256::reduce_all<float>(
1568+
[](fVec& x, fVec& y) { return x + y; },
1569+
(float*)dqXVec.data(),
1570+
kFloatVLen * dqXVec.size()
1571+
);
1572+
layerSum += thisLayerSum;
1573+
// sum of squares
1574+
float thisLayerSumSquares = vec256::map_reduce_all<float>(
1575+
[](fVec x) { return x * x; },
1576+
[](fVec x, fVec y) { return x + y; },
1577+
(float*)dqXVec.data(),
1578+
kFloatVLen * dqXVec.size()
1579+
);
1580+
layerSumSquares += thisLayerSumSquares;
1581+
}
1582+
for (int64_t remIdx = N - kNonVecRemInLayer; remIdx < N; remIdx++) {
1583+
auto qXVal = X_ptr[remIdx];
1584+
float dqXVal = at::dequantize_val(x_fake_scale, x_zp, qXVal);
1585+
layerSum += dqXVal;
1586+
layerSumSquares += dqXVal * dqXVal;
1587+
}
1588+
1589+
// mean(dqX) / scale_x
1590+
float layerMeanDivScaleX = layerSum / N;
1591+
// var(dqX) / scale_x^2
1592+
float layerVarDivScaleXSq =
1593+
std::max(layerSumSquares / N - layerMeanDivScaleX * layerMeanDivScaleX, 0.0f);
1594+
// scale_x / std(dqX), scale epsilon properly
1595+
float scaleXDivLayerStd = 1.0f /
1596+
std::sqrt(layerVarDivScaleXSq + (eps * x_scale * x_scale));
1597+
fVec layerMeanDivScaleXVec(layerMeanDivScaleX);
1598+
fVec scaleXDivLayerStdVec(scaleXDivLayerStd);
1599+
1600+
// Second pass: normalize
1601+
1602+
// TODO replace with TensorIterator implementation once #33166 is fixed.
1603+
for (int64_t vecIdx = 0; vecIdx < kNumIntVecInLayer; vecIdx++) {
1604+
int64_t vecStartIdx = vecIdx * kIntVLen;
1605+
auto qXVec = qVec::loadu(X_ptr + vecStartIdx);
1606+
auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
1607+
x_fake_scale_zp_neg_premul_vec);
1608+
for (int dqXVecIdx = 0; dqXVecIdx < dqXVec.size(); dqXVecIdx++) {
1609+
int64_t vecVecStartIdx = vecStartIdx + dqXVecIdx * kFloatVLen;
1610+
auto gammaVec = gamma_null
1611+
? one_vec
1612+
: fVec::loadu(gamma_data + vecVecStartIdx);
1613+
auto betaVec = beta_null
1614+
? zero_vec
1615+
: fVec::loadu(beta_data + vecVecStartIdx);
1616+
dqXVec[dqXVecIdx] =
1617+
(dqXVec[dqXVecIdx] - layerMeanDivScaleXVec) *
1618+
scaleXDivLayerStdVec * gammaVec + betaVec;
1619+
qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale)
1620+
.store(Y_ptr + vecStartIdx);
1621+
}
1622+
}
1623+
for (int64_t remIdx = N - kNonVecRemInLayer; remIdx < N; remIdx++) {
1624+
const float gamma_v = gamma_null ? 1.0f : gamma_data[remIdx];
1625+
const float beta_v = beta_null ? 0.0f : beta_data[remIdx];
1626+
auto qXVal = X_ptr[remIdx];
1627+
float dqXVal = at::dequantize_val(x_fake_scale, x_zp, qXVal);
1628+
float dqY =
1629+
((dqXVal - layerMeanDivScaleX) * scaleXDivLayerStd) * gamma_v + beta_v;
1630+
Y_ptr[remIdx] = at::quantize_val<T>(y_scale, y_zp, dqY);
1631+
}
1632+
1633+
}
1634+
}); // parallel_for
1635+
}
1636+
1637+
void LayerNormKernelQuantizedImpl(
1638+
const Tensor& X,
1639+
const Tensor& gamma,
1640+
const Tensor& beta,
1641+
int64_t M,
1642+
int64_t N,
1643+
double eps,
1644+
Tensor* Y) {
1645+
AT_DISPATCH_QINT_TYPES(X.scalar_type(), "LayerNormKernelImpl", [&]() {
1646+
LayerNormKernelQuantizedImplInternal<scalar_t>(
1647+
X, gamma, beta, M, N, static_cast<float>(eps), Y);
1648+
});
1649+
}
1650+
15001651
} // namespace
15011652

15021653
REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel);
@@ -1531,6 +1682,7 @@ REGISTER_DISPATCH(fake_quant_tensor_stub, &fake_quantize_tensor_kernel);
15311682
REGISTER_DISPATCH(fake_quant_grad_tensor_stub, &fake_quantize_grad_tensor_kernel);
15321683
REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cpu);
15331684
REGISTER_DISPATCH(fake_quant_grad_per_channel_stub, &fake_quant_grad_per_channel_cpu);
1685+
REGISTER_DISPATCH(LayerNormKernelQuantized, &LayerNormKernelQuantizedImpl);
15341686

15351687
} // namespace native
15361688
} // namespace at

test/test_quantized.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,44 @@ def test_qhardsigmoid(self, X):
271271
message="Hardsigmoid failed: {} vs. {}".format(qY, qY_hat))
272272

273273

274+
"""Tests the correctness of the quantized::qlayer_norm op."""
275+
@given(X=hu.tensor(shapes=hu.array_shapes(3, 5, 1, 32),
276+
elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
277+
qparams=hu.qparams()),
278+
Y_scale=st.floats(0.2, 2.6),
279+
Y_zero_point=st.integers(0, 5),
280+
qengine=st.sampled_from(("qnnpack", "fbgemm")))
281+
def test_qlayer_norm(self, X, Y_scale, Y_zero_point, qengine):
282+
if qengine not in torch.backends.quantized.supported_engines:
283+
return
284+
285+
with override_quantized_engine(qengine):
286+
X, (scale, zero_point, torch_type) = X
287+
X = torch.from_numpy(X)
288+
qX = torch.quantize_per_tensor(X, scale=scale,
289+
zero_point=zero_point,
290+
dtype=torch_type)
291+
dqX = qX.dequantize()
292+
293+
weight = torch.rand(*qX.size()[1:], dtype=torch.float)
294+
bias = torch.rand(*qX.size()[1:], dtype=torch.float)
295+
epsilon = 1e-5
296+
297+
qY = torch.ops.quantized.layer_norm(
298+
qX, qX.size()[1:], weight=weight, bias=bias, eps=epsilon,
299+
output_scale=Y_scale, output_zero_point=Y_zero_point)
300+
301+
Y_hat = F.layer_norm(
302+
dqX, dqX.size()[1:], weight=weight, bias=bias, eps=epsilon)
303+
qY_hat = torch.quantize_per_tensor(
304+
Y_hat, scale=Y_scale, zero_point=Y_zero_point, dtype=torch_type)
305+
306+
self.assertEqual(
307+
qY,
308+
qY_hat,
309+
message="LayerNorm failed:\n {} input vs\n {} actual vs \n{} expected".format(X, qY, qY_hat))
310+
311+
274312
"""Tests the correctness of the quantized::qnnpack_tanh op."""
275313
@given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
276314
qparams=hu.qparams()))

third_party/protobuf

Submodule protobuf updated 1548 files

0 commit comments

Comments
 (0)