Skip to content

Commit 58b34c6

Browse files
isharktensorflower-gardener
authored andcommitted
Fix integer overflow leading to divide by zero error in Unravel index kernel when dimensions product exceeds max int value.
PiperOrigin-RevId: 413250052 Change-Id: I9450b6e8acecd2e881a64b882e2b7c70e8e9289a
1 parent 4d00cd5 commit 58b34c6

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

tensorflow/core/kernels/unravel_index_op.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#include <cstdint>
17+
18+
#include "tensorflow/core/framework/types.pb.h"
19+
#include "tensorflow/core/platform/types.h"
1620
#define EIGEN_USE_THREADS
1721

1822
#include "tensorflow/core/framework/op_kernel.h"
@@ -35,7 +39,8 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
3539
template <typename Tidx>
3640
class UnravelIndexOp : public OpKernel {
3741
public:
38-
explicit UnravelIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
42+
explicit UnravelIndexOp(OpKernelConstruction* ctx)
43+
: OpKernel(ctx), dtidx_(DataTypeToEnum<Tidx>::v()) {}
3944

4045
void Compute(OpKernelContext* ctx) override {
4146
const Tensor& indices_tensor = ctx->input(0);
@@ -54,12 +59,31 @@ class UnravelIndexOp : public OpKernel {
5459

5560
auto dims = dims_tensor.vec<Tidx>();
5661
// Make sure dims does not contain a zero
62+
double prod = 1;
63+
uint64_t limit;
64+
if (dtidx_ == DataType::DT_INT64) {
65+
limit = kint64max;
66+
} else {
67+
limit = kint32max;
68+
}
69+
5770
for (int i = 0; i < dims.size(); i++) {
5871
OP_REQUIRES(
5972
ctx, dims(i) != 0,
6073
errors::InvalidArgument("Input dims cannot contain a dim of zero, "
6174
"but dims contains zero at index ",
6275
i));
76+
OP_REQUIRES(ctx, dims(i) > 0,
77+
errors::InvalidArgument(
78+
"Input dims cannot be negative. Got dim = ", dims(i),
79+
" at index ", i));
80+
// Check interger overflow
81+
OP_REQUIRES(
82+
ctx, prod <= limit / dims(i),
83+
errors::InvalidArgument("Input dims product is causing integer "
84+
"overflow: (",
85+
dims, ")"));
86+
prod = (prod * dims(i));
6387
}
6488

6589
// Check to make sure indices is not out of boundary
@@ -132,6 +156,7 @@ class UnravelIndexOp : public OpKernel {
132156
strides_shifted.reshape(reshape).broadcast(bcast);
133157
}
134158
}
159+
const DataType dtidx_;
135160
};
136161

137162
#define REGISTER_KERNEL(type) \

tensorflow/python/kernel_tests/array_ops/array_ops_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,20 @@ def testUnravelIndexZeroDim(self):
15801580
dims = constant_op.constant([3, 0], dtype=dtype)
15811581
self.evaluate(array_ops.unravel_index(indices=indices, dims=dims))
15821582

1583+
def testUnravelIndexIntegerOverflow(self):
1584+
with self.cached_session():
1585+
for dtype in [dtypes.int32, dtypes.int64]:
1586+
with self.assertRaisesRegex(
1587+
errors.InvalidArgumentError,
1588+
r"Input dims product is causing integer overflow"):
1589+
indices = constant_op.constant(-0x100000, dtype=dtype)
1590+
if dtype == dtypes.int32:
1591+
value = 0x10000000
1592+
else:
1593+
value = 0x7FFFFFFFFFFFFFFF
1594+
dims = constant_op.constant([value, value], dtype=dtype)
1595+
self.evaluate(array_ops.unravel_index(indices=indices, dims=dims))
1596+
15831597

15841598
class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
15851599

0 commit comments

Comments
 (0)