Skip to content

Commit 3218043

Browse files
isharktensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 411896058 Change-Id: Ia031058247e3cf382957a6662d3f9e1cbb481ca2
1 parent 05e7d51 commit 3218043

File tree

4 files changed

+83
-17
lines changed

4 files changed

+83
-17
lines changed

tensorflow/core/grappler/costs/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ tf_cc_test(
355355
"//tensorflow/core:protos_all_cc",
356356
"//tensorflow/core:test",
357357
"//tensorflow/core:test_main",
358+
"//tensorflow/core/platform:status_matchers",
358359
],
359360
)
360361

tensorflow/core/grappler/costs/op_level_cost_estimator.cc

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2153,7 +2153,7 @@ OpInfo::TensorProperties OpLevelCostEstimator::DescribeTensor(
21532153
}
21542154

21552155
/* static */
2156-
OpLevelCostEstimator::ConvolutionDimensions
2156+
StatusOr<OpLevelCostEstimator::ConvolutionDimensions>
21572157
OpLevelCostEstimator::OpDimensionsFromInputs(
21582158
const TensorShapeProto& original_image_shape, const OpInfo& op_info,
21592159
bool* found_unknown_shapes) {
@@ -2190,6 +2190,11 @@ OpLevelCostEstimator::OpDimensionsFromInputs(
21902190
std::vector<int64_t> strides = GetStrides(op_info);
21912191
int64_t sx = strides[x_index];
21922192
int64_t sy = strides[y_index];
2193+
if (sx == 0 || sy == 0) {
2194+
return errors::InvalidArgument(
2195+
"Stride must be > 0 for Height and Width, but got (", sy, ", ", sx,
2196+
")");
2197+
}
21932198
const auto padding = GetPadding(op_info);
21942199

21952200
int64_t ox = GetOutputSize(ix, kx, sx, padding);
@@ -2206,8 +2211,9 @@ Status OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context,
22062211
bool found_unknown_shapes = false;
22072212
const auto& op_info = op_context.op_info;
22082213
// x: op_info.inputs(0)
2209-
ConvolutionDimensions dims = OpDimensionsFromInputs(
2210-
op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2214+
TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
2215+
OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
2216+
&found_unknown_shapes));
22112217
// kx * ky - 1 comparisons per output (kx * xy > 1)
22122218
// or 1 copy per output (kx * k1 = 1).
22132219
int per_output_ops = dims.kx * dims.ky == 1 ? 1 : dims.kx * dims.ky - 1;
@@ -2248,8 +2254,9 @@ Status OpLevelCostEstimator::PredictMaxPoolGrad(const OpContext& op_context,
22482254
op_info.ShortDebugString());
22492255
}
22502256

2251-
ConvolutionDimensions dims = OpDimensionsFromInputs(
2252-
op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2257+
TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
2258+
OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
2259+
&found_unknown_shapes));
22532260

22542261
int64_t ops = 0;
22552262
if (dims.kx == 1 && dims.ky == 1) {
@@ -2324,8 +2331,9 @@ Status OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context,
23242331
bool found_unknown_shapes = false;
23252332
const auto& op_info = op_context.op_info;
23262333
// x: op_info.inputs(0)
2327-
ConvolutionDimensions dims = OpDimensionsFromInputs(
2328-
op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2334+
TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
2335+
OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
2336+
&found_unknown_shapes));
23292337

23302338
// kx * ky - 1 additions and 1 multiplication per output.
23312339
int64_t ops = dims.batch * dims.ox * dims.oy * dims.oz * dims.kx * dims.ky;
@@ -2382,8 +2390,9 @@ Status OpLevelCostEstimator::PredictAvgPoolGrad(const OpContext& op_context,
23822390
found_unknown_shapes = true;
23832391
}
23842392

2385-
ConvolutionDimensions dims =
2386-
OpDimensionsFromInputs(x_shape, op_info, &found_unknown_shapes);
2393+
TF_ASSIGN_OR_RETURN(
2394+
ConvolutionDimensions dims,
2395+
OpDimensionsFromInputs(x_shape, op_info, &found_unknown_shapes));
23872396

23882397
int64_t ops = 0;
23892398
if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
@@ -2409,8 +2418,9 @@ Status OpLevelCostEstimator::PredictFusedBatchNorm(
24092418
// offset: op_info.inputs(2)
24102419
// mean: op_info.inputs(3) --> only for inference
24112420
// variance: op_info.inputs(4) --> only for inference
2412-
ConvolutionDimensions dims = OpDimensionsFromInputs(
2413-
op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
2421+
TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
2422+
OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
2423+
&found_unknown_shapes));
24142424
const bool is_training = IsTraining(op_info);
24152425

24162426
int64_t ops = 0;
@@ -2459,8 +2469,9 @@ Status OpLevelCostEstimator::PredictFusedBatchNormGrad(
24592469
// scale: op_info.inputs(2)
24602470
// mean: op_info.inputs(3)
24612471
// variance or inverse of variance: op_info.inputs(4)
2462-
ConvolutionDimensions dims = OpDimensionsFromInputs(
2463-
op_info.inputs(1).shape(), op_info, &found_unknown_shapes);
2472+
TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
2473+
OpDimensionsFromInputs(op_info.inputs(1).shape(), op_info,
2474+
&found_unknown_shapes));
24642475

