Skip to content

Commit 6ed2bd0

Browse files
sagunbtensorflow-jenkins
authored andcommitted
Fix security vulnerability with LSTMBlockCellOp
PiperOrigin-RevId: 446028341
1 parent 2667271 commit 6ed2bd0

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

tensorflow/core/kernels/rnn/lstm_ops.cc

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,65 @@ class LSTMBlockCellOp : public OpKernel {
416416

417417
const Device& device = ctx->eigen_device<Device>();
418418

419+
// Sanity check that each of the tensors have the required NDIMS.
420+
OP_REQUIRES(ctx, x_tensor->dims() == 2,
421+
errors::InvalidArgument("x_tensor must be rank 2 but is rank ",
422+
x_tensor->dims(), "."));
423+
OP_REQUIRES(
424+
ctx, cs_prev_tensor->dims() == 2,
425+
errors::InvalidArgument("cs_prev_tensor must be rank 2 but is rank ",
426+
cs_prev_tensor->dims(), "."));
427+
OP_REQUIRES(
428+
ctx, h_prev_tensor->dims() == 2,
429+
errors::InvalidArgument("h_prev_tensor must be rank 2 but is rank ",
430+
h_prev_tensor->dims(), "."));
431+
OP_REQUIRES(ctx, w_tensor->dims() == 2,
432+
errors::InvalidArgument("w_tensor must be rank 2 but is rank ",
433+
w_tensor->dims(), "."));
434+
OP_REQUIRES(
435+
ctx, wci_tensor->dims() == 1,
436+
errors::InvalidArgument("wci_tensor must be rank 1 but is rank ",
437+
wci_tensor->dims(), "."));
438+
OP_REQUIRES(
439+
ctx, wcf_tensor->dims() == 1,
440+
errors::InvalidArgument("wcf_tensor must be rank 1 but is rank ",
441+
wci_tensor->dims(), "."));
442+
OP_REQUIRES(
443+
ctx, wco_tensor->dims() == 1,
444+
errors::InvalidArgument("wco_tensor must be rank 1 but is rank ",
445+
wco_tensor->dims(), "."));
446+
OP_REQUIRES(ctx, b_tensor->dims() == 1,
447+
errors::InvalidArgument("b_tensor must be rank 1 but is rank ",
448+
b_tensor->dims(), "."));
449+
OP_REQUIRES(ctx, xh_tensor.dims() == 2,
450+
errors::InvalidArgument("xh_tensor must be rank 2 but is rank ",
451+
xh_tensor.dims(), "."));
452+
OP_REQUIRES(ctx, i_tensor->dims() == 2,
453+
errors::InvalidArgument("i_tensor must be rank 2 but is rank ",
454+
i_tensor->dims(), "."));
455+
OP_REQUIRES(ctx, cs_tensor->dims() == 2,
456+
errors::InvalidArgument("cs_tensor must be rank 2 but is rank ",
457+
cs_tensor->dims(), "."));
458+
OP_REQUIRES(ctx, f_tensor->dims() == 2,
459+
errors::InvalidArgument("f_tensor must be rank 2 but is rank ",
460+
f_tensor->dims(), "."));
461+
OP_REQUIRES(ctx, o_tensor->dims() == 2,
462+
errors::InvalidArgument("o_tensor must be rank 2 but is rank ",
463+
o_tensor->dims(), "."));
464+
OP_REQUIRES(ctx, ci_tensor->dims() == 2,
465+
errors::InvalidArgument("ci_tensor must be rank 2 but is rank ",
466+
ci_tensor->dims(), "."));
467+
OP_REQUIRES(ctx, co_tensor->dims() == 2,
468+
errors::InvalidArgument("co_tensor must be rank 2 but is rank ",
469+
co_tensor->dims(), "."));
470+
OP_REQUIRES(
471+
ctx, gates_tensor.dims() == 2,
472+
errors::InvalidArgument("gates_tensor must be rank 2 but is rank ",
473+
gates_tensor.dims(), "."));
474+
OP_REQUIRES(ctx, h_tensor->dims() == 2,
475+
errors::InvalidArgument("h_tensor must be rank 2 but is rank ",
476+
h_tensor->dims(), "."));
477+
419478
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS, gate_layout>(
420479
batch_size, input_size, cell_size)(
421480
ctx, device, forget_bias_, cell_clip_, use_peephole_,

tensorflow/python/kernel_tests/nn_ops/rnn_cell_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from tensorflow.python.framework import test_util
3434
from tensorflow.python.ops import array_ops
3535
from tensorflow.python.ops import control_flow_ops
36+
from tensorflow.python.ops import gen_rnn_ops
3637
from tensorflow.python.ops import gradients_impl
3738
from tensorflow.python.ops import init_ops
3839
from tensorflow.python.ops import math_ops
@@ -1323,6 +1324,36 @@ def testDynamicEquivalentToStaticRNN(self):
13231324
def testDynamicEquivalentToStaticRNNWithSequenceLength(self):
13241325
self._testDynamicEquivalentToStaticRNN(use_sequence_length=True)
13251326

1327+
@test_util.run_in_graph_and_eager_modes
1328+
def testLSTMBlockCellErrorHandling(self):
1329+
forget_bias = 1
1330+
cell_clip = 0
1331+
use_peephole = False
1332+
x = constant_op.constant(0.837607, shape=[28, 29], dtype=dtypes.float32)
1333+
cs_prev = constant_op.constant(0, shape=[28, 17], dtype=dtypes.float32)
1334+
h_prev = constant_op.constant(
1335+
0.592631638, shape=[28, 17], dtype=dtypes.float32)
1336+
w = constant_op.constant(0.887386262, shape=[46, 68], dtype=dtypes.float32)
1337+
wci = constant_op.constant(0, shape=[], dtype=dtypes.float32)
1338+
wcf = constant_op.constant(0, shape=[17], dtype=dtypes.float32)
1339+
wco = constant_op.constant(
1340+
0.592631638, shape=[28, 17], dtype=dtypes.float32)
1341+
b = constant_op.constant(0.75259006, shape=[68], dtype=dtypes.float32)
1342+
with self.assertRaises(errors_impl.InvalidArgumentError):
1343+
self.evaluate(
1344+
gen_rnn_ops.lstm_block_cell(
1345+
x=x,
1346+
cs_prev=cs_prev,
1347+
h_prev=h_prev,
1348+
w=w,
1349+
wci=wci,
1350+
wcf=wcf,
1351+
wco=wco,
1352+
b=b,
1353+
forget_bias=forget_bias,
1354+
cell_clip=cell_clip,
1355+
use_peephole=use_peephole))
1356+
13261357

13271358
class BidirectionalRNNTest(test.TestCase):
13281359

0 commit comments

Comments
 (0)