Skip to content

Commit 1295ae4

Browse files
majnemertensorflower-gardener
authored andcommitted
[tf2xla] Validate that stride and window size are positive
PiperOrigin-RevId: 504866231
1 parent 3b1b9de commit 1295ae4

File tree

3 files changed

+112
-36
lines changed

3 files changed

+112
-36
lines changed

tensorflow/compiler/tests/pooling_ops_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
from tensorflow.compiler.tests import xla_test
2020
from tensorflow.python.framework import dtypes
21+
from tensorflow.python.framework import errors
2122
from tensorflow.python.framework import ops
23+
from tensorflow.python.framework import test_util
2224
from tensorflow.python.ops import array_ops
2325
from tensorflow.python.ops import gen_nn_ops
2426
from tensorflow.python.ops import nn_ops
@@ -560,6 +562,34 @@ def AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding,
560562

561563
self._TestPooling(nn_ops.avg_pool, AvgPoolGrad)
562564

565+
@test_util.disable_mlir_bridge(
566+
"TODO(b/266613412): investigate FPE in AvgPoolGrad for TPU"
567+
)
568+
def testAvgPoolGradSamePaddingZeroStrideZeroSize(self):
569+
output_gradient_vals = np.array([0.39117979], dtype=np.float32)
570+
output_gradient_vals = output_gradient_vals.reshape([1, 1, 1, 1])
571+
with self.session() as sess:
572+
with self.test_scope():
573+
output_gradients = array_ops.placeholder(
574+
dtypes.float32, shape=output_gradient_vals.shape
575+
)
576+
t = gen_nn_ops.avg_pool_grad(
577+
orig_input_shape=[1, 0, 0, 0],
578+
grad=output_gradients,
579+
ksize=[1, 0, 0, 0],
580+
strides=[1, 0, 0, 0],
581+
padding="SAME",
582+
data_format="NCHW",
583+
)
584+
with self.assertRaisesRegex(
585+
errors.InvalidArgumentError,
586+
(
587+
"Sliding window ksize field for dimension 1 must be positive but"
588+
" is 0"
589+
),
590+
):
591+
sess.run(t, {output_gradients: output_gradient_vals})
592+
563593
# The CPU implementation of AvgPoolGrad doesn't accept kernels smaller than
564594
# the stride size, so we only run the following tests on MaxPoolGrad.
565595

tensorflow/compiler/tf2xla/kernels/pooling_ops.cc

Lines changed: 72 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,41 @@ limitations under the License.
3333
#include "tensorflow/compiler/xla/util.h"
3434
#include "tensorflow/core/framework/bounds_check.h"
3535
#include "tensorflow/core/framework/op_kernel.h"
36+
#include "tensorflow/core/framework/op_requires.h"
3637
#include "tensorflow/core/framework/register_types.h"
3738
#include "tensorflow/core/framework/tensor.h"
3839
#include "tensorflow/core/platform/errors.h"
3940
#include "tensorflow/core/util/determinism.h"
4041
#include "tensorflow/core/util/tensor_format.h"
42+
#include "tensorflow/tsl/platform/errors.h"
4143

