@@ -49,35 +49,45 @@ class MklRequantizePerChannelOp : public OpKernel {
4949 void Compute (OpKernelContext* ctx) override {
5050 try {
5151 const Tensor& input = ctx->input (kInputTensorIndex );
52+ OP_REQUIRES (
53+ ctx, input.dims () == 4 ,
54+ errors::InvalidArgument (" Current RequantizePerChannel operator"
55+ " supports 4D tensors only." ));
56+
5257 const Tensor& input_min_vec = ctx->input (kInputMinVecIndex );
58+ size_t depth = input_min_vec.NumElements ();
5359 float * input_min_vec_data = (float *)const_cast <void *>(
5460 static_cast <const void *>(input_min_vec.flat <float >().data ()));
61+
5562 const Tensor& input_max_vec = ctx->input (kInputMaxVecIndex );
63+ OP_REQUIRES (
64+ ctx, input_max_vec.NumElements () == depth,
65+ errors::InvalidArgument (" input_max has incorrect size, expected " ,
66+ depth, " was " , input_max_vec.NumElements ()));
5667 float * input_max_vec_data = (float *)const_cast <void *>(
5768 static_cast <const void *>(input_max_vec.flat <float >().data ()));
5869
5970 const Tensor& input_requested_min = ctx->input (this ->kRequestMinIndex );
71+ OP_REQUIRES (
72+ ctx, input_requested_min.NumElements () == 1 ,
73+ errors::InvalidArgument (" requested_output_min must be a scalar" ));
6074 const float input_requested_min_float =
6175 input_requested_min.flat <float >()(0 );
76+
6277 const Tensor& input_requested_max = ctx->input (this ->kRequestMaxIndex );
78+ OP_REQUIRES (
79+ ctx, input_requested_min.NumElements () == 1 ,
80+ errors::InvalidArgument (" requested_output_max must be a scalar" ));
6381 const float input_requested_max_float =
6482 input_requested_max.flat <float >()(0 );
6583
66- size_t depth = input_min_vec.NumElements ();
67- OP_REQUIRES (
68- ctx, input.dims () == 4 ,
69- errors::InvalidArgument (" Current RequantizePerChannel operator"
70- " supports 4D tensors only." ));
71- OP_REQUIRES (
72- ctx, input_min_vec.dim_size (0 ) == depth,
73- errors::InvalidArgument (" input_min has incorrect size, expected " ,
74- depth, " was " , input_min_vec.dim_size (0 )));
75- OP_REQUIRES (
76- ctx, input_max_vec.dim_size (0 ) == depth,
77- errors::InvalidArgument (" input_max has incorrect size, expected " ,
78- depth, " was " , input_max_vec.dim_size (0 )));
79-
80- if (out_type_ == DT_QINT8) DCHECK (input_requested_min_float < 0 .0f );
84+ if (out_type_ == DT_QINT8) {
85+ OP_REQUIRES (ctx, input_requested_min_float < 0 .0f ,
86+ errors::InvalidArgument (
87+ " If out_type is QINT8, requested_output_max must be "
88+ " non negative, got " ,
89+ input_requested_min_float));
90+ }
8191
8292 const float factor = (out_type_ == DT_QINT8) ? 127 .0f : 255 .0f ;
8393 const float requested_min_max =
0 commit comments