Skip to content

Commit 21ca7e4

Browse files
Vijay Vasudevantensorflower-gardener
authored andcommitted
TensorFlow: implement ExtractImagePatches shape fn, move Validation
of KnownDim strewn across a few files into shape_inference.h with a simple unittest. Change: 129152593
1 parent 5a828e3 commit 21ca7e4

6 files changed

Lines changed: 148 additions & 50 deletions

File tree

tensorflow/core/framework/common_shape_fns.cc

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -173,17 +173,6 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) {
173173
return Status::OK();
174174
}
175175

176-
namespace {
177-
Status CheckKnownDim(shape_inference::InferenceContext* c, const Dimension* dim,
178-
const char* name) {
179-
if (!c->ValueKnown(dim)) {
180-
return errors::InvalidArgument("Cannot infer shape because dimension ",
181-
name, " is not known.");
182-
}
183-
return Status::OK();
184-
}
185-
} // namespace
186-
187176
Status Conv2DShape(shape_inference::InferenceContext* c) {
188177
const Shape* input_shape;
189178
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
@@ -224,10 +213,10 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
224213
const Dimension* output_depth_dim = c->Dim(filter_shape, 3);
225214

226215
// At the moment we need to know the values of several fields.
227-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
228-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
229-
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows"));
230-
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols"));
216+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
217+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
218+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_rows_dim, "filter_rows"));
219+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_cols_dim, "filter_cols"));
231220

232221
auto in_rows = c->Value(in_rows_dim);
233222
auto in_cols = c->Value(in_cols_dim);
@@ -292,12 +281,12 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
292281
const Dimension* output_depth_dim = c->Dim(filter_shape, 4);
293282

294283
// At the moment we need to know the values of several fields.
295-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_planes_dim, "in_planes"));
296-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
297-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
298-
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_planes_dim, "filter_planes"));
299-
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows"));
300-
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols"));
284+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_planes_dim, "in_planes"));
285+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
286+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
287+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_planes_dim, "filter_planes"));
288+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_rows_dim, "filter_rows"));
289+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_cols_dim, "filter_cols"));
301290

302291
auto in_planes = c->Value(in_planes_dim);
303292
auto in_rows = c->Value(in_rows_dim);
@@ -357,12 +346,12 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
357346
const Dimension* depth_multiplier = c->Dim(filter_shape, 3);
358347

359348
// At the moment we need to know the values of several fields.
360-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
361-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
362-
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows"));
363-
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols"));
364-
TF_RETURN_IF_ERROR(CheckKnownDim(c, input_depth, "depth"));
365-
TF_RETURN_IF_ERROR(CheckKnownDim(c, depth_multiplier, "depth_multiplier"));
349+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
350+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
351+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_rows_dim, "filter_rows"));
352+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_cols_dim, "filter_cols"));
353+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(input_depth, "depth"));
354+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(depth_multiplier, "depth_multiplier"));
366355

367356
// Check that the input depths are compatible.
368357
TF_RETURN_IF_ERROR(
@@ -449,8 +438,8 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
449438
const Dimension* output_depth_dim = c->Dim(input_shape, 3);
450439

451440
// At the moment we need to know the values of several fields.
452-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
453-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
441+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
442+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
454443

455444
Padding padding;
456445
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
@@ -536,9 +525,9 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) {
536525
const Dimension* in_depth_dim = c->Dim(input_shape, 3);
537526

538527
// At the moment we need to know the values of several fields.
539-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
540-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
541-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_depth_dim, "in_depth"));
528+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
529+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
530+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_depth_dim, "in_depth"));
542531

543532
Padding padding;
544533
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
@@ -614,9 +603,9 @@ Status Pool3DShape(shape_inference::InferenceContext* c) {
614603
const Dimension* output_depth_dim = c->Dim(input_shape, 4);
615604

616605
// At the moment we need to know the values of several fields.
617-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_planes_dim, "in_planes"));
618-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
619-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
606+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_planes_dim, "in_planes"));
607+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
608+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
620609

621610
Padding padding;
622611
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));

tensorflow/core/framework/shape_inference.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,16 @@ class InferenceContext {
261261

262262
Status construction_status() const { return construction_status_; }
263263

264+
// Validates that 'dim' has a known value, and prints an error
265+
// message containing 'name' if validation fails.
266+
Status ValidateKnownDim(const Dimension* dim, const char* name) {
267+
if (!ValueKnown(dim)) {
268+
return errors::InvalidArgument("Cannot infer shape because dimension ",
269+
name, " is not known.");
270+
}
271+
return Status::OK();
272+
}
273+
264274
private:
265275
const Dimension* GetDimension(const DimensionOrConstant& d);
266276

tensorflow/core/framework/shape_inference_test.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,5 +887,13 @@ TEST(ShapeInferenceTest, FullyDefined) {
887887
EXPECT_TRUE(c.FullyDefined(c.Scalar()));
888888
}
889889

890+
TEST(ShapeInferenceTest, ValidateKnownDim) {
891+
NodeDef def;
892+
InferenceContext c(&def, MakeOpDef(0, 2), {}, {});
893+
894+
EXPECT_FALSE(c.ValidateKnownDim(c.UnknownDim(), "unknown").ok());
895+
EXPECT_TRUE(c.ValidateKnownDim(c.Dim(c.Matrix(1, 2), 0), "known").ok());
896+
}
897+
890898
} // namespace shape_inference
891899
} // namespace tensorflow

tensorflow/core/ops/array_ops.cc

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2771,6 +2771,76 @@ REGISTER_OP("ExtractImagePatches")
27712771
.Attr("rates: list(int) >= 4")
27722772
.Attr("T: realnumbertype")
27732773
.Attr(GetPaddingAttrString())
2774+
.SetShapeFn([](InferenceContext* c) {
2775+
const Shape* input_shape;
2776+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2777+
2778+
std::vector<int32> ksizes;
2779+
TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
2780+
if (ksizes.size() != 4) {
2781+
return errors::InvalidArgument(
2782+
"ExtractImagePatches requires the ksizes attribute to contain 4 "
2783+
"values, but got: ",
2784+
ksizes.size());
2785+
}
2786+
2787+
std::vector<int32> strides;
2788+
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
2789+
if (strides.size() != 4) {
2790+
return errors::InvalidArgument(
2791+
"ExtractImagePatches requires the stride attribute to contain 4 "
2792+
"values, but got: ",
2793+
strides.size());
2794+
}
2795+
2796+
std::vector<int32> rates;
2797+
TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
2798+
if (rates.size() != 4) {
2799+
return errors::InvalidArgument(
2800+
"ExtractImagePatches requires the rates attribute to contain 4 "
2801+
"values, but got: ",
2802+
rates.size());
2803+
}
2804+
2805+
int32 ksize_rows = ksizes[1];
2806+
int32 ksize_cols = ksizes[2];
2807+
2808+
int32 stride_rows = strides[1];
2809+
int32 stride_cols = strides[2];
2810+
2811+
int32 rate_rows = rates[1];
2812+
int32 rate_cols = rates[2];
2813+
2814+
int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
2815+
int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
2816+
2817+
const Dimension* batch_size_dim = c->Dim(input_shape, 0);
2818+
const Dimension* in_rows_dim = c->Dim(input_shape, 1);
2819+
const Dimension* in_cols_dim = c->Dim(input_shape, 2);
2820+
const Dimension* output_depth_dim = c->Dim(input_shape, 3);
2821+
2822+
// At the moment we need to know the values of several fields.
2823+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
2824+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
2825+
auto in_rows = c->Value(in_rows_dim);
2826+
auto in_cols = c->Value(in_cols_dim);
2827+
2828+
Padding padding;
2829+
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
2830+
2831+
int64 output_rows, output_cols;
2832+
int64 padding_before, padding_after;
2833+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2834+
in_rows, ksize_rows_eff, stride_rows, padding, &output_rows,
2835+
&padding_before, &padding_after));
2836+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
2837+
in_cols, ksize_cols_eff, stride_cols, padding, &output_cols,
2838+
&padding_before, &padding_after));
2839+
const Shape* output_shape = c->MakeShape(
2840+
{batch_size_dim, output_rows, output_cols, output_depth_dim});
2841+
c->set_output(0, output_shape);
2842+
return Status::OK();
2843+
})
27742844
.Doc(R"doc(
27752845
Extract `patches` from `images` and put them in the "depth" output dimension.
27762846

tensorflow/core/ops/array_ops_test.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,4 +877,36 @@ TEST(ArrayOpsTest, OneHot_ShapeFn) {
877877
INFER_OK(op, "[1,3,4];[];?;?", "[d0_0,d0_1,d0_2,2]");
878878
}
879879

880+
TEST(NNOpsTest, ExtractImagePatchesShapeTest) {
881+
ShapeInferenceTestOp op("ExtractImagePatches");
882+
auto set_op = [&op](const std::vector<int32>& ksizes,
883+
const std::vector<int32>& strides,
884+
const std::vector<int32>& rates, const string& padding) {
885+
TF_CHECK_OK(NodeDefBuilder("test", "ExtractImagePatches")
886+
.Input("input", 0, DT_FLOAT)
887+
.Attr("ksizes", ksizes)
888+
.Attr("strides", strides)
889+
.Attr("rates", rates)
890+
.Attr("padding", padding)
891+
.Finalize(&op.node_def));
892+
};
893+
894+
// Just tests that the ksize calculation with rates works. Most of
895+
// the other code is boilerplate that is tested by a variety of
896+
// other ops.
897+
//
898+
// ksizes is 2x2. rate rows and cols is 2, so ksize_rows and
899+
// cols are changed to be 2 + (2 - 1) = 3. 7x7 input with 3x3
900+
// filter and 1x1 stride gives a 5x5 output.
901+
set_op({1, 2, 2, 1}, {1, 1, 1, 1}, {1, 2, 2, 1}, "VALID");
902+
INFER_OK(op, "[1,7,7,2]", "[d0_0,5,5,d0_3]");
903+
904+
// Bad ksize rank
905+
set_op({1, 2, 2, 1, 1}, {1, 1, 1, 1}, {1, 2, 2, 1}, "VALID");
906+
INFER_ERROR(
907+
"ExtractImagePatches requires the ksizes attribute to contain 4 values, "
908+
"but got: 5",
909+
op, "[1,7,7,2]");
910+
}
911+
880912
} // end namespace tensorflow

tensorflow/core/ops/nn_ops.cc

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -968,17 +968,6 @@ output: Gradients w.r.t. the input of `max_pool`.
968968

969969
// --------------------------------------------------------------------------
970970

971-
namespace {
972-
Status CheckKnownDim(shape_inference::InferenceContext* c, const Dimension* dim,
973-
const char* name) {
974-
if (!c->ValueKnown(dim)) {
975-
return errors::InvalidArgument("Cannot infer shape because dimension ",
976-
name, " is not known.");
977-
}
978-
return Status::OK();
979-
}
980-
} // namespace
981-
982971
REGISTER_OP("Dilation2D")
983972
.Input("input: T")
984973
.Input("filter: T")
@@ -1029,10 +1018,10 @@ REGISTER_OP("Dilation2D")
10291018
c->Merge(c->Dim(input_shape, 3), output_depth_dim, &unused));
10301019

10311020
// At the moment we need to know the values of several fields.
1032-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
1033-
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
1034-
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows"));
1035-
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols"));
1021+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
1022+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
1023+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_rows_dim, "filter_rows"));
1024+
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_cols_dim, "filter_cols"));
10361025

10371026
auto in_rows = c->Value(in_rows_dim);
10381027
auto in_cols = c->Value(in_cols_dim);

0 commit comments

Comments
 (0)