Skip to content

Commit 60ed7ce

Browse files
pak-lauratensorflow-jenkins
authored andcommitted
Re-enable testTensorListReserveWithNonScalarNumElements to work with mlir as well.
PiperOrigin-RevId: 466460987
1 parent 23cb0d3 commit 60ed7ce

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

tensorflow/core/kernels/list_kernels.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ limitations under the License.
3131
#include "tensorflow/core/framework/allocator.h"
3232
#include "tensorflow/core/framework/op_kernel.h"
3333
#include "tensorflow/core/framework/register_types.h"
34+
#include "tensorflow/core/framework/tensor_shape.h"
3435
#include "tensorflow/core/framework/tensor_types.h"
3536
#include "tensorflow/core/framework/variant.h"
3637
#include "tensorflow/core/framework/variant_op_registry.h"
38+
#include "tensorflow/core/platform/errors.h"
3739

3840
namespace tensorflow {
3941

@@ -322,6 +324,11 @@ class TensorListReserve : public OpKernel {
322324
void Compute(OpKernelContext* c) override {
323325
PartialTensorShape element_shape;
324326
OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(0), &element_shape));
327+
OP_REQUIRES(
328+
c, TensorShapeUtils::IsScalar(c->input(1).shape()),
329+
errors::InvalidArgument(
330+
"The num_elements to reserve must be a tensor size 1, but got ",
331+
c->input(1).shape()));
325332
int32_t num_elements = c->input(1).scalar<int32>()();
326333
OP_REQUIRES(c, num_elements >= 0,
327334
errors::InvalidArgument("The num_elements to reserve must be a "

tensorflow/python/kernel_tests/data_structures/list_ops_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,16 @@ def testPopFromEmptyTensorListFails(self, max_num_elements):
9494
l = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
9595
self.evaluate(l)
9696

97+
def testTensorListReserveWithNonScalarNumElements(self):
98+
# list_kernels.cc in tf/core/kernels raises InvalidArgumentError, and
99+
# tf_ops_n_z.cc in tf/compiler/mlir/tf/ir raises UnknownError.
100+
with self.assertRaises((errors.InvalidArgumentError, errors.UnknownError)):
101+
l = list_ops.tensor_list_reserve(
102+
element_dtype=dtypes.float32,
103+
element_shape=[2, 3],
104+
num_elements=constant_op.constant([1, 1]))
105+
self.evaluate(l)
106+
97107
def testPopUninitializedTensorUseListElementShape(self):
98108
l = list_ops.tensor_list_reserve(
99109
element_dtype=dtypes.float32, element_shape=[2, 3], num_elements=3)

0 commit comments

Comments
 (0)