Skip to content

Commit 9e62869

Browse files
Add more validation to RequantizationRangePerChannel.
PiperOrigin-RevId: 387693946 Change-Id: Ife8dcbdb021bec4787eef6a4361dd08f17c14bd6
1 parent e2c9d55 commit 9e62869

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

tensorflow/core/kernels/mkl/mkl_requantization_range_per_channel_op.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,20 @@ class MklRequantizationRangePerChannelOp : public OpKernel {
5757
ctx, input_max.dim_size(0) == depth,
5858
errors::InvalidArgument("input_max has incorrect size, expected ",
5959
depth, " was ", input_max.dim_size(0)));
60+
OP_REQUIRES(
61+
ctx, input_min.NumElements() == depth,
62+
errors::InvalidArgument("input_min must have the same number of "
63+
"elements as input_max, got ",
64+
input_min.NumElements(), " and ", depth));
65+
OP_REQUIRES(ctx, input.NumElements() > 0,
66+
errors::InvalidArgument("input must not be empty"));
67+
OP_REQUIRES(ctx, input.dims() == 4,
68+
errors::InvalidArgument("input must be in NHWC format"));
69+
OP_REQUIRES(
70+
ctx, input.dim_size(3) == depth,
71+
errors::InvalidArgument(
72+
"input must have same number of channels as length of input_min: ",
73+
input.dim_size(3), " vs ", depth));
6074

6175
const float* input_min_data = input_min.flat<float>().data();
6276
const float* input_max_data = input_max.flat<float>().data();

0 commit comments

Comments
 (0)