Skip to content

Commit 482da92

Browse files
pak-lauratensorflower-gardener
authored andcommitted
Ensure non-empty padding_value input to tf.raw_ops.MatrixDiagPartV2, if a padding_value is input
PiperOrigin-RevId: 388314614 Change-Id: If0b51ad58d5d8543a6be6ce8f42ae4755c80d55f
1 parent 3b4351c commit 482da92

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tensorflow/core/kernels/linalg/matrix_diag_op.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ class MatrixDiagPartOp : public OpKernel {
8989
upper_diag_index = diag_index.flat<int32>()(1);
9090
}
9191
}
92-
padding_value = context->input(2).flat<T>()(0);
92+
const Tensor& padding_in = context->input(2);
93+
OP_REQUIRES(context, padding_in.NumElements() == 1,
94+
errors::InvalidArgument("Padding must be scalar."));
95+
padding_value = padding_in.flat<T>()(0);
9396
}
9497
const TensorShape& input_shape = input.shape();
9598

0 commit comments

Comments
 (0)