@@ -173,17 +173,6 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) {
173173 return Status::OK ();
174174}
175175
176- namespace {
177- Status CheckKnownDim (shape_inference::InferenceContext* c, const Dimension* dim,
178- const char * name) {
179- if (!c->ValueKnown (dim)) {
180- return errors::InvalidArgument (" Cannot infer shape because dimension " ,
181- name, " is not known." );
182- }
183- return Status::OK ();
184- }
185- } // namespace
186-
187176Status Conv2DShape (shape_inference::InferenceContext* c) {
188177 const Shape* input_shape;
189178 TF_RETURN_IF_ERROR (c->WithRank (c->input (0 ), 4 , &input_shape));
@@ -224,10 +213,10 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
224213 const Dimension* output_depth_dim = c->Dim (filter_shape, 3 );
225214
226215 // At the moment we need to know the values of several fields.
227- TF_RETURN_IF_ERROR (CheckKnownDim (c, in_rows_dim, " in_rows" ));
228- TF_RETURN_IF_ERROR (CheckKnownDim (c, in_cols_dim, " in_cols" ));
229- TF_RETURN_IF_ERROR (CheckKnownDim (c, filter_rows_dim, " filter_rows" ));
230- TF_RETURN_IF_ERROR (CheckKnownDim (c, filter_cols_dim, " filter_cols" ));
216+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( in_rows_dim, " in_rows" ));
217+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( in_cols_dim, " in_cols" ));
218+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( filter_rows_dim, " filter_rows" ));
219+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( filter_cols_dim, " filter_cols" ));
231220
232221 auto in_rows = c->Value (in_rows_dim);
233222 auto in_cols = c->Value (in_cols_dim);
@@ -292,12 +281,12 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
292281 const Dimension* output_depth_dim = c->Dim (filter_shape, 4 );
293282
294283 // At the moment we need to know the values of several fields.
295- TF_RETURN_IF_ERROR (CheckKnownDim (c, in_planes_dim, " in_planes" ));
296- TF_RETURN_IF_ERROR (CheckKnownDim (c, in_rows_dim, " in_rows" ));
297- TF_RETURN_IF_ERROR (CheckKnownDim (c, in_cols_dim, " in_cols" ));
298- TF_RETURN_IF_ERROR (CheckKnownDim (c, filter_planes_dim, " filter_planes" ));
299- TF_RETURN_IF_ERROR (CheckKnownDim (c, filter_rows_dim, " filter_rows" ));
300- TF_RETURN_IF_ERROR (CheckKnownDim (c, filter_cols_dim, " filter_cols" ));
284+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( in_planes_dim, " in_planes" ));
285+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( in_rows_dim, " in_rows" ));
286+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( in_cols_dim, " in_cols" ));
287+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( filter_planes_dim, " filter_planes" ));
288+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( filter_rows_dim, " filter_rows" ));
289+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( filter_cols_dim, " filter_cols" ));
301290
302291 auto in_planes = c->Value (in_planes_dim);
303292 auto in_rows = c->Value (in_rows_dim);
@@ -357,12 +346,12 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
357346 const Dimension* depth_multiplier = c->Dim (filter_shape, 3 );
358347
359348 // At the moment we need to know the values of several fields.
360- TF_RETURN_IF_ERROR (CheckKnownDim (c, in_rows_dim, " in_rows" ));
361- TF_RETURN_IF_ERROR (CheckKnownDim (c, in_cols_dim, " in_cols" ));
362- TF_RETURN_IF_ERROR (CheckKnownDim (c, filter_rows_dim, " filter_rows" ));
363- TF_RETURN_IF_ERROR (CheckKnownDim (c, filter_cols_dim, " filter_cols" ));
364- TF_RETURN_IF_ERROR (CheckKnownDim (c, input_depth, " depth" ));
365- TF_RETURN_IF_ERROR (CheckKnownDim (c, depth_multiplier, " depth_multiplier" ));
349+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( in_rows_dim, " in_rows" ));
350+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( in_cols_dim, " in_cols" ));
351+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( filter_rows_dim, " filter_rows" ));
352+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( filter_cols_dim, " filter_cols" ));
353+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( input_depth, " depth" ));
354+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( depth_multiplier, " depth_multiplier" ));
366355
367356 // Check that the input depths are compatible.
368357 TF_RETURN_IF_ERROR (
@@ -449,8 +438,8 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
449438 const Dimension* output_depth_dim = c->Dim (input_shape, 3 );
450439
451440 // At the moment we need to know the values of several fields.
452- TF_RETURN_IF_ERROR (CheckKnownDim (c, in_rows_dim, " in_rows" ));
453- TF_RETURN_IF_ERROR (CheckKnownDim (c, in_cols_dim, " in_cols" ));
441+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( in_rows_dim, " in_rows" ));
442+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( in_cols_dim, " in_cols" ));
454443
455444 Padding padding;
456445 TF_RETURN_IF_ERROR (c->GetAttr (" padding" , &padding));
@@ -536,9 +525,9 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) {
536525 const Dimension* in_depth_dim = c->Dim (input_shape, 3 );
537526
538527 // At the moment we need to know the values of several fields.
539- TF_RETURN_IF_ERROR (CheckKnownDim (c, in_rows_dim, " in_rows" ));
540- TF_RETURN_IF_ERROR (CheckKnownDim (c, in_cols_dim, " in_cols" ));
541- TF_RETURN_IF_ERROR (CheckKnownDim (c, in_depth_dim, " in_depth" ));
528+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( in_rows_dim, " in_rows" ));
529+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( in_cols_dim, " in_cols" ));
530+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( in_depth_dim, " in_depth" ));
542531
543532 Padding padding;
544533 TF_RETURN_IF_ERROR (c->GetAttr (" padding" , &padding));
@@ -614,9 +603,9 @@ Status Pool3DShape(shape_inference::InferenceContext* c) {
614603 const Dimension* output_depth_dim = c->Dim (input_shape, 4 );
615604
616605 // At the moment we need to know the values of several fields.
617- TF_RETURN_IF_ERROR (CheckKnownDim (c, in_planes_dim, " in_planes" ));
618- TF_RETURN_IF_ERROR (CheckKnownDim (c, in_rows_dim, " in_rows" ));
619- TF_RETURN_IF_ERROR (CheckKnownDim (c, in_cols_dim, " in_cols" ));
606+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( in_planes_dim, " in_planes" ));
607+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( in_rows_dim, " in_rows" ));
608+ TF_RETURN_IF_ERROR (c-> ValidateKnownDim ( in_cols_dim, " in_cols" ));
620609
621610 Padding padding;
622611 TF_RETURN_IF_ERROR (c->GetAttr (" padding" , &padding));
0 commit comments