Skip to content

Commit d043560

Browse files
supriyarfacebook-github-bot
authored andcommitted
[quant] Add a quantized batch_norm operator (#33080)
Summary: Pull Request resolved: #33080 Quantized batch norm for cases where batch norm cannot be fused with conv. AVX2 implementation is from Caffe2. Test Plan: python test/test_quantized.py TestQuantizedOps.test_batch_norm Imported from OSS Differential Revision: D19861927 fbshipit-source-id: bd8cd101fc063cb6358132ab7c651a160999293c
1 parent b28a834 commit d043560

File tree

5 files changed

+276
-0
lines changed

5 files changed

+276
-0
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,11 @@
476476

477477
- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor
478478

479+
- func: quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor
480+
requires_tensor: True
481+
dispatch:
482+
QuantizedCPU: quantized_batch_norm
483+
479484
- func: _batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)
480485

481486
- func: _batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
#include <ATen/native/SortingUtils.h>
99

1010
#include <cmath>
11+
#ifdef USE_FBGEMM
12+
#include "fbgemm/QuantUtils.h"
13+
#endif
14+
#ifdef _OPENMP
15+
#include <omp.h>
16+
#endif
1117

1218
namespace at {
1319
namespace native {
@@ -1006,6 +1012,74 @@ void qtopk_kernel(Tensor& values,
10061012
});
10071013
}
10081014

1015+
template <bool ReluFused>
1016+
void q_batch_norm_kernel(
1017+
int64_t N,
1018+
int64_t C,
1019+
int64_t HxW,
1020+
const int64_t in_zero_point,
1021+
const int64_t out_zero_point,
1022+
const Tensor& input,
1023+
const Tensor& a,
1024+
const Tensor& b,
1025+
Tensor& output) {
1026+
1027+
AT_DISPATCH_QINT_TYPES(input.scalar_type(), "qbatch_norm", [&]() {
1028+
float* alpha = a.data_ptr<float>();
1029+
float* beta = b.data_ptr<float>();
1030+
scalar_t::underlying* X =
1031+
reinterpret_cast<scalar_t::underlying*>(input.data_ptr());
1032+
scalar_t::underlying* Y = reinterpret_cast<scalar_t::underlying*>(output.data_ptr());
1033+
1034+
constexpr int kVLen = 8;
1035+
const int64_t outer_size = N * HxW;
1036+
using Vec = Vec256<scalar_t>;
1037+
// Hoisted variables
1038+
auto in_zp_vec = Vec256<float>(static_cast<float>(in_zero_point));
1039+
auto fake_scale = Vec256<float>(1.0f);
1040+
auto scale_neg_zp_premul = fake_scale * in_zp_vec.neg();
1041+
auto out_zero_point_v = Vec(scalar_t(out_zero_point));
1042+
1043+
// TODO replace with TensorIterator implementation once #33166 is fixed.
1044+
for (int64_t i = 0; i < outer_size; ++i) {
1045+
int64_t n = C / (Vec::float_num_vecs() * kVLen) * (Vec::float_num_vecs() * kVLen);
1046+
int64_t r = C % (Vec::float_num_vecs() * kVLen);
1047+
auto* X_ptr = reinterpret_cast<typename scalar_t::underlying*>(X + i * C);
1048+
auto* Y_ptr = reinterpret_cast<typename scalar_t::underlying*>(Y + i * C);
1049+
1050+
for (int64_t j = 0; j < n; j += Vec::float_num_vecs() * kVLen) {
1051+
auto vals_q = Vec::loadu(X_ptr + j);
1052+
// Fake scale of 1.0 here, should not affect performance (FMA in place of sub)
1053+
auto vals_dq = vals_q.dequantize(fake_scale, in_zp_vec, scale_neg_zp_premul);
1054+
for (size_t idx = 0; idx < vals_dq.size(); ++idx) {
1055+
auto alpha_v = Vec256<float>::loadu(alpha + j + idx * kVLen);
1056+
auto beta_v = Vec256<float>::loadu(beta + j + idx * kVLen);
1057+
vals_dq[idx] = vec256::fmadd(alpha_v, vals_dq[idx], beta_v);
1058+
}
1059+
// Fake scale again
1060+
auto outputs_q = Vec::quantize(vals_dq, /*output_scale=*/1.0f, out_zero_point, /*inv_output_scale=*/1.0f);
1061+
if (ReluFused) {
1062+
outputs_q = outputs_q.relu(out_zero_point_v);
1063+
}
1064+
outputs_q.store(Y_ptr + j);
1065+
}
1066+
1067+
for (int64_t j = 0; j < r; ++j) {
1068+
long quantized_down = out_zero_point +
1069+
lrintf(alpha[n + j] * (X_ptr[n + j] - in_zero_point) +
1070+
beta[n + j]);
1071+
if (ReluFused) { // static if
1072+
quantized_down = std::max<long>(quantized_down, out_zero_point);
1073+
}
1074+
Y_ptr[n + j] = std::min<long>(
1075+
std::max<long>(quantized_down, std::numeric_limits<scalar_t::underlying>::min()),
1076+
std::numeric_limits<scalar_t::underlying>::max());
1077+
}
1078+
}
1079+
});
1080+
1081+
}
1082+
10091083
} // namespace
10101084

