Skip to content

Commit f2a673b

Browse files
Add missing validation to matrix_diag_op.cc
PiperOrigin-RevId: 387923533 Change-Id: Idfffeb328d5f9c6748d992d28a56d6e9e45103a0
1 parent ff88940 commit f2a673b

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

tensorflow/core/kernels/linalg/matrix_diag_op.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ class MatrixDiagPartOp : public OpKernel {
7373
errors::InvalidArgument(
7474
"diag_index must be a scalar or vector, received shape: ",
7575
diag_index.shape().DebugString()));
76+
OP_REQUIRES(context, diag_index.NumElements() > 0,
77+
errors::InvalidArgument(
78+
"Expected diag_index to have at least 1 element"));
7679
lower_diag_index = diag_index.flat<int32>()(0);
7780
upper_diag_index = lower_diag_index;
7881
if (TensorShapeUtils::IsVector(diag_index.shape())) {
@@ -179,6 +182,9 @@ class MatrixDiagOp : public OpKernel {
179182
errors::InvalidArgument(
180183
"diag_index must be a scalar or vector, received shape: ",
181184
diag_index.shape().DebugString()));
185+
OP_REQUIRES(context, diag_index.NumElements() > 0,
186+
errors::InvalidArgument(
187+
"Expected diag_index to have at least 1 element"));
182188
lower_diag_index = diag_index.flat<int32>()(0);
183189
upper_diag_index = lower_diag_index;
184190
if (TensorShapeUtils::IsVector(diag_index.shape())) {

0 commit comments

Comments
 (0)