Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,11 @@

- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor

- func: quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor
requires_tensor: True
dispatch:
QuantizedCPU: quantized_batch_norm

- func: _batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)

- func: _batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)
Expand Down
75 changes: 75 additions & 0 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
#include <ATen/native/SortingUtils.h>

#include <cmath>
#ifdef USE_FBGEMM
#include "fbgemm/QuantUtils.h"
#endif
#ifdef _OPENMP
#include <omp.h>
#endif

namespace at {
namespace native {
Expand Down Expand Up @@ -1006,6 +1012,74 @@ void qtopk_kernel(Tensor& values,
});
}

template <bool ReluFused>
void q_batch_norm_kernel(
int64_t N,
int64_t C,
int64_t HxW,
const int64_t in_zero_point,
const int64_t out_zero_point,
const Tensor& input,
const Tensor& a,
const Tensor& b,
Tensor& output) {

AT_DISPATCH_QINT_TYPES(input.scalar_type(), "qbatch_norm", [&]() {
float* alpha = a.data_ptr<float>();
float* beta = b.data_ptr<float>();
scalar_t::underlying* X =
reinterpret_cast<scalar_t::underlying*>(input.data_ptr());
scalar_t::underlying* Y = reinterpret_cast<scalar_t::underlying*>(output.data_ptr());

constexpr int kVLen = 8;
const int64_t outer_size = N * HxW;
using Vec = Vec256<scalar_t>;
// Hoisted variables
auto in_zp_vec = Vec256<float>(static_cast<float>(in_zero_point));
auto fake_scale = Vec256<float>(1.0f);
auto scale_neg_zp_premul = fake_scale * in_zp_vec.neg();
auto out_zero_point_v = Vec(scalar_t(out_zero_point));

// TODO replace with TensorIterator implementation once #33166 is fixed.
for (int64_t i = 0; i < outer_size; ++i) {
int64_t n = C / (Vec::float_num_vecs() * kVLen) * (Vec::float_num_vecs() * kVLen);
int64_t r = C % (Vec::float_num_vecs() * kVLen);
auto* X_ptr = reinterpret_cast<typename scalar_t::underlying*>(X + i * C);
auto* Y_ptr = reinterpret_cast<typename scalar_t::underlying*>(Y + i * C);

for (int64_t j = 0; j < n; j += Vec::float_num_vecs() * kVLen) {
auto vals_q = Vec::loadu(X_ptr + j);
// Fake scale of 1.0 here, should not affect performance (FMA in place of sub)
auto vals_dq = vals_q.dequantize(fake_scale, in_zp_vec, scale_neg_zp_premul);
for (size_t idx = 0; idx < vals_dq.size(); ++idx) {
auto alpha_v = Vec256<float>::loadu(alpha + j + idx * kVLen);
auto beta_v = Vec256<float>::loadu(beta + j + idx * kVLen);
vals_dq[idx] = vec256::fmadd(alpha_v, vals_dq[idx], beta_v);
}
// Fake scale again
auto outputs_q = Vec::quantize(vals_dq, /*output_scale=*/1.0f, out_zero_point, /*inv_output_scale=*/1.0f);
if (ReluFused) {
outputs_q = outputs_q.relu(out_zero_point_v);
}
outputs_q.store(Y_ptr + j);
}

for (int64_t j = 0; j < r; ++j) {
long quantized_down = out_zero_point +
lrintf(alpha[n + j] * (X_ptr[n + j] - in_zero_point) +
beta[n + j]);
if (ReluFused) { // static if
quantized_down = std::max<long>(quantized_down, out_zero_point);
}
Y_ptr[n + j] = std::min<long>(
std::max<long>(quantized_down, std::numeric_limits<scalar_t::underlying>::min()),
std::numeric_limits<scalar_t::underlying>::max());
}
}
});

}

} // namespace

REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel);
Expand All @@ -1027,6 +1101,7 @@ REGISTER_DISPATCH(
REGISTER_DISPATCH(qcat_nhwc_stub, &qcat_nhwc_kernel<false>);
REGISTER_DISPATCH(qcat_relu_nhwc_stub, &qcat_nhwc_kernel<true>);
REGISTER_DISPATCH(qtopk_stub, &qtopk_kernel);
REGISTER_DISPATCH(qbatch_norm_stub, &q_batch_norm_kernel<false>);

} // namespace native
} // namespace at
164 changes: 164 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/native/quantized/cpu/quantized_ops.h>

#include <algorithm>
#include <vector>

