|
15 | 15 | """Tests for tensorflow.kernels.edit_distance_op.""" |
16 | 16 |
|
17 | 17 | import numpy as np |
18 | | - |
| 18 | +from tensorflow.python.eager import def_function |
19 | 19 | from tensorflow.python.framework import constant_op |
| 20 | +from tensorflow.python.framework import errors |
20 | 21 | from tensorflow.python.framework import ops |
21 | 22 | from tensorflow.python.framework import sparse_tensor |
22 | 23 | from tensorflow.python.ops import array_ops |
@@ -225,6 +226,66 @@ def testEditDistanceBadIndices(self): |
225 | 226 | "to outside of the buffer for the output tensor|" |
226 | 227 | r"Dimension -\d+ must be >= 0")) |
227 | 228 |
|
| 229 | + def testEmptyShapeWithEditDistanceRaisesError(self): |
| 230 | + para = { |
| 231 | + "hypothesis_indices": [[]], |
| 232 | + "hypothesis_values": ["tmp/"], |
| 233 | + "hypothesis_shape": [], |
| 234 | + "truth_indices": [[]], |
| 235 | + "truth_values": [""], |
| 236 | + "truth_shape": [], |
| 237 | + "normalize": False, |
| 238 | + } |
| 239 | + |
| 240 | + # Check edit distance raw op with empty shape in eager mode. |
| 241 | + with self.assertRaisesRegex( |
| 242 | + (errors.InvalidArgumentError, ValueError), |
| 243 | + ( |
| 244 | + r"Input Hypothesis SparseTensors must have rank at least 2, but" |
| 245 | + " hypothesis_shape rank is: 0|Input SparseTensors must have rank " |
| 246 | + "at least 2, but truth_shape rank is: 0" |
| 247 | + ), |
| 248 | + ): |
| 249 | + array_ops.gen_array_ops.EditDistance(**para) |
| 250 | + |
| 251 | + # Check raw op with tf.function |
| 252 | + @def_function.function |
| 253 | + def TestFunction(): |
| 254 | + """Wrapper function for edit distance call.""" |
| 255 | + array_ops.gen_array_ops.EditDistance(**para) |
| 256 | + |
| 257 | + with self.assertRaisesRegex( |
| 258 | + ValueError, |
| 259 | + ( |
| 260 | + "Input Hypothesis SparseTensors must have rank at least 2, but" |
| 261 | + " hypothesis_shape rank is: 0" |
| 262 | + ), |
| 263 | + ): |
| 264 | + TestFunction() |
| 265 | + |
| 266 | + # Check with python wrapper API |
| 267 | + hypothesis_indices = [[]] |
| 268 | + hypothesis_values = [0] |
| 269 | + hypothesis_shape = [] |
| 270 | + truth_indices = [[]] |
| 271 | + truth_values = [1] |
| 272 | + truth_shape = [] |
| 273 | + expected_output = [] # dummy ignored |
| 274 | + |
| 275 | + with self.assertRaisesRegex( |
| 276 | + ValueError, |
| 277 | + ( |
| 278 | + "Input Hypothesis SparseTensors must have rank at least 2, but" |
| 279 | + " hypothesis_shape rank is: 0" |
| 280 | + ), |
| 281 | + ): |
| 282 | + self._testEditDistance( |
| 283 | + hypothesis=(hypothesis_indices, hypothesis_values, hypothesis_shape), |
| 284 | + truth=(truth_indices, truth_values, truth_shape), |
| 285 | + normalize=False, |
| 286 | + expected_output=expected_output, |
| 287 | + ) |
| 288 | + |
228 | 289 |
|
229 | 290 | if __name__ == "__main__": |
230 | 291 | test.main() |
0 commit comments