4244
namespace tensorflow {
4345
namespace {
4446

47+
template <typename T>
48+
static Status ValidateKernelSizes(const T& ksizes) {
49+
for (size_t i = 0; i < ksizes.size(); ++i) {
50+
if (ksizes[i] <= 0) {
51+
return errors::InvalidArgument(
52+
"Sliding window ksize field for dimension ", i,
53+
" must be positive but is ", ksizes[i]);
54+
}
55+
}
56+
return OkStatus();
57+
}
58+
59+
template <typename T>
60+
static Status ValidateStrides(const T& strides) {
61+
for (size_t i = 0; i < strides.size(); ++i) {
62+
if (strides[i] <= 0) {
63+
return errors::InvalidArgument(
64+
"Sliding window stride field for dimension ", i,
65+
" must be positive but is ", strides[i]);
66+
}
67+
}
68+
return OkStatus();
69+
}
70+
4571
// Superclass of pooling ops.
4672
class PoolingOp : public XlaOpKernel {
4773
public:
@@ -83,50 +109,54 @@ class PoolingOp : public XlaOpKernel {
83109

84110
protected:
85111
StatusOr<std::vector<int64_t>> GetKernelSize(XlaOpKernelContext* ctx) {
86-
if (ctx->num_inputs() == 1) {
87-
return ksize_;
88-
}
89-
const TensorShape ksize_shape = ctx->InputShape(1);
90-
// Validate input sizes.
91-
if (!TensorShapeUtils::IsVector(ksize_shape)) {
92-
return errors::InvalidArgument("ksize must be a vector, not shape ",
93-
ksize_shape.DebugString());
94-
}
95-
if (ksize_shape.num_elements() != num_dims()) {
96-
return errors::InvalidArgument(
97-
"Sliding window ksize field must "
98-
"specify ",
99-
num_dims(), " dimensions");
100-
}
101112
std::vector<int64_t> ksize;
102-
auto status = ctx->ConstantInputAsIntVector(1, &ksize);
103-
if (!status.ok()) {
104-
return status;
113+
if (ctx->num_inputs() == 1) {
114+
ksize = ksize_;
115+
} else {
116+
const TensorShape ksize_shape = ctx->InputShape(1);
117+
// Validate input sizes.
118+
if (!TensorShapeUtils::IsVector(ksize_shape)) {
119+
return errors::InvalidArgument("ksize must be a vector, not shape ",
120+
ksize_shape.DebugString());
121+
}
122+
if (ksize_shape.num_elements() != num_dims()) {
123+
return errors::InvalidArgument(
124+
"Sliding window ksize field must "
125+
"specify ",
126+
num_dims(), " dimensions");
127+
}
128+
auto status = ctx->ConstantInputAsIntVector(1, &ksize);
129+
if (!status.ok()) {
130+
return status;
131+
}
105132
}
133+
TF_RETURN_IF_ERROR(ValidateKernelSizes(ksize));
106134
return ksize;
107135
}
108136

109137
StatusOr<std::vector<int64_t>> GetStride(XlaOpKernelContext* ctx) {
110-
if (ctx->num_inputs() == 1) {
111-
return stride_;
112-
}
113-
const TensorShape stride_shape = ctx->InputShape(2);
114-
// Validate input sizes.
115-
if (!TensorShapeUtils::IsVector(stride_shape)) {
116-
return errors::InvalidArgument("stride must be a vector, not shape ",
117-
stride_shape.DebugString());
118-
}
119-
if (stride_shape.num_elements() != num_dims()) {
120-
return errors::InvalidArgument(
121-
"Sliding window stride field must "
122-
"specify ",
123-
num_dims(), " dimensions");
124-
}
125138
std::vector<int64_t> stride;
126-
auto status = ctx->ConstantInputAsIntVector(2, &stride);
127-
if (!status.ok()) {
128-
return status;
139+
if (ctx->num_inputs() == 1) {
140+
stride = stride_;
141+
} else {
142+
const TensorShape stride_shape = ctx->InputShape(2);
143+
// Validate input sizes.
144+
if (!TensorShapeUtils::IsVector(stride_shape)) {
145+
return errors::InvalidArgument("stride must be a vector, not shape ",
146+
stride_shape.DebugString());
147+
}
148+
if (stride_shape.num_elements() != num_dims()) {
149+
return errors::InvalidArgument(
150+
"Sliding window stride field must "
151+
"specify ",
152+
num_dims(), " dimensions");
153+
}
154+
auto status = ctx->ConstantInputAsIntVector(2, &stride);
155+
if (!status.ok()) {
156+
return status;
157+
}
129158
}
159+
TF_RETURN_IF_ERROR(ValidateStrides(stride));
130160
return stride;
131161
}
132162

@@ -355,10 +385,12 @@ class MaxPoolGradOp : public XlaOpKernel {
355385
errors::InvalidArgument("Sliding window ksize field must "
356386
"specify ",
357387
num_dims(), " dimensions"));
388+
OP_REQUIRES_OK(ctx, ValidateKernelSizes(ksize_));
358389
OP_REQUIRES(ctx, stride_.size() == num_dims(),
359390
errors::InvalidArgument("Sliding window strides field must "
360391
"specify ",
361392
num_dims(), " dimensions"));
393+
OP_REQUIRES_OK(ctx, ValidateStrides(stride_));
362394

363395
const TensorShape tensor_in_shape = ctx->InputShape(0);
364396
const TensorShape tensor_out_shape = ctx->InputShape(1);
@@ -446,11 +478,13 @@ class AvgPoolGradOp : public XlaOpKernel {
446478
errors::InvalidArgument("Sliding window ksize field must "
447479
"specify ",
448480
num_dims(), " dimensions"));
481+
OP_REQUIRES_OK(ctx, ValidateKernelSizes(ksize_));
449482
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
450483
OP_REQUIRES(ctx, stride_.size() == num_dims(),
451484
errors::InvalidArgument("Sliding window strides field must "
452485
"specify ",
453486
num_dims(), " dimensions"));
487+
OP_REQUIRES_OK(ctx, ValidateStrides(stride_));
454488
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
455489
OP_REQUIRES(ctx, padding_ != EXPLICIT,
456490
errors::Unimplemented(
@@ -579,10 +613,12 @@ class MaxPoolGradGradOp : public XlaOpKernel {
579613
errors::InvalidArgument("Sliding window ksize field must "
580614
"specify ",
581615
num_dims(), " dimensions"));
616+
OP_REQUIRES_OK(ctx, ValidateKernelSizes(ksize_));
582617
OP_REQUIRES(ctx, stride_.size() == num_dims(),
583618
errors::InvalidArgument("Sliding window strides field must "
584619
"specify ",
585620
num_dims(), " dimensions"));
621+
OP_REQUIRES_OK(ctx, ValidateStrides(stride_));
586622

587623
const TensorShape tensor_in_shape = ctx->InputShape(0);
588624
const TensorShape tensor_out_shape = ctx->InputShape(1);

tensorflow/compiler/xla/client/padding.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ Status ValidatePaddingValues(absl::Span<const int64_t> input_dimensions,
3535
input_dimensions.size(), window_dimensions.size(),
3636
window_strides.size());
3737
}
38+
for (size_t i = 0; i < input_dimensions.size(); ++i) {
39+
if (window_dimensions[i] <= 0) {
40+
return InvalidArgument("Window dimension %u has non-positive size %d", i,
41+
window_dimensions[i]);
42+
}
43+
if (window_strides[i] <= 0) {
44+
return InvalidArgument("Window dimension %u has non-positive stride %d",
45+
i, window_strides[i]);
46+
}
47+
}
3848
return OkStatus();
3949
}
4050

0 commit comments

Comments
 (0)