namespace at {
namespace native {

DEFINE_DISPATCH(qbatch_norm_stub);

namespace {
void compute_fused_params(
const int64_t channels,
const float* weight_data,
const float* bias_data,
const float* mean_data,
const float* var_data,
double eps,
float input_scale,
float output_scale,
float* alpha_data,
float* beta_data) {
// Batch Normalization
// output(n, c, h, w)
// = (input(n, c, h, w) - mean(c)) / sqrt(var(c) + eps) * weight(c)
// + bias(c)
// We factor out inv_sigma(c) = 1 / sqrt(var(c) + eps).
for (int64_t c = 0; c < channels; c++) {
float inv_sigma = 1.0 / std::sqrt(var_data[c] + static_cast<float>(eps));
float weight_v = weight_data ? weight_data[c] : 1;
float bias_v = bias_data ? bias_data[c] : 0;
alpha_data[c] = inv_sigma * weight_v * (input_scale / output_scale);
beta_data[c] = (bias_v - mean_data[c] * inv_sigma * weight_v) / output_scale;
}
}

template <bool ReluFused>
Tensor q_batch_norm_impl(
Tensor qx,
Tensor weight,
Tensor bias,
Tensor mean,
Tensor var,
double eps,
float output_scale,
int64_t output_zero_point) {

if (qx.numel() == 0) {
auto out = qx.clone();
return out;
}
int64_t ndim = qx.dim();
TORCH_CHECK(ndim == 4, "Expecting the input tensor of rank 4.");
const int64_t N = qx.size(0);
const int64_t C = qx.size(1);
const int64_t H = qx.size(2);
const int64_t W = qx.size(3);

TORCH_CHECK(weight.numel() == C, "Expect weight size to match C");
TORCH_CHECK(bias.numel() == C, "Expect weight size to match C");

const float* weight_data = weight.template data<float>();
const float* bias_data = bias.template data<float>();

TORCH_CHECK(mean.numel() == C, "Mean size must match channel dimension");
TORCH_CHECK(var.numel() == C, "Variance size must match channel dimension");

Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
float* alpha_data = alpha.data_ptr<float>();
float* beta_data = beta.data_ptr<float>();

const float* mean_data = mean.template data<float>();
const float* var_data = var.template data<float>();

auto oSizes = qx.sizes();
auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast);
Tensor qy = at::_empty_affine_quantized(
oSizes,
at::device(kCPU).dtype(qx_nhwc.scalar_type()),
output_scale,
output_zero_point,
MemoryFormat::ChannelsLast);

compute_fused_params(
C,
weight_data,
bias_data,
mean_data,
var_data,
eps,
qx.q_scale(),
output_scale,
alpha_data,
beta_data);

qbatch_norm_stub(
qx.device().type(),
N,
C,
H * W,
qx.q_zero_point(),
output_zero_point,
qx_nhwc,
alpha,
beta,
qy);
return qy;
}

} // namespace

Tensor quantized_batch_norm(
const Tensor& qx,
const Tensor& weight /* optional */,
const Tensor& bias /* optional */,
const Tensor& mean /* optional */,
const Tensor& var /* optional */,
double eps,
double output_scale,
int64_t output_zero_point) {
Tensor qy;
qy = q_batch_norm_impl<false>(
qx, weight, bias, mean, var, eps, output_scale, output_zero_point);
return qy;
}

// Keep the registry in the anonymous namespace.
namespace {
class QBatchNorm2d final : public torch::OperatorKernel {
public:
Tensor operator()(
Tensor qx,
Tensor weight,
Tensor bias,
Tensor mean,
Tensor var,
double eps,
double output_scale,
int64_t output_zero_point) {
return q_batch_norm_impl<false>(
qx, weight, bias, mean, var, eps, output_scale, output_zero_point);
}
};

static auto registry = torch::RegisterOperators().op(
"quantized::batch_norm(Tensor qx, "
"Tensor weight, "
"Tensor bias, "
"Tensor mean, "
"Tensor var, "
"float eps, "
"float output_scale, "
"int output_zero_point) -> Tensor",
torch::RegisterOperators::options().kernel<QBatchNorm2d>(
DispatchKey::QuantizedCPUTensorId));

} // namespace
} // namespace native
} // namespace at
3 changes: 3 additions & 0 deletions aten/src/ATen/native/quantized/cpu/quantized_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ using qcat_nhwc_fn = Tensor (*)(
int64_t zero_point);
using qtopk_fn = void(*)(Tensor&, Tensor&, const Tensor&, int64_t, int64_t, bool, bool);

using qbatch_norm_fn = void(*)(int64_t, int64_t, int64_t, const int64_t, const int64_t, const Tensor&, const Tensor&, const Tensor&, Tensor&);

// using qavg_pool2d_fn
DECLARE_DISPATCH(qrelu_fn, qrelu_stub);
DECLARE_DISPATCH(qrelu_fn, qrelu6_stub);
Expand All @@ -101,6 +103,7 @@ DECLARE_DISPATCH(qupsample_bilinear2d_fn, qupsample_bilinear2d_nhwc_stub);
DECLARE_DISPATCH(qcat_nhwc_fn, qcat_nhwc_stub);
DECLARE_DISPATCH(qcat_nhwc_fn, qcat_relu_nhwc_stub);
DECLARE_DISPATCH(qtopk_fn, qtopk_stub);
DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_stub);

} // namespace native
} // namespace at
29 changes: 29 additions & 0 deletions test/test_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,35 @@ def equal_ref(qX, qX2):
self.assertEqual(qX.equal(qX2), equal_ref(qX, qX2))


@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4,
min_side=1, max_side=32),
qparams=hu.qparams()),
Y_scale=st.floats(0.2, 2.6),
Y_zero_point=st.integers(0, 5),
qengine=st.sampled_from(("qnnpack", "fbgemm")))
def test_batch_norm(self, X, Y_scale, Y_zero_point, qengine):
if qengine not in torch.backends.quantized.supported_engines:
return

with override_quantized_engine(qengine):
X, (scale_x, zero_point_x, dtype_x) = X

X = torch.from_numpy(X)
c = X.shape[1]

mean = torch.rand(c).float()
var = torch.rand(c).float()
weight = torch.rand(c).float()
bias = torch.rand(c).float()
eps = 0.001
qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x)
qy = torch.ops.quantized.batch_norm(qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)

float_ref = F.batch_norm(qx.dequantize(), weight=weight, bias=bias,
running_mean=mean, running_var=var, training=False, momentum=0, eps=eps)
quantize_ref = torch.quantize_per_tensor(float_ref, Y_scale, Y_zero_point, dtype_x)
self.assertEqual(qy.int_repr().numpy(), quantize_ref.int_repr().numpy())

@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.")
Expand Down