Skip to content

Commit f00aefd

Browse files
Vijay Vasudevantensorflower-gardener
authored andcommitted
TensorFlow: implement C++ shape functions for Batch/DepthToSpace and
SpaceToBatch/Depth. Change array_ops_test to use TF_ASSERT_OK instead of TF_CHECK_OK. Change: 129667873
1 parent 88686b3 commit f00aefd

2 files changed

Lines changed: 390 additions & 99 deletions

File tree

tensorflow/core/ops/array_ops.cc

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2393,6 +2393,66 @@ REGISTER_OP("SpaceToBatch")
23932393
.Output("output: T")
23942394
.Attr("T: type")
23952395
.Attr("block_size: int >= 2")
2396+
.SetShapeFn([](InferenceContext* c) {
2397+
const Shape* input;
2398+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
2399+
2400+
const Shape* paddings;
2401+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings));
2402+
2403+
const Dimension* pad0_dim = c->Dim(paddings, 0);
2404+
const Dimension* pad1_dim = c->Dim(paddings, 1);
2405+
2406+
if (!c->ValueKnown(pad0_dim) || !c->ValueKnown(pad1_dim)) {
2407+
return shape_inference::UnknownShape(c);
2408+
}
2409+
2410+
int64 pad0 = c->Value(pad0_dim);
2411+
int64 pad1 = c->Value(pad1_dim);
2412+
if (pad0 != 2 || pad1 != 2) {
2413+
return errors::InvalidArgument(
2414+
"SpaceToBatch requires paddings with shape [2,2].");
2415+
}
2416+
2417+
int32 block_size;
2418+
TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2419+
2420+
const Dimension* output_height;
2421+
const Dimension* output_width;
2422+
2423+
const Tensor* paddings_t = c->input_tensor(1);
2424+
if (paddings_t == nullptr) {
2425+
output_height = c->UnknownDim();
2426+
output_width = c->UnknownDim();
2427+
} else {
2428+
auto pad_matrix = paddings_t->matrix<int32>();
2429+
const int32 pad_top = pad_matrix(0, 0);
2430+
const int32 pad_bottom = pad_matrix(0, 1);
2431+
const int32 pad_left = pad_matrix(1, 0);
2432+
const int32 pad_right = pad_matrix(1, 1);
2433+
2434+
if (pad_top < 0 || pad_bottom < 0 || pad_left < 0 || pad_right < 0) {
2435+
return errors::InvalidArgument("Paddings cannot be negative.");
2436+
}
2437+
2438+
TF_RETURN_IF_ERROR(
2439+
c->Add(c->Dim(input, 1), pad_top + pad_bottom, &output_height));
2440+
TF_RETURN_IF_ERROR(
2441+
c->Add(c->Dim(input, 2), pad_left + pad_right, &output_width));
2442+
}
2443+
2444+
const Dimension* batch;
2445+
TF_RETURN_IF_ERROR(
2446+
c->Multiply(c->Dim(input, 0), block_size * block_size, &batch));
2447+
2448+
// Will return an error if block_size does not evenly divide.
2449+
TF_RETURN_IF_ERROR(c->Divide(output_height, block_size, &output_height));
2450+
TF_RETURN_IF_ERROR(c->Divide(output_width, block_size, &output_width));
2451+
2452+
c->set_output(0, c->MakeShape({batch, output_height, output_width,
2453+
c->Dim(input, 3)}));
2454+
return Status::OK();
2455+
})
23962456
.Doc(R"doc(
23972457
SpaceToBatch for 4-D tensors of type T.
23982458
@@ -2498,6 +2558,69 @@ REGISTER_OP("BatchToSpace")
24982558
.Output("output: T")
24992559
.Attr("T: type")
25002560
.Attr("block_size: int >= 2")
2561+
.SetShapeFn([](InferenceContext* c) {
2562+
const Shape* input;
2563+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
2564+
2565+
const Shape* crops;
2566+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &crops));
2567+
2568+
const Dimension* crops0_dim = c->Dim(crops, 0);
2569+
const Dimension* crops1_dim = c->Dim(crops, 1);
2570+
2571+
if (!c->ValueKnown(crops0_dim) || !c->ValueKnown(crops1_dim)) {
2572+
return shape_inference::UnknownShape(c);
2573+
}
2574+
2575+
int64 crops0 = c->Value(crops0_dim);
2576+
int64 crops1 = c->Value(crops1_dim);
2577+
if (crops0 != 2 || crops1 != 2) {
2578+
return errors::InvalidArgument(
2579+
"BatchToSpace requires crops with shape [2,2].");
2580+
}
2581+
2582+
int32 block_size;
2583+
TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2584+
2585+
const Dimension* batch;
2586+
// Will return an error if does not evenly divide
2587+
TF_RETURN_IF_ERROR(
2588+
c->Divide(c->Dim(input, 0), block_size * block_size, &batch));
2589+
2590+
const Dimension* output_height;
2591+
const Dimension* output_width;
2592+
2593+
const Tensor* crops_t = c->input_tensor(1);
2594+
if (crops_t == nullptr) {
2595+
output_height = c->UnknownDim();
2596+
output_width = c->UnknownDim();
2597+
} else {
2598+
auto crops_matrix = crops_t->matrix<int32>();
2599+
const int32 crops_top = crops_matrix(0, 0);
2600+
const int32 crops_bottom = crops_matrix(0, 1);
2601+
const int32 crops_left = crops_matrix(1, 0);
2602+
const int32 crops_right = crops_matrix(1, 1);
2603+
2604+
if (crops_top < 0 || crops_bottom < 0 || crops_left < 0 ||
2605+
crops_right < 0) {
2606+
return errors::InvalidArgument("Croppings cannot be negative.");
2607+
}
2608+
2609+
TF_RETURN_IF_ERROR(
2610+
c->Multiply(c->Dim(input, 1), block_size, &output_height));
2611+
TF_RETURN_IF_ERROR(c->Subtract(
2612+
output_height, (crops_top + crops_bottom), &output_height));
2613+
2614+
TF_RETURN_IF_ERROR(
2615+
c->Multiply(c->Dim(input, 2), block_size, &output_width));
2616+
TF_RETURN_IF_ERROR(c->Subtract(output_width, (crops_left + crops_right),
2617+
&output_width));
2618+
}
2619+
2620+
c->set_output(0, c->MakeShape({batch, output_height, output_width,
2621+
c->Dim(input, 3)}));
2622+
return Status::OK();
2623+
})
25012624
.Doc(R"doc(
25022625
BatchToSpace for 4-D tensors of type T.
25032626
@@ -2593,6 +2716,29 @@ REGISTER_OP("SpaceToDepth")
25932716
.Output("output: T")
25942717
.Attr("T: type")
25952718
.Attr("block_size: int >= 2")
2719+
.SetShapeFn([](InferenceContext* c) {
2720+
const Shape* input;
2721+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
2722+
2723+
int32 block_size;
2724+
TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2725+
2726+
const Dimension* output_height;
2727+
const Dimension* output_width;
2728+
const Dimension* output_depth;
2729+
// Will return an error if does not evenly divide
2730+
TF_RETURN_IF_ERROR(
2731+
c->Divide(c->Dim(input, 1), block_size, &output_height));
2732+
TF_RETURN_IF_ERROR(
2733+
c->Divide(c->Dim(input, 2), block_size, &output_width));
2734+
2735+
TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input, 3), block_size * block_size,
2736+
&output_depth));
2737+
2738+
c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height,
2739+
output_width, output_depth}));
2740+
return Status::OK();
2741+
})
25962742
.Doc(R"doc(
25972743
SpaceToDepth for tensors of type T.
25982744
@@ -2677,6 +2823,27 @@ REGISTER_OP("DepthToSpace")
26772823
.Output("output: T")
26782824
.Attr("T: type")
26792825
.Attr("block_size: int >= 2")
2826+
.SetShapeFn([](InferenceContext* c) {
2827+
const Shape* input;
2828+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
2829+
2830+
int32 block_size;
2831+
TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
2832+
2833+
const Dimension* output_height;
2834+
const Dimension* output_width;
2835+
const Dimension* output_depth;
2836+
TF_RETURN_IF_ERROR(
2837+
c->Multiply(c->Dim(input, 1), block_size, &output_height));
2838+
TF_RETURN_IF_ERROR(
2839+
c->Multiply(c->Dim(input, 2), block_size, &output_width));
2840+
TF_RETURN_IF_ERROR(
2841+
c->Divide(c->Dim(input, 3), block_size * block_size, &output_depth));
2842+
2843+
c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height,
2844+
output_width, output_depth}));
2845+
return Status::OK();
2846+
})
26802847
.Doc(R"doc(
26812848
DepthToSpace for tensors of type T.
26822849

0 commit comments

Comments
 (0)