Skip to content

Commit 02cc160

Browse files
Prevent nullptr deref in SparseTensorSliceDataset
The arguments must determine a valid sparse tensor. This means that when indices are empty then the values must be empty too (and the reverse). Also added test, by modifying existing test with empty sparse tensor to now run with an invalid sparse tensor input. PiperOrigin-RevId: 388562757 Change-Id: Id8b54cd7c2316025b4f9a77292c8fb5344d17609
1 parent 234a51a commit 02cc160

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,17 @@ class SparseTensorSliceDatasetOp : public DatasetOpKernel {
241241
errors::InvalidArgument(
242242
"Input indices should be a matrix but received shape ",
243243
indices->shape().DebugString()));
244+
245+
const auto num_indices = indices->NumElements();
246+
const auto num_values = values->NumElements();
247+
if (num_indices == 0 || num_values == 0) {
248+
OP_REQUIRES(ctx, num_indices == num_values,
249+
errors::InvalidArgument(
250+
"If indices or values are empty, the other one must also "
251+
"be. Got indices of shape ",
252+
indices->shape().DebugString(), " and values of shape ",
253+
values->shape().DebugString()));
254+
}
244255
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(values->shape()),
245256
errors::InvalidArgument(
246257
"Input values should be a vector but received shape ",

tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,26 @@ def testEmptySparseTensorSlices(self):
118118
with self.assertRaises(errors.OutOfRangeError):
119119
sess.run(get_next)
120120

121+
@combinations.generate(combinations.combine(tf_api_version=1, mode=["graph"]))
122+
def testEmptySparseTensorSlicesInvalid(self):
123+
"""Test a dataset based on invalid `tf.sparse.SparseTensor`."""
124+
st = array_ops.sparse_placeholder(dtypes.float64)
125+
iterator = dataset_ops.make_initializable_iterator(
126+
dataset_ops.Dataset.from_sparse_tensor_slices(st))
127+
init_op = iterator.initializer
128+
129+
with self.cached_session() as sess:
130+
# Test with an empty sparse tensor but with non empty values.
131+
empty_indices = np.empty((0, 4), dtype=np.int64)
132+
non_empty_values = [1, 2, 3, 4]
133+
empty_dense_shape = [0, 4, 37, 9]
134+
sparse_feed = sparse_tensor.SparseTensorValue(empty_indices,
135+
non_empty_values,
136+
empty_dense_shape)
137+
# Here, we expect the test to fail when running the feed.
138+
with self.assertRaises(errors.InvalidArgumentError):
139+
sess.run(init_op, feed_dict={st: sparse_feed})
140+
121141
@combinations.generate(combinations.combine(tf_api_version=2, mode=["eager"]))
122142
def testFromSparseTensorSlicesError(self):
123143
with self.assertRaises(AttributeError):

0 commit comments

Comments
 (0)