Skip to content

Commit 965b97e

Browse files
Properly validate sparse tensor in SparseTensorSliceDataset
Existing validation was incomplete. PiperOrigin-RevId: 415375048 Change-Id: I14cd18f29ede73286f3ffac35171bd15828997e9
1 parent c41e88b commit 965b97e

File tree

2 files changed

+39
-19
lines changed

2 files changed

+39
-19
lines changed

tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -240,28 +240,29 @@ class SparseTensorSliceDatasetOp : public DatasetOpKernel {
240240
OP_REQUIRES_OK(ctx, ctx->input("dense_shape", &dense_shape));
241241

242242
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(indices->shape()),
243-
errors::InvalidArgument(
244-
"Input indices should be a matrix but received shape ",
245-
indices->shape().DebugString()));
246-
247-
const auto num_indices = indices->NumElements();
248-
const auto num_values = values->NumElements();
249-
if (num_indices == 0 || num_values == 0) {
250-
OP_REQUIRES(ctx, num_indices == num_values,
251-
errors::InvalidArgument(
252-
"If indices or values are empty, the other one must also "
253-
"be. Got indices of shape ",
254-
indices->shape().DebugString(), " and values of shape ",
255-
values->shape().DebugString()));
256-
}
243+
errors::InvalidArgument("Input indices must be a matrix. Got: ",
244+
indices->shape().DebugString()));
257245
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(values->shape()),
258-
errors::InvalidArgument(
259-
"Input values should be a vector but received shape ",
260-
indices->shape().DebugString()));
246+
errors::InvalidArgument("Input values must be a vector. Got: ",
247+
values->shape().DebugString()));
261248
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(dense_shape->shape()),
249+
errors::InvalidArgument("Input shape must be a vector. Got: ",
250+
dense_shape->shape().DebugString()));
251+
OP_REQUIRES(
252+
ctx, values->shape().dim_size(0) == indices->shape().dim_size(0),
253+
errors::InvalidArgument(
254+
"Number of values must match first dimension of indices. ", "Got ",
255+
values->shape().dim_size(0),
256+
" values, indices shape: ", indices->shape().DebugString()));
257+
OP_REQUIRES(
258+
ctx, dense_shape->shape().dim_size(0) == indices->shape().dim_size(1),
259+
errors::InvalidArgument(
260+
"Number of dimensions must match second dimension of indices. ",
261+
"Got ", dense_shape->shape().dim_size(0),
262+
" dimensions, indices shape: ", indices->shape().DebugString()));
263+
OP_REQUIRES(ctx, dense_shape->NumElements() > 0,
262264
errors::InvalidArgument(
263-
"Input shape should be a vector but received shape ",
264-
dense_shape->shape().DebugString()));
265+
"The shape argument requires at least one element."));
265266

266267
// We currently ensure that `sparse_tensor` is ordered in the
267268
// batch dimension.

tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,25 @@ def testEmptySparseTensorSlicesInvalid(self):
134134
with self.assertRaises(errors.InvalidArgumentError):
135135
sess.run(init_op, feed_dict={st: sparse_feed})
136136

137+
@combinations.generate(combinations.combine(tf_api_version=1, mode=["graph"]))
138+
def testEmptySparseTensorSlicesInvalid2(self):
139+
"""Test a dataset based on invalid `tf.sparse.SparseTensor`."""
140+
st = array_ops.sparse_placeholder(dtypes.float64)
141+
iterator = dataset_ops.make_initializable_iterator(
142+
dataset_ops.Dataset.from_sparse_tensor_slices(st))
143+
init_op = iterator.initializer
144+
145+
with self.cached_session() as sess:
146+
# Test with an empty sparse tensor but with non empty values.
147+
empty_indices = [[]]
148+
empty_values = []
149+
dense_shape = [1, 1]
150+
sparse_feed = sparse_tensor.SparseTensorValue(empty_indices, empty_values,
151+
dense_shape)
152+
# Here, we expect the test to fail when running the feed.
153+
with self.assertRaises(errors.InvalidArgumentError):
154+
sess.run(init_op, feed_dict={st: sparse_feed})
155+
137156
@combinations.generate(combinations.combine(tf_api_version=2, mode=["eager"]))
138157
def testFromSparseTensorSlicesError(self):
139158
with self.assertRaises(AttributeError):

0 commit comments

Comments
 (0)