@@ -2393,6 +2393,66 @@ REGISTER_OP("SpaceToBatch")
23932393 .Output(" output: T" )
23942394 .Attr(" T: type" )
23952395 .Attr(" block_size: int >= 2" )
2396+ .SetShapeFn([](InferenceContext* c) {
2397+ const Shape* input;
2398+ TF_RETURN_IF_ERROR (c->WithRank (c->input (0 ), 4 , &input));
2399+
2400+ const Shape* paddings;
2401+ TF_RETURN_IF_ERROR (c->WithRank (c->input (1 ), 2 , &paddings));
2402+
2403+ const Dimension* pad0_dim = c->Dim (paddings, 0 );
2404+ const Dimension* pad1_dim = c->Dim (paddings, 1 );
2405+
2406+ if (!c->ValueKnown (pad0_dim) || !c->ValueKnown (pad1_dim)) {
2407+ return shape_inference::UnknownShape (c);
2408+ }
2409+
2410+ int64 pad0 = c->Value (pad0_dim);
2411+ int64 pad1 = c->Value (pad1_dim);
2412+ if (pad0 != 2 || pad1 != 2 ) {
2413+ return errors::InvalidArgument (
2414+ " SpaceToBatch requires paddings with shape [2,2]." );
2415+ }
2416+
2417+ int32 block_size;
2418+ TF_RETURN_IF_ERROR (c->GetAttr (" block_size" , &block_size));
2419+
2420+ const Dimension* output_height;
2421+ const Dimension* output_width;
2422+
2423+ const Tensor* paddings_t = c->input_tensor (1 );
2424+ if (paddings_t == nullptr ) {
2425+ output_height = c->UnknownDim ();
2426+ output_width = c->UnknownDim ();
2427+ } else {
2428+ auto pad_matrix = paddings_t ->matrix <int32>();
2429+ const int32 pad_top = pad_matrix (0 , 0 );
2430+ const int32 pad_bottom = pad_matrix (0 , 1 );
2431+ const int32 pad_left = pad_matrix (1 , 0 );
2432+ const int32 pad_right = pad_matrix (1 , 1 );
2433+
2434+ if (pad_top < 0 || pad_bottom < 0 || pad_left < 0 || pad_right < 0 ) {
2435+ return errors::InvalidArgument (" Paddings cannot be negative." );
2436+ }
2437+
2438+ TF_RETURN_IF_ERROR (
2439+ c->Add (c->Dim (input, 1 ), pad_top + pad_bottom, &output_height));
2440+ TF_RETURN_IF_ERROR (
2441+ c->Add (c->Dim (input, 2 ), pad_left + pad_right, &output_width));
2442+ }
2443+
2444+ const Dimension* batch;
2445+ TF_RETURN_IF_ERROR (
2446+ c->Multiply (c->Dim (input, 0 ), block_size * block_size, &batch));
2447+
2448+ // Will return an error if block_size does not evenly divide.
2449+ TF_RETURN_IF_ERROR (c->Divide (output_height, block_size, &output_height));
2450+ TF_RETURN_IF_ERROR (c->Divide (output_width, block_size, &output_width));
2451+
2452+ c->set_output (0 , c->MakeShape ({batch, output_height, output_width,
2453+ c->Dim (input, 3 )}));
2454+ return Status::OK ();
2455+ })
23962456 .Doc(R"doc(
23972457SpaceToBatch for 4-D tensors of type T.
23982458
@@ -2498,6 +2558,69 @@ REGISTER_OP("BatchToSpace")
24982558 .Output(" output: T" )
24992559 .Attr(" T: type" )
25002560 .Attr(" block_size: int >= 2" )
2561+ .SetShapeFn([](InferenceContext* c) {
2562+ const Shape* input;
2563+ TF_RETURN_IF_ERROR (c->WithRank (c->input (0 ), 4 , &input));
2564+
2565+ const Shape* crops;
2566+ TF_RETURN_IF_ERROR (c->WithRank (c->input (1 ), 2 , &crops));
2567+
2568+ const Dimension* crops0_dim = c->Dim (crops, 0 );
2569+ const Dimension* crops1_dim = c->Dim (crops, 1 );
2570+
2571+ if (!c->ValueKnown (crops0_dim) || !c->ValueKnown (crops1_dim)) {
2572+ return shape_inference::UnknownShape (c);
2573+ }
2574+
2575+ int64 crops0 = c->Value (crops0_dim);
2576+ int64 crops1 = c->Value (crops1_dim);
2577+ if (crops0 != 2 || crops1 != 2 ) {
2578+ return errors::InvalidArgument (
2579+ " BatchToSpace requires crops with shape [2,2]." );
2580+ }
2581+
2582+ int32 block_size;
2583+ TF_RETURN_IF_ERROR (c->GetAttr (" block_size" , &block_size));
2584+
2585+ const Dimension* batch;
2586+ // Will return an error if does not evenly divide
2587+ TF_RETURN_IF_ERROR (
2588+ c->Divide (c->Dim (input, 0 ), block_size * block_size, &batch));
2589+
2590+ const Dimension* output_height;
2591+ const Dimension* output_width;
2592+
2593+ const Tensor* crops_t = c->input_tensor (1 );
2594+ if (crops_t == nullptr ) {
2595+ output_height = c->UnknownDim ();
2596+ output_width = c->UnknownDim ();
2597+ } else {
2598+ auto crops_matrix = crops_t ->matrix <int32>();
2599+ const int32 crops_top = crops_matrix (0 , 0 );
2600+ const int32 crops_bottom = crops_matrix (0 , 1 );
2601+ const int32 crops_left = crops_matrix (1 , 0 );
2602+ const int32 crops_right = crops_matrix (1 , 1 );
2603+
2604+ if (crops_top < 0 || crops_bottom < 0 || crops_left < 0 ||
2605+ crops_right < 0 ) {
2606+ return errors::InvalidArgument (" Croppings cannot be negative." );
2607+ }
2608+
2609+ TF_RETURN_IF_ERROR (
2610+ c->Multiply (c->Dim (input, 1 ), block_size, &output_height));
2611+ TF_RETURN_IF_ERROR (c->Subtract (
2612+ output_height, (crops_top + crops_bottom), &output_height));
2613+
2614+ TF_RETURN_IF_ERROR (
2615+ c->Multiply (c->Dim (input, 2 ), block_size, &output_width));
2616+ TF_RETURN_IF_ERROR (c->Subtract (output_width, (crops_left + crops_right),
2617+ &output_width));
2618+ }
2619+
2620+ c->set_output (0 , c->MakeShape ({batch, output_height, output_width,
2621+ c->Dim (input, 3 )}));
2622+ return Status::OK ();
2623+ })
25012624 .Doc(R"doc(
25022625BatchToSpace for 4-D tensors of type T.
25032626
@@ -2593,6 +2716,29 @@ REGISTER_OP("SpaceToDepth")
25932716 .Output(" output: T" )
25942717 .Attr(" T: type" )
25952718 .Attr(" block_size: int >= 2" )
2719+ .SetShapeFn([](InferenceContext* c) {
2720+ const Shape* input;
2721+ TF_RETURN_IF_ERROR (c->WithRank (c->input (0 ), 4 , &input));
2722+
2723+ int32 block_size;
2724+ TF_RETURN_IF_ERROR (c->GetAttr (" block_size" , &block_size));
2725+
2726+ const Dimension* output_height;
2727+ const Dimension* output_width;
2728+ const Dimension* output_depth;
2729+ // Will return an error if does not evenly divide
2730+ TF_RETURN_IF_ERROR (
2731+ c->Divide (c->Dim (input, 1 ), block_size, &output_height));
2732+ TF_RETURN_IF_ERROR (
2733+ c->Divide (c->Dim (input, 2 ), block_size, &output_width));
2734+
2735+ TF_RETURN_IF_ERROR (c->Multiply (c->Dim (input, 3 ), block_size * block_size,
2736+ &output_depth));
2737+
2738+ c->set_output (0 , c->MakeShape ({c->Dim (input, 0 ), output_height,
2739+ output_width, output_depth}));
2740+ return Status::OK ();
2741+ })
25962742 .Doc(R"doc(
25972743SpaceToDepth for tensors of type T.
25982744
@@ -2677,6 +2823,27 @@ REGISTER_OP("DepthToSpace")
26772823 .Output(" output: T" )
26782824 .Attr(" T: type" )
26792825 .Attr(" block_size: int >= 2" )
2826+ .SetShapeFn([](InferenceContext* c) {
2827+ const Shape* input;
2828+ TF_RETURN_IF_ERROR (c->WithRank (c->input (0 ), 4 , &input));
2829+
2830+ int32 block_size;
2831+ TF_RETURN_IF_ERROR (c->GetAttr (" block_size" , &block_size));
2832+
2833+ const Dimension* output_height;
2834+ const Dimension* output_width;
2835+ const Dimension* output_depth;
2836+ TF_RETURN_IF_ERROR (
2837+ c->Multiply (c->Dim (input, 1 ), block_size, &output_height));
2838+ TF_RETURN_IF_ERROR (
2839+ c->Multiply (c->Dim (input, 2 ), block_size, &output_width));
2840+ TF_RETURN_IF_ERROR (
2841+ c->Divide (c->Dim (input, 3 ), block_size * block_size, &output_depth));
2842+
2843+ c->set_output (0 , c->MakeShape ({c->Dim (input, 0 ), output_height,
2844+ output_width, output_depth}));
2845+ return Status::OK ();
2846+ })
26802847 .Doc(R"doc(
26812848DepthToSpace for tensors of type T.
26822849
0 commit comments