File tree Expand file tree Collapse file tree 2 files changed +16
-1
lines changed
Expand file tree Collapse file tree 2 files changed +16
-1
lines changed Original file line number Diff line number Diff line change @@ -30,11 +30,16 @@ class SparseSplitOp : public OpKernel {
3030 }
3131
3232 void Compute (OpKernelContext* context) override {
33- const int64_t axis_input = context->input (0 ). scalar < int64_t >()( );
33+ const Tensor& input_axis = context->input (0 );
3434 const Tensor& input_indices = context->input (1 );
3535 const Tensor& input_values = context->input (2 );
3636 const Tensor& input_shape = context->input (3 );
3737
38+ OP_REQUIRES (context, TensorShapeUtils::IsScalar (input_axis.shape ()),
39+ errors::InvalidArgument (
40+ " Input axis should be a scalar but received shape " ,
41+ input_axis.shape ().DebugString ()),
42+ done);
3843 OP_REQUIRES (context, TensorShapeUtils::IsMatrix (input_indices.shape ()),
3944 errors::InvalidArgument (
4045 " Input indices should be a matrix but received shape " ,
@@ -48,6 +53,7 @@ class SparseSplitOp : public OpKernel {
4853 " Input shape should be a vector but received shape " ,
4954 input_shape.shape ().DebugString ()));
5055
56+ const int64_t axis_input = input_axis.scalar <int64_t >()();
5157 const int64_t input_rank = input_shape.vec <int64_t >().size ();
5258 const int64_t axis =
5359 (axis_input < 0 ) ? input_rank + axis_input : axis_input;
Original file line number Diff line number Diff line change @@ -257,6 +257,15 @@ def testArgumentErrors(self):
257257 with self .assertRaisesRegex (ValueError , 'axis is required' ):
258258 sparse_ops .sparse_split (num_split = 2 , sp_input = 1 )
259259
260+ def testInvalidArgumentError (self ):
261+ # Test case for GitHub issue 53660.
262+ axis = [1 , 2 ]
263+ with self .assertRaisesRegexp (errors .InvalidArgumentError ,
264+ r'axis should be a scalar' ):
265+ self .evaluate (
266+ sparse_ops .sparse_split (
267+ sp_input = self ._SparseTensor_4x6 (), num_split = 3 , axis = axis ))
268+
260269
261270if __name__ == '__main__' :
262271 test .main ()
You can’t perform that action at this time.
0 commit comments