10111085
REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel);
@@ -1027,6 +1101,7 @@ REGISTER_DISPATCH(
10271101
REGISTER_DISPATCH(qcat_nhwc_stub, &qcat_nhwc_kernel<false>);
10281102
REGISTER_DISPATCH(qcat_relu_nhwc_stub, &qcat_nhwc_kernel<true>);
10291103
REGISTER_DISPATCH(qtopk_stub, &qtopk_kernel);
1104+
REGISTER_DISPATCH(qbatch_norm_stub, &q_batch_norm_kernel<false>);
10301105

10311106
} // namespace native
10321107
} // namespace at
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/NativeFunctions.h>
3+
#include <ATen/Parallel.h>
4+
#include <ATen/core/op_registration/op_registration.h>
5+
#include <ATen/native/quantized/cpu/quantized_ops.h>
6+
7+
#include <algorithm>
8+
#include <vector>
9+
10+
namespace at {
11+
namespace native {
12+
13+
DEFINE_DISPATCH(qbatch_norm_stub);
14+
15+
namespace {
16+
void compute_fused_params(
17+
const int64_t channels,
18+
const float* weight_data,
19+
const float* bias_data,
20+
const float* mean_data,
21+
const float* var_data,
22+
double eps,
23+
float input_scale,
24+
float output_scale,
25+
float* alpha_data,
26+
float* beta_data) {
27+
// Batch Normalization
28+
// output(n, c, h, w)
29+
// = (input(n, c, h, w) - mean(c)) / sqrt(var(c) + eps) * weight(c)
30+
// + bias(c)
31+
// We factor out inv_sigma(c) = 1 / sqrt(var(c) + eps).
32+
for (int64_t c = 0; c < channels; c++) {
33+
float inv_sigma = 1.0 / std::sqrt(var_data[c] + static_cast<float>(eps));
34+
float weight_v = weight_data ? weight_data[c] : 1;
35+
float bias_v = bias_data ? bias_data[c] : 0;
36+
alpha_data[c] = inv_sigma * weight_v * (input_scale / output_scale);
37+
beta_data[c] = (bias_v - mean_data[c] * inv_sigma * weight_v) / output_scale;
38+
}
39+
}
40+
41+
template <bool ReluFused>
42+
Tensor q_batch_norm_impl(
43+
Tensor qx,
44+
Tensor weight,
45+
Tensor bias,
46+
Tensor mean,
47+
Tensor var,
48+
double eps,
49+
float output_scale,
50+
int64_t output_zero_point) {
51+
52+
if (qx.numel() == 0) {
53+
auto out = qx.clone();
54+
return out;
55+
}
56+
int64_t ndim = qx.dim();
57+
TORCH_CHECK(ndim == 4, "Expecting the input tensor of rank 4.");
58+
const int64_t N = qx.size(0);
59+
const int64_t C = qx.size(1);
60+
const int64_t H = qx.size(2);
61+
const int64_t W = qx.size(3);
62+
63+
TORCH_CHECK(weight.numel() == C, "Expect weight size to match C");
64+
TORCH_CHECK(bias.numel() == C, "Expect weight size to match C");
65+
66+
const float* weight_data = weight.template data<float>();
67+
const float* bias_data = bias.template data<float>();
68+
69+
TORCH_CHECK(mean.numel() == C, "Mean size must match channel dimension");
70+
TORCH_CHECK(var.numel() == C, "Variance size must match channel dimension");
71+
72+
Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
73+
Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
74+
float* alpha_data = alpha.data_ptr<float>();
75+
float* beta_data = beta.data_ptr<float>();
76+
77+
const float* mean_data = mean.template data<float>();
78+
const float* var_data = var.template data<float>();
79+
80+
auto oSizes = qx.sizes();
81+
auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast);
82+
Tensor qy = at::_empty_affine_quantized(
83+
oSizes,
84+
at::device(kCPU).dtype(qx_nhwc.scalar_type()),
85+
output_scale,
86+
output_zero_point,
87+
MemoryFormat::ChannelsLast);
88+
89+
compute_fused_params(
90+
C,
91+
weight_data,
92+
bias_data,
93+
mean_data,
94+
var_data,
95+
eps,
96+
qx.q_scale(),
97+
output_scale,
98+
alpha_data,
99+
beta_data);
100+
101+
qbatch_norm_stub(
102+
qx.device().type(),
103+
N,
104+
C,
105+
H * W,
106+
qx.q_zero_point(),
107+
output_zero_point,
108+
qx_nhwc,
109+
alpha,
110+
beta,
111+
qy);
112+
return qy;
113+
}
114+
115+
} // namespace
116+
117+
Tensor quantized_batch_norm(
118+
const Tensor& qx,
119+
const Tensor& weight /* optional */,
120+
const Tensor& bias /* optional */,
121+
const Tensor& mean /* optional */,
122+
const Tensor& var /* optional */,
123+
double eps,
124+
double output_scale,
125+
int64_t output_zero_point) {
126+
Tensor qy;
127+
qy = q_batch_norm_impl<false>(
128+
qx, weight, bias, mean, var, eps, output_scale, output_zero_point);
129+
return qy;
130+
}
131+
132+
// Keep the registry in the anonymous namespace.
133+
namespace {
134+
class QBatchNorm2d final : public torch::OperatorKernel {
135+
public:
136+
Tensor operator()(
137+
Tensor qx,
138+
Tensor weight,
139+
Tensor bias,
140+
Tensor mean,
141+
Tensor var,
142+
double eps,
143+
double output_scale,
144+
int64_t output_zero_point) {
145+
return q_batch_norm_impl<false>(
146+
qx, weight, bias, mean, var, eps, output_scale, output_zero_point);
147+
}
148+
};
149+
150+
static auto registry = torch::RegisterOperators().op(
151+
"quantized::batch_norm(Tensor qx, "
152+
"Tensor weight, "
153+
"Tensor bias, "
154+
"Tensor mean, "
155+
"Tensor var, "
156+
"float eps, "
157+
"float output_scale, "
158+
"int output_zero_point) -> Tensor",
159+
torch::RegisterOperators::options().kernel<QBatchNorm2d>(
160+
DispatchKey::QuantizedCPUTensorId));
161+
162+
} // namespace
163+
} // namespace native
164+
} // namespace at

