Skip to content

Commit 61bf91e

Browse files
Merge pull request #53695 from yongtang:53660-tf.sparse.split-crash
PiperOrigin-RevId: 420811652 Change-Id: I83742482770ba0bf7c3ccd57508c40fb9cdbe2f7
2 parents 5bb1cb3 + 5a5a4ed commit 61bf91e

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

tensorflow/core/kernels/sparse_split_op.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,16 @@ void SparseSplitOpImpl(OpKernelContext* context, int num_split,
7474
done = [] {};
7575
}
7676

77-
const int64_t axis_input = context->input(0).scalar<int64_t>()();
77+
const Tensor& input_axis = context->input(0);
7878
const Tensor& input_indices = context->input(1);
7979
const Tensor& input_values = context->input(2);
8080
const Tensor& input_shape = context->input(3);
8181

82+
OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsScalar(input_axis.shape()),
83+
errors::InvalidArgument(
84+
"Input axis should be a scalar but received shape ",
85+
input_axis.shape().DebugString()),
86+
done);
8287
OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsMatrix(input_indices.shape()),
8388
errors::InvalidArgument(
8489
"Input indices should be a matrix but received shape ",
@@ -95,6 +100,7 @@ void SparseSplitOpImpl(OpKernelContext* context, int num_split,
95100
input_shape.shape().DebugString()),
96101
done);
97102

103+
const int64_t axis_input = input_axis.scalar<int64_t>()();
98104
const int64_t input_rank = input_shape.vec<int64_t>().size();
99105
const int64_t axis = (axis_input < 0) ? input_rank + axis_input : axis_input;
100106

tensorflow/python/kernel_tests/sparse_ops/sparse_split_op_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,15 @@ def testSplitEmpty(self):
279279
self.assertAllEqual(sparse_splits1[1].values, [])
280280
self.assertAllEqual(sparse_splits1[1].dense_shape, [4, 3])
281281

282+
def testInvalidArgumentError(self):
283+
# Test case for GitHub issue 53660.
284+
axis = [1, 2]
285+
with self.assertRaisesRegexp(errors.InvalidArgumentError,
286+
r'axis should be a scalar'):
287+
self.evaluate(
288+
sparse_ops.sparse_split(
289+
sp_input=self._SparseTensor_4x6(), num_split=3, axis=axis))
290+
282291

283292
if __name__ == '__main__':
284293
test.main()

0 commit comments

Comments
 (0)