Skip to content

Commit 894ba6b

Browse files
Merge pull request #53695 from yongtang:53660-tf.sparse.split-crash
PiperOrigin-RevId: 420811652 Change-Id: I83742482770ba0bf7c3ccd57508c40fb9cdbe2f7
1 parent ea9b53a commit 894ba6b

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

tensorflow/core/kernels/sparse_split_op.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,16 @@ class SparseSplitOp : public OpKernel {
3030
}
3131

3232
void Compute(OpKernelContext* context) override {
33-
const int64_t axis_input = context->input(0).scalar<int64_t>()();
33+
const Tensor& input_axis = context->input(0);
3434
const Tensor& input_indices = context->input(1);
3535
const Tensor& input_values = context->input(2);
3636
const Tensor& input_shape = context->input(3);
3737

38+
OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_axis.shape()),
39+
errors::InvalidArgument(
40+
"Input axis should be a scalar but received shape ",
41+
input_axis.shape().DebugString()),
42+
done);
3843
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()),
3944
errors::InvalidArgument(
4045
"Input indices should be a matrix but received shape ",
@@ -48,6 +53,7 @@ class SparseSplitOp : public OpKernel {
4853
"Input shape should be a vector but received shape ",
4954
input_shape.shape().DebugString()));
5055

56+
const int64_t axis_input = input_axis.scalar<int64_t>()();
5157
const int64_t input_rank = input_shape.vec<int64_t>().size();
5258
const int64_t axis =
5359
(axis_input < 0) ? input_rank + axis_input : axis_input;

tensorflow/python/kernel_tests/sparse_split_op_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,15 @@ def testArgumentErrors(self):
257257
with self.assertRaisesRegex(ValueError, 'axis is required'):
258258
sparse_ops.sparse_split(num_split=2, sp_input=1)
259259

260+
def testInvalidArgumentError(self):
261+
# Test case for GitHub issue 53660.
262+
axis = [1, 2]
263+
with self.assertRaisesRegexp(errors.InvalidArgumentError,
264+
r'axis should be a scalar'):
265+
self.evaluate(
266+
sparse_ops.sparse_split(
267+
sp_input=self._SparseTensor_4x6(), num_split=3, axis=axis))
268+
260269

261270
if __name__ == '__main__':
262271
test.main()

0 commit comments

Comments
 (0)