Skip to content

Commit 4e25654

Browse files
edlopertensorflower-gardener
authored andcommitted
Fix bug that could cause map_fn to produce incorrect results (rather than an error)
when mapping over a ragged tensor with an inappropriate fn_output_signature. (Note: there are cases where the default value for fn_output_signature is not appropriate, so the user needs to explicitly specify the correct output signature.) PiperOrigin-RevId: 387606546 Change-Id: Ib4ea27b9634e6ab413f211cfe809a69a90f0e2cd
1 parent 2c5a876 commit 4e25654

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

tensorflow/core/kernels/ragged_tensor_from_variant_op.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,23 @@ Status NestedStackRaggedTensors(
174174
auto output_values_flat =
175175
output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>();
176176
int values_index = 0;
177+
178+
TensorShape expected_value_shape = component_values_shape;
179+
expected_value_shape.RemoveDim(0);
180+
177181
for (int i = 0; i < ragged_components.size(); i++) {
182+
// Check that the flat_values tensor shape is compatible.
183+
TensorShape value_shape = ragged_components[i].values().shape();
184+
value_shape.RemoveDim(0);
185+
if (value_shape != expected_value_shape) {
186+
return errors::InvalidArgument(
187+
"All flat_values must have compatible shapes. Shape at index 0: ",
188+
expected_value_shape, ". Shape at index ", i, ": ", value_shape,
189+
". If you are using tf.map_fn, then you may need to specify an "
190+
"explicit fn_output_signature with appropriate ragged_rank, and/or "
191+
"convert output tensors to RaggedTensors.");
192+
}
193+
178194
auto component_values_flat =
179195
ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>();
180196
int num_inner_elements = ragged_components[i].values().NumElements();

tensorflow/python/ops/ragged/ragged_map_fn_op_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
import numpy as np
2222

2323
from tensorflow.python.framework import dtypes
24+
from tensorflow.python.framework import errors
2425
from tensorflow.python.framework import sparse_tensor
2526
from tensorflow.python.framework import test_util
2627
from tensorflow.python.ops import array_ops
28+
from tensorflow.python.ops import map_fn as map_fn_lib
2729
from tensorflow.python.ops import math_ops as mo
2830
from tensorflow.python.ops import string_ops
2931
from tensorflow.python.ops.ragged import ragged_factory_ops
@@ -309,6 +311,27 @@ def testMapOnSparseTensor(self):
309311
)
310312
self.assertAllEqual(id_t2, [[0, 5], [0, 4]])
311313

314+
def testRaggedMapWithIncorrectFnOutputSignature(self):
315+
x = ragged_factory_ops.constant([[1, 2, 3, 4], [1]])
316+
with self.assertRaisesRegex(errors.InvalidArgumentError,
317+
'All flat_values must have compatible shapes'):
318+
y = map_fn_lib.map_fn(lambda r: map_fn_lib.map_fn(lambda y: r, r), x)
319+
self.evaluate(y)
320+
321+
def testNestedRaggedMapWithFnOutputSignature(self):
322+
ragged1d = ragged_tensor.RaggedTensorSpec([None], dtypes.int32)
323+
ragged2d = ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)
324+
325+
x = ragged_factory_ops.constant([[1, 2, 3, 4], [1]])
326+
# pylint: disable=g-long-lambda
327+
y = map_fn_lib.map_fn(
328+
lambda r: map_fn_lib.map_fn(
329+
lambda y: r, r, fn_output_signature=ragged1d),
330+
x,
331+
fn_output_signature=ragged2d)
332+
expected = [[[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], [[1]]]
333+
self.assertAllEqual(y, expected)
334+
312335

313336
if __name__ == '__main__':
314337
googletest.main()

0 commit comments

Comments
 (0)