aten/src/ATen/native/quantized/cpu/quantized_ops.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ using qcat_nhwc_fn = Tensor (*)(
8585
int64_t zero_point);
8686
using qtopk_fn = void(*)(Tensor&, Tensor&, const Tensor&, int64_t, int64_t, bool, bool);
8787

88+
using qbatch_norm_fn = void(*)(int64_t, int64_t, int64_t, const int64_t, const int64_t, const Tensor&, const Tensor&, const Tensor&, Tensor&);
89+
8890
// using qavg_pool2d_fn
8991
DECLARE_DISPATCH(qrelu_fn, qrelu_stub);
9092
DECLARE_DISPATCH(qrelu_fn, qrelu6_stub);
@@ -101,6 +103,7 @@ DECLARE_DISPATCH(qupsample_bilinear2d_fn, qupsample_bilinear2d_nhwc_stub);
101103
DECLARE_DISPATCH(qcat_nhwc_fn, qcat_nhwc_stub);
102104
DECLARE_DISPATCH(qcat_nhwc_fn, qcat_relu_nhwc_stub);
103105
DECLARE_DISPATCH(qtopk_fn, qtopk_stub);
106+
DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_stub);
104107

105108
} // namespace native
106109
} // namespace at

test/test_quantized.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,35 @@ def equal_ref(qX, qX2):
10931093
self.assertEqual(qX.equal(qX2), equal_ref(qX, qX2))
10941094

10951095

1096+
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4,
1097+
min_side=1, max_side=32),
1098+
qparams=hu.qparams()),
1099+
Y_scale=st.floats(0.2, 2.6),
1100+
Y_zero_point=st.integers(0, 5),
1101+
qengine=st.sampled_from(("qnnpack", "fbgemm")))
1102+
def test_batch_norm(self, X, Y_scale, Y_zero_point, qengine):
1103+
if qengine not in torch.backends.quantized.supported_engines:
1104+
return
1105+
1106+
with override_quantized_engine(qengine):
1107+
X, (scale_x, zero_point_x, dtype_x) = X
1108+
1109+
X = torch.from_numpy(X)
1110+
c = X.shape[1]
1111+
1112+
mean = torch.rand(c).float()
1113+
var = torch.rand(c).float()
1114+
weight = torch.rand(c).float()
1115+
bias = torch.rand(c).float()
1116+
eps = 0.001
1117+
qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x)
1118+
qy = torch.ops.quantized.batch_norm(qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
1119+
1120+
float_ref = F.batch_norm(qx.dequantize(), weight=weight, bias=bias,
1121+
running_mean=mean, running_var=var, training=False, momentum=0, eps=eps)
1122+
quantize_ref = torch.quantize_per_tensor(float_ref, Y_scale, Y_zero_point, dtype_x)
1123+
self.assertEqual(qy.int_repr().numpy(), quantize_ref.int_repr().numpy())
1124+
10961125
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
10971126
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
10981127
" with instruction set support avx2 or newer.")

0 commit comments

Comments
 (0)