24652476
int64_t ops = 0;
24662477
const auto rsqrt_cost = Eigen::internal::functor_traits<

tensorflow/core/grappler/costs/op_level_cost_estimator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ class OpLevelCostEstimator {
290290
bool* found_unknown_shapes);
291291

292292
// For Pooling, FusedBatchNorm, and their grad ops.
293-
static ConvolutionDimensions OpDimensionsFromInputs(
293+
static StatusOr<ConvolutionDimensions> OpDimensionsFromInputs(
294294
const TensorShapeProto& original_image_shape, const OpInfo& op_info,
295295
bool* found_unknown_shapes);
296296

tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include "tensorflow/core/framework/tensor_shape.h"
2525
#include "tensorflow/core/framework/tensor_shape.pb.h"
2626
#include "tensorflow/core/framework/types.h"
27+
#include "tensorflow/core/platform/status_matchers.h"
2728
#include "tensorflow/core/platform/test.h"
2829
#include "tensorflow/core/protobuf/device_properties.pb.h"
2930

@@ -558,9 +559,10 @@ class OpLevelCostEstimatorTest : public ::testing::Test {
558559
}
559560

560561
bool found_unknown_shapes;
561-
auto dims = OpLevelCostEstimator::OpDimensionsFromInputs(
562-
op_context.op_info.inputs(0).shape(), op_context.op_info,
563-
&found_unknown_shapes);
562+
TF_ASSERT_OK_AND_ASSIGN(
563+
auto dims, OpLevelCostEstimator::OpDimensionsFromInputs(
564+
op_context.op_info.inputs(0).shape(), op_context.op_info,
565+
&found_unknown_shapes));
564566
Padding padding_enum;
565567
if (padding == "VALID") {
566568
padding_enum = Padding::VALID;
@@ -581,6 +583,38 @@ class OpLevelCostEstimatorTest : public ::testing::Test {
581583
EXPECT_EQ(padding_enum, dims.padding);
582584
}
583585

586+
StatusOr<OpLevelCostEstimator::ConvolutionDimensions>
587+
CallOpDimensionsFromInputs(const int n, const int h, const int w, const int c,
588+
const int kx, const int ky, const int sx,
589+
const int sy, const string& data_format,
590+
const string& padding) {
591+
OpContext op_context;
592+
593+
const std::vector<int> x = {n, h, w, c};
594+
const std::vector<int> ksize = {1, kx, ky, 1};
595+
std::vector<int> strides;
596+
if (data_format == "NHWC") {
597+
strides = {1, sy, sx, 1};
598+
} else {
599+
strides = {1, 1, sy, sx};
600+
}
601+
602+
auto& op_info = op_context.op_info;
603+
SetCpuDevice(&op_info);
604+
op_info.set_op("MaxPool");
605+
606+
DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
607+
auto* attr = op_info.mutable_attr();
608+
SetAttrValue(data_format, &(*attr)["data_format"]);
609+
SetAttrValue(padding, &(*attr)["padding"]);
610+
SetAttrValue(strides, &(*attr)["strides"]);
611+
SetAttrValue(ksize, &(*attr)["ksize"]);
612+
bool found_unknown_shapes;
613+
return OpLevelCostEstimator::OpDimensionsFromInputs(
614+
op_context.op_info.inputs(0).shape(), op_context.op_info,
615+
&found_unknown_shapes);
616+
}
617+
584618
OpLevelCostEstimator estimator_;
585619
};
586620

@@ -1383,6 +1417,26 @@ TEST_F(OpLevelCostEstimatorTest, OpDimensionsFromInputs) {
13831417
}
13841418
}
13851419

1420+
TEST_F(OpLevelCostEstimatorTest, OpDimensionsFromInputsError) {
1421+
std::vector<string> paddings = {"VALID", "SAME"};
1422+
std::vector<string> formats = {"NHWC", "NCHW"};
1423+
for (const auto& p : paddings) {
1424+
for (const auto& f : formats) {
1425+
// n, h, w, c, kx, ky, sx, sy, data_format, padding.
1426+
ASSERT_THAT(
1427+
CallOpDimensionsFromInputs(10, 14, 14, 3840, 3, 3, 0, 2, f, p),
1428+
testing::StatusIs(
1429+
error::INVALID_ARGUMENT,
1430+
"Stride must be > 0 for Height and Width, but got (2, 0)"));
1431+
ASSERT_THAT(
1432+
CallOpDimensionsFromInputs(10, 14, 14, 3840, 3, 3, 2, 0, f, p),
1433+
testing::StatusIs(
1434+
error::INVALID_ARGUMENT,
1435+
"Stride must be > 0 for Height and Width, but got (0, 2)"));
1436+
}
1437+
}
1438+
}
1439+
13861440
TEST_F(OpLevelCostEstimatorTest, PredictMaxPool) {
13871441
auto predict_max_pool = [this](const int n, const int in, const int c,
13881442
const int k, const int s,

0 commit comments

Comments
 (0)