@@ -416,6 +416,65 @@ class LSTMBlockCellOp : public OpKernel {
416416
417417 const Device& device = ctx->eigen_device <Device>();
418418
419+ // Sanity check that each of the tensors have the required NDIMS.
420+ OP_REQUIRES (ctx, x_tensor->dims () == 2 ,
421+ errors::InvalidArgument (" x_tensor must be rank 2 but is rank " ,
422+ x_tensor->dims (), " ." ));
423+ OP_REQUIRES (
424+ ctx, cs_prev_tensor->dims () == 2 ,
425+ errors::InvalidArgument (" cs_prev_tensor must be rank 2 but is rank " ,
426+ cs_prev_tensor->dims (), " ." ));
427+ OP_REQUIRES (
428+ ctx, h_prev_tensor->dims () == 2 ,
429+ errors::InvalidArgument (" h_prev_tensor must be rank 2 but is rank " ,
430+ h_prev_tensor->dims (), " ." ));
431+ OP_REQUIRES (ctx, w_tensor->dims () == 2 ,
432+ errors::InvalidArgument (" w_tensor must be rank 2 but is rank " ,
433+ w_tensor->dims (), " ." ));
434+ OP_REQUIRES (
435+ ctx, wci_tensor->dims () == 1 ,
436+ errors::InvalidArgument (" wci_tensor must be rank 1 but is rank " ,
437+ wci_tensor->dims (), " ." ));
438+ OP_REQUIRES (
439+ ctx, wcf_tensor->dims () == 1 ,
440+ errors::InvalidArgument (" wcf_tensor must be rank 1 but is rank " ,
441+ wci_tensor->dims (), " ." ));
442+ OP_REQUIRES (
443+ ctx, wco_tensor->dims () == 1 ,
444+ errors::InvalidArgument (" wco_tensor must be rank 1 but is rank " ,
445+ wco_tensor->dims (), " ." ));
446+ OP_REQUIRES (ctx, b_tensor->dims () == 1 ,
447+ errors::InvalidArgument (" b_tensor must be rank 1 but is rank " ,
448+ b_tensor->dims (), " ." ));
449+ OP_REQUIRES (ctx, xh_tensor.dims () == 2 ,
450+ errors::InvalidArgument (" xh_tensor must be rank 2 but is rank " ,
451+ xh_tensor.dims (), " ." ));
452+ OP_REQUIRES (ctx, i_tensor->dims () == 2 ,
453+ errors::InvalidArgument (" i_tensor must be rank 2 but is rank " ,
454+ i_tensor->dims (), " ." ));
455+ OP_REQUIRES (ctx, cs_tensor->dims () == 2 ,
456+ errors::InvalidArgument (" cs_tensor must be rank 2 but is rank " ,
457+ cs_tensor->dims (), " ." ));
458+ OP_REQUIRES (ctx, f_tensor->dims () == 2 ,
459+ errors::InvalidArgument (" f_tensor must be rank 2 but is rank " ,
460+ f_tensor->dims (), " ." ));
461+ OP_REQUIRES (ctx, o_tensor->dims () == 2 ,
462+ errors::InvalidArgument (" o_tensor must be rank 2 but is rank " ,
463+ o_tensor->dims (), " ." ));
464+ OP_REQUIRES (ctx, ci_tensor->dims () == 2 ,
465+ errors::InvalidArgument (" ci_tensor must be rank 2 but is rank " ,
466+ ci_tensor->dims (), " ." ));
467+ OP_REQUIRES (ctx, co_tensor->dims () == 2 ,
468+ errors::InvalidArgument (" co_tensor must be rank 2 but is rank " ,
469+ co_tensor->dims (), " ." ));
470+ OP_REQUIRES (
471+ ctx, gates_tensor.dims () == 2 ,
472+ errors::InvalidArgument (" gates_tensor must be rank 2 but is rank " ,
473+ gates_tensor.dims (), " ." ));
474+ OP_REQUIRES (ctx, h_tensor->dims () == 2 ,
475+ errors::InvalidArgument (" h_tensor must be rank 2 but is rank " ,
476+ h_tensor->dims (), " ." ));
477+
419478 functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS, gate_layout>(
420479 batch_size, input_size, cell_size)(
421480 ctx, device, forget_bias_, cell_clip_, use_peephole_,
0 commit comments