Skip to content

Commit 2032145

Browse files
Reorganize and add more validation to MKL requantization
PiperOrigin-RevId: 387901341 Change-Id: I2515b9034c64e113db0bcec8337d30643ab0a0f1
1 parent aff0d5b commit 2032145

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,35 +49,45 @@ class MklRequantizePerChannelOp : public OpKernel {
4949
void Compute(OpKernelContext* ctx) override {
5050
try {
5151
const Tensor& input = ctx->input(kInputTensorIndex);
52+
OP_REQUIRES(
53+
ctx, input.dims() == 4,
54+
errors::InvalidArgument("Current RequantizePerChannel operator"
55+
"supports 4D tensors only."));
56+
5257
const Tensor& input_min_vec = ctx->input(kInputMinVecIndex);
58+
size_t depth = input_min_vec.NumElements();
5359
float* input_min_vec_data = (float*)const_cast<void*>(
5460
static_cast<const void*>(input_min_vec.flat<float>().data()));
61+
5562
const Tensor& input_max_vec = ctx->input(kInputMaxVecIndex);
63+
OP_REQUIRES(
64+
ctx, input_max_vec.NumElements() == depth,
65+
errors::InvalidArgument("input_max has incorrect size, expected ",
66+
depth, " was ", input_max_vec.NumElements()));
5667
float* input_max_vec_data = (float*)const_cast<void*>(
5768
static_cast<const void*>(input_max_vec.flat<float>().data()));
5869

5970
const Tensor& input_requested_min = ctx->input(this->kRequestMinIndex);
71+
OP_REQUIRES(
72+
ctx, input_requested_min.NumElements() == 1,
73+
errors::InvalidArgument("requested_output_min must be a scalar"));
6074
const float input_requested_min_float =
6175
input_requested_min.flat<float>()(0);
76+
6277
const Tensor& input_requested_max = ctx->input(this->kRequestMaxIndex);
78+
OP_REQUIRES(
79+
ctx, input_requested_min.NumElements() == 1,
80+
errors::InvalidArgument("requested_output_max must be a scalar"));
6381
const float input_requested_max_float =
6482
input_requested_max.flat<float>()(0);
6583

66-
size_t depth = input_min_vec.NumElements();
67-
OP_REQUIRES(
68-
ctx, input.dims() == 4,
69-
errors::InvalidArgument("Current RequantizePerChannel operator"
70-
"supports 4D tensors only."));
71-
OP_REQUIRES(
72-
ctx, input_min_vec.dim_size(0) == depth,
73-
errors::InvalidArgument("input_min has incorrect size, expected ",
74-
depth, " was ", input_min_vec.dim_size(0)));
75-
OP_REQUIRES(
76-
ctx, input_max_vec.dim_size(0) == depth,
77-
errors::InvalidArgument("input_max has incorrect size, expected ",
78-
depth, " was ", input_max_vec.dim_size(0)));
79-
80-
if (out_type_ == DT_QINT8) DCHECK(input_requested_min_float < 0.0f);
84+
if (out_type_ == DT_QINT8) {
85+
OP_REQUIRES(ctx, input_requested_min_float < 0.0f,
86+
errors::InvalidArgument(
87+
"If out_type is QINT8, requested_output_max must be "
88+
"non negative, got ",
89+
input_requested_min_float));
90+
}
8191

8292
const float factor = (out_type_ == DT_QINT8) ? 127.0f : 255.0f;
8393
const float requested_min_max =

0 commit comments

Comments
 (0)