2727from tensorflow .python .framework import dtypes
2828from tensorflow .python .framework import ops
2929from tensorflow .python .framework import tensor_shape
30+ from tensorflow .python .layers import base as layers_base
3031from tensorflow .python .util import nest
3132
3233
@@ -44,35 +45,62 @@ class BasicDecoderOutput(
4445class BasicDecoder (decoder .Decoder ):
4546 """Basic sampling decoder."""
4647
47- def __init__ (self , cell , helper , initial_state ):
48+ def __init__ (self , cell , helper , initial_state , output_layer = None ):
4849 """Initialize BasicDecoder.
4950
5051 Args:
5152 cell: An `RNNCell` instance.
5253 helper: A `Helper` instance.
5354 initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
55+ output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
56+ `tf.layers.Dense`. Optional layer to apply to the RNN output prior
57+ to storing the result or sampling.
5458
5559 Raises:
56- TypeError: if `cell` is not an instance of `RNNCell` or `helper`
57- is not an instance of `Helper`.
60+ TypeError: if `cell` is not an instance of `RNNCell`, `helper`
61+ is not an instance of `Helper`, or `output_layer` is not an instance
62+ of `tf.layers.Layer`.
5863 """
5964 if not isinstance (cell , core_rnn_cell .RNNCell ):
6065 raise TypeError ("cell must be an RNNCell, received: %s" % type (cell ))
6166 if not isinstance (helper , helper_py .Helper ):
6267 raise TypeError ("helper must be a Helper, received: %s" % type (helper ))
68+ if (output_layer is not None
69+ and not isinstance (output_layer , layers_base ._Layer )): # pylint: disable=protected-access
70+ raise TypeError (
71+ "output_layer must be a Layer, received: %s" % type (output_layer ))
6372 self ._cell = cell
6473 self ._helper = helper
6574 self ._initial_state = initial_state
75+ self ._output_layer = output_layer
6676
6777 @property
6878 def batch_size (self ):
6979 return self ._helper .batch_size
7080
81+ def _rnn_output_size (self ):
82+ size = self ._cell .output_size
83+ if self ._output_layer is None :
84+ return size
85+ else :
86+ # To use layer's compute_output_shape, we need to convert the
87+ # RNNCell's output_size entries into shapes with an unknown
88+ # batch size. We then pass this through the layer's
89+ # compute_output_shape and read off all but the first (batch)
90+ # dimensions to get the output size of the rnn with the layer
91+ # applied to the top.
92+ output_shape_with_unknown_batch = nest .map_structure (
93+ lambda s : tensor_shape .TensorShape ([None ]).concatenate (s ),
94+ size )
95+ layer_output_shape = self ._output_layer ._compute_output_shape ( # pylint: disable=protected-access
96+ output_shape_with_unknown_batch )
97+ return nest .map_structure (lambda s : s [1 :], layer_output_shape )
98+
7199 @property
72100 def output_size (self ):
73101 # Return the cell output and the id
74102 return BasicDecoderOutput (
75- rnn_output = self ._cell . output_size ,
103+ rnn_output = self ._rnn_output_size () ,
76104 sample_id = tensor_shape .TensorShape ([]))
77105
78106 @property
@@ -82,7 +110,7 @@ def output_dtype(self):
82110 # Return that structure and int32 (the id)
83111 dtype = nest .flatten (self ._initial_state )[0 ].dtype
84112 return BasicDecoderOutput (
85- nest .map_structure (lambda _ : dtype , self ._cell . output_size ),
113+ nest .map_structure (lambda _ : dtype , self ._rnn_output_size () ),
86114 dtypes .int32 )
87115
88116 def initialize (self , name = None ):
@@ -110,6 +138,8 @@ def step(self, time, inputs, state, name=None):
110138 """
111139 with ops .name_scope (name , "BasicDecoderStep" , (time , inputs , state )):
112140 cell_outputs , cell_state = self ._cell (inputs , state )
141+ if self ._output_layer is not None :
142+ cell_outputs = self ._output_layer (cell_outputs )
113143 sample_ids = self ._helper .sample (
114144 time = time , outputs = cell_outputs , state = cell_state )
115145 (finished , next_inputs , next_state ) = self ._helper .next_inputs (
0 commit comments