Skip to content

Commit 08b8e18

Browse files
isharktensorflower-gardener
authored andcommitted
Fix security vulnerability in EditDistance op shape function.
PiperOrigin-RevId: 504367470
1 parent 7533da4 commit 08b8e18

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

tensorflow/core/ops/array_ops.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "tensorflow/core/framework/types.h"
2626
#include "tensorflow/core/framework/types.pb.h"
2727
#include "tensorflow/core/lib/core/errors.h"
28+
#include "tensorflow/core/platform/status.h"
2829
#include "tensorflow/core/platform/types.h"
2930
#include "tensorflow/core/util/mirror_pad_mode.h"
3031
#include "tensorflow/core/util/padding.h"
@@ -1072,13 +1073,24 @@ REGISTER_OP("EditDistance")
10721073
// or else the output shape is unknown.
10731074
return shape_inference::UnknownShape(c);
10741075
}
1075-
10761076
if (hypothesis_shape_t->NumElements() != truth_shape_t->NumElements()) {
10771077
return errors::InvalidArgument(
10781078
"Num elements of hypothesis_shape does not match truth_shape: ",
10791079
hypothesis_shape_t->NumElements(), " vs. ",
10801080
truth_shape_t->NumElements());
10811081
}
1082+
if (hypothesis_shape_t->NumElements() < 2) {
1083+
return errors::InvalidArgument(
1084+
"Input Hypothesis SparseTensors must have rank at least 2, but "
1085+
"hypothesis_shape rank is: ",
1086+
hypothesis_shape_t->NumElements());
1087+
}
1088+
if (truth_shape_t->NumElements() < 2) {
1089+
return errors::InvalidArgument(
1090+
"Input Truth SparseTensors must have rank at least 2, but "
1091+
"truth_shape rank is: ",
1092+
truth_shape_t->NumElements());
1093+
}
10821094

10831095
auto h_values = hypothesis_shape_t->flat<int64_t>();
10841096
auto t_values = truth_shape_t->flat<int64_t>();

tensorflow/python/kernel_tests/array_ops/edit_distance_op_test.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
"""Tests for tensorflow.kernels.edit_distance_op."""
1616

1717
import numpy as np
18-
18+
from tensorflow.python.eager import def_function
1919
from tensorflow.python.framework import constant_op
20+
from tensorflow.python.framework import errors
2021
from tensorflow.python.framework import ops
2122
from tensorflow.python.framework import sparse_tensor
2223
from tensorflow.python.ops import array_ops
@@ -225,6 +226,66 @@ def testEditDistanceBadIndices(self):
225226
"to outside of the buffer for the output tensor|"
226227
r"Dimension -\d+ must be >= 0"))
227228

229+
def testEmptyShapeWithEditDistanceRaisesError(self):
230+
para = {
231+
"hypothesis_indices": [[]],
232+
"hypothesis_values": ["tmp/"],
233+
"hypothesis_shape": [],
234+
"truth_indices": [[]],
235+
"truth_values": [""],
236+
"truth_shape": [],
237+
"normalize": False,
238+
}
239+
240+
# Check edit distance raw op with empty shape in eager mode.
241+
with self.assertRaisesRegex(
242+
(errors.InvalidArgumentError, ValueError),
243+
(
244+
r"Input Hypothesis SparseTensors must have rank at least 2, but"
245+
" hypothesis_shape rank is: 0|Input SparseTensors must have rank "
246+
"at least 2, but truth_shape rank is: 0"
247+
),
248+
):
249+
array_ops.gen_array_ops.EditDistance(**para)
250+
251+
# Check raw op with tf.function
252+
@def_function.function
253+
def TestFunction():
254+
"""Wrapper function for edit distance call."""
255+
array_ops.gen_array_ops.EditDistance(**para)
256+
257+
with self.assertRaisesRegex(
258+
ValueError,
259+
(
260+
"Input Hypothesis SparseTensors must have rank at least 2, but"
261+
" hypothesis_shape rank is: 0"
262+
),
263+
):
264+
TestFunction()
265+
266+
# Check with python wrapper API
267+
hypothesis_indices = [[]]
268+
hypothesis_values = [0]
269+
hypothesis_shape = []
270+
truth_indices = [[]]
271+
truth_values = [1]
272+
truth_shape = []
273+
expected_output = [] # dummy ignored
274+
275+
with self.assertRaisesRegex(
276+
ValueError,
277+
(
278+
"Input Hypothesis SparseTensors must have rank at least 2, but"
279+
" hypothesis_shape rank is: 0"
280+
),
281+
):
282+
self._testEditDistance(
283+
hypothesis=(hypothesis_indices, hypothesis_values, hypothesis_shape),
284+
truth=(truth_indices, truth_values, truth_shape),
285+
normalize=False,
286+
expected_output=expected_output,
287+
)
288+
228289

229290
if __name__ == "__main__":
230291
test.main()

0 commit comments

Comments
 (0)