@@ -33,15 +33,41 @@ limitations under the License.
3333#include " tensorflow/compiler/xla/util.h"
3434#include " tensorflow/core/framework/bounds_check.h"
3535#include " tensorflow/core/framework/op_kernel.h"
36+ #include " tensorflow/core/framework/op_requires.h"
3637#include " tensorflow/core/framework/register_types.h"
3738#include " tensorflow/core/framework/tensor.h"
3839#include " tensorflow/core/platform/errors.h"
3940#include " tensorflow/core/util/determinism.h"
4041#include " tensorflow/core/util/tensor_format.h"
42+ #include " tensorflow/tsl/platform/errors.h"
4143
4244namespace tensorflow {
4345namespace {
4446
47+ template <typename T>
48+ static Status ValidateKernelSizes (const T& ksizes) {
49+ for (size_t i = 0 ; i < ksizes.size (); ++i) {
50+ if (ksizes[i] <= 0 ) {
51+ return errors::InvalidArgument (
52+ " Sliding window ksize field for dimension " , i,
53+ " must be positive but is " , ksizes[i]);
54+ }
55+ }
56+ return OkStatus ();
57+ }
58+
59+ template <typename T>
60+ static Status ValidateStrides (const T& strides) {
61+ for (size_t i = 0 ; i < strides.size (); ++i) {
62+ if (strides[i] <= 0 ) {
63+ return errors::InvalidArgument (
64+ " Sliding window stride field for dimension " , i,
65+ " must be positive but is " , strides[i]);
66+ }
67+ }
68+ return OkStatus ();
69+ }
70+
4571// Superclass of pooling ops.
4672class PoolingOp : public XlaOpKernel {
4773 public:
@@ -83,50 +109,54 @@ class PoolingOp : public XlaOpKernel {
83109
84110 protected:
85111 StatusOr<std::vector<int64_t >> GetKernelSize (XlaOpKernelContext* ctx) {
86- if (ctx->num_inputs () == 1 ) {
87- return ksize_;
88- }
89- const TensorShape ksize_shape = ctx->InputShape (1 );
90- // Validate input sizes.
91- if (!TensorShapeUtils::IsVector (ksize_shape)) {
92- return errors::InvalidArgument (" ksize must be a vector, not shape " ,
93- ksize_shape.DebugString ());
94- }
95- if (ksize_shape.num_elements () != num_dims ()) {
96- return errors::InvalidArgument (
97- " Sliding window ksize field must "
98- " specify " ,
99- num_dims (), " dimensions" );
100- }
101112 std::vector<int64_t > ksize;
102- auto status = ctx->ConstantInputAsIntVector (1 , &ksize);
103- if (!status.ok ()) {
104- return status;
113+ if (ctx->num_inputs () == 1 ) {
114+ ksize = ksize_;
115+ } else {
116+ const TensorShape ksize_shape = ctx->InputShape (1 );
117+ // Validate input sizes.
118+ if (!TensorShapeUtils::IsVector (ksize_shape)) {
119+ return errors::InvalidArgument (" ksize must be a vector, not shape " ,
120+ ksize_shape.DebugString ());
121+ }
122+ if (ksize_shape.num_elements () != num_dims ()) {
123+ return errors::InvalidArgument (
124+ " Sliding window ksize field must "
125+ " specify " ,
126+ num_dims (), " dimensions" );
127+ }
128+ auto status = ctx->ConstantInputAsIntVector (1 , &ksize);
129+ if (!status.ok ()) {
130+ return status;
131+ }
105132 }
133+ TF_RETURN_IF_ERROR (ValidateKernelSizes (ksize));
106134 return ksize;
107135 }
108136
109137 StatusOr<std::vector<int64_t >> GetStride (XlaOpKernelContext* ctx) {
110- if (ctx->num_inputs () == 1 ) {
111- return stride_;
112- }
113- const TensorShape stride_shape = ctx->InputShape (2 );
114- // Validate input sizes.
115- if (!TensorShapeUtils::IsVector (stride_shape)) {
116- return errors::InvalidArgument (" stride must be a vector, not shape " ,
117- stride_shape.DebugString ());
118- }
119- if (stride_shape.num_elements () != num_dims ()) {
120- return errors::InvalidArgument (
121- " Sliding window stride field must "
122- " specify " ,
123- num_dims (), " dimensions" );
124- }
125138 std::vector<int64_t > stride;
126- auto status = ctx->ConstantInputAsIntVector (2 , &stride);
127- if (!status.ok ()) {
128- return status;
139+ if (ctx->num_inputs () == 1 ) {
140+ stride = stride_;
141+ } else {
142+ const TensorShape stride_shape = ctx->InputShape (2 );
143+ // Validate input sizes.
144+ if (!TensorShapeUtils::IsVector (stride_shape)) {
145+ return errors::InvalidArgument (" stride must be a vector, not shape " ,
146+ stride_shape.DebugString ());
147+ }
148+ if (stride_shape.num_elements () != num_dims ()) {
149+ return errors::InvalidArgument (
150+ " Sliding window stride field must "
151+ " specify " ,
152+ num_dims (), " dimensions" );
153+ }
154+ auto status = ctx->ConstantInputAsIntVector (2 , &stride);
155+ if (!status.ok ()) {
156+ return status;
157+ }
129158 }
159+ TF_RETURN_IF_ERROR (ValidateStrides (stride));
130160 return stride;
131161 }
132162
@@ -355,10 +385,12 @@ class MaxPoolGradOp : public XlaOpKernel {
355385 errors::InvalidArgument (" Sliding window ksize field must "
356386 " specify " ,
357387 num_dims (), " dimensions" ));
388+ OP_REQUIRES_OK (ctx, ValidateKernelSizes (ksize_));
358389 OP_REQUIRES (ctx, stride_.size () == num_dims (),
359390 errors::InvalidArgument (" Sliding window strides field must "
360391 " specify " ,
361392 num_dims (), " dimensions" ));
393+ OP_REQUIRES_OK (ctx, ValidateStrides (stride_));
362394
363395 const TensorShape tensor_in_shape = ctx->InputShape (0 );
364396 const TensorShape tensor_out_shape = ctx->InputShape (1 );
@@ -446,11 +478,13 @@ class AvgPoolGradOp : public XlaOpKernel {
446478 errors::InvalidArgument (" Sliding window ksize field must "
447479 " specify " ,
448480 num_dims (), " dimensions" ));
481+ OP_REQUIRES_OK (ctx, ValidateKernelSizes (ksize_));
449482 OP_REQUIRES_OK (ctx, ctx->GetAttr (" strides" , &stride_));
450483 OP_REQUIRES (ctx, stride_.size () == num_dims (),
451484 errors::InvalidArgument (" Sliding window strides field must "
452485 " specify " ,
453486 num_dims (), " dimensions" ));
487+ OP_REQUIRES_OK (ctx, ValidateStrides (stride_));
454488 OP_REQUIRES_OK (ctx, ctx->GetAttr (" padding" , &padding_));
455489 OP_REQUIRES (ctx, padding_ != EXPLICIT,
456490 errors::Unimplemented (
@@ -579,10 +613,12 @@ class MaxPoolGradGradOp : public XlaOpKernel {
579613 errors::InvalidArgument (" Sliding window ksize field must "
580614 " specify " ,
581615 num_dims (), " dimensions" ));
616+ OP_REQUIRES_OK (ctx, ValidateKernelSizes (ksize_));
582617 OP_REQUIRES (ctx, stride_.size () == num_dims (),
583618 errors::InvalidArgument (" Sliding window strides field must "
584619 " specify " ,
585620 num_dims (), " dimensions" ));
621+ OP_REQUIRES_OK (ctx, ValidateStrides (stride_));
586622
587623 const TensorShape tensor_in_shape = ctx->InputShape (0 );
588624 const TensorShape tensor_out_shape = ctx->InputShape (1 );
0 commit comments