Skip to content

Commit feae294

Browse files
ebrevdotensorflower-gardener
authored andcommitted
Add "output_layer" argument to the BasicDecoder.
Change: 147845195
1 parent 30651a9 commit feae294

2 files changed

Lines changed: 59 additions & 9 deletions

File tree

tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,35 +35,44 @@
3535
from tensorflow.python.framework import constant_op
3636
from tensorflow.python.framework import dtypes
3737
from tensorflow.python.framework import tensor_shape
38+
from tensorflow.python.layers import core as layers_core
3839
from tensorflow.python.ops import variables
3940
from tensorflow.python.platform import test
4041
# pylint: enable=g-import-not-at-top
4142

4243

4344
class BasicDecoderTest(test.TestCase):
4445

45-
def testStepWithTrainingHelper(self):
46+
def _testStepWithTrainingHelper(self, use_output_layer):
4647
sequence_length = [3, 4, 3, 1, 0]
4748
batch_size = 5
4849
max_time = 8
4950
input_depth = 7
5051
cell_depth = 10
52+
output_layer_depth = 3
5153

5254
with self.test_session() as sess:
5355
inputs = np.random.randn(batch_size, max_time,
5456
input_depth).astype(np.float32)
5557
cell = core_rnn_cell.LSTMCell(cell_depth)
5658
helper = helper_py.TrainingHelper(
5759
inputs, sequence_length, time_major=False)
60+
if use_output_layer:
61+
output_layer = layers_core.Dense(output_layer_depth, use_bias=False)
62+
expected_output_depth = output_layer_depth
63+
else:
64+
output_layer = None
65+
expected_output_depth = cell_depth
5866
my_decoder = basic_decoder.BasicDecoder(
5967
cell=cell,
6068
helper=helper,
6169
initial_state=cell.zero_state(
62-
dtype=dtypes.float32, batch_size=batch_size))
70+
dtype=dtypes.float32, batch_size=batch_size),
71+
output_layer=output_layer)
6372
output_size = my_decoder.output_size
6473
output_dtype = my_decoder.output_dtype
6574
self.assertEqual(
66-
basic_decoder.BasicDecoderOutput(cell_depth,
75+
basic_decoder.BasicDecoderOutput(expected_output_depth,
6776
tensor_shape.TensorShape([])),
6877
output_size)
6978
self.assertEqual(
@@ -80,13 +89,18 @@ def testStepWithTrainingHelper(self):
8089
self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
8190
self.assertTrue(
8291
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
83-
self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
92+
self.assertEqual((batch_size, expected_output_depth),
93+
step_outputs[0].get_shape())
8494
self.assertEqual((batch_size,), step_outputs[1].get_shape())
8595
self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
8696
self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
8797
self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
8898
self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())
8999

100+
if use_output_layer:
101+
# The output layer was accessed
102+
self.assertEqual(len(output_layer.variables), 1)
103+
90104
sess.run(variables.global_variables_initializer())
91105
sess_results = sess.run({
92106
"batch_size": batch_size_t,
@@ -107,6 +121,12 @@ def testStepWithTrainingHelper(self):
107121
np.argmax(sess_results["step_outputs"].rnn_output, -1),
108122
sess_results["step_outputs"].sample_id)
109123

124+
def testStepWithTrainingHelperNoOutputLayer(self):
125+
self._testStepWithTrainingHelper(use_output_layer=False)
126+
127+
def testStepWithTrainingHelperWithOutputLayer(self):
128+
self._testStepWithTrainingHelper(use_output_layer=True)
129+
110130
def testStepWithGreedyEmbeddingHelper(self):
111131
batch_size = 5
112132
vocabulary_size = 7

tensorflow/contrib/seq2seq/python/ops/basic_decoder.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tensorflow.python.framework import dtypes
2828
from tensorflow.python.framework import ops
2929
from tensorflow.python.framework import tensor_shape
30+
from tensorflow.python.layers import base as layers_base
3031
from tensorflow.python.util import nest
3132

3233

@@ -44,35 +45,62 @@ class BasicDecoderOutput(
4445
class BasicDecoder(decoder.Decoder):
4546
"""Basic sampling decoder."""
4647

47-
def __init__(self, cell, helper, initial_state):
48+
def __init__(self, cell, helper, initial_state, output_layer=None):
4849
"""Initialize BasicDecoder.
4950
5051
Args:
5152
cell: An `RNNCell` instance.
5253
helper: A `Helper` instance.
5354
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
55+
output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
56+
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
57+
to storing the result or sampling.
5458
5559
Raises:
56-
TypeError: if `cell` is not an instance of `RNNCell` or `helper`
57-
is not an instance of `Helper`.
60+
TypeError: if `cell` is not an instance of `RNNCell`, `helper`
61+
is not an instance of `Helper`, or `output_layer` is not an instance
62+
of `tf.layers.Layer`.
5863
"""
5964
if not isinstance(cell, core_rnn_cell.RNNCell):
6065
raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
6166
if not isinstance(helper, helper_py.Helper):
6267
raise TypeError("helper must be a Helper, received: %s" % type(helper))
68+
if (output_layer is not None
69+
and not isinstance(output_layer, layers_base._Layer)): # pylint: disable=protected-access
70+
raise TypeError(
71+
"output_layer must be a Layer, received: %s" % type(output_layer))
6372
self._cell = cell
6473
self._helper = helper
6574
self._initial_state = initial_state
75+
self._output_layer = output_layer
6676

6777
@property
6878
def batch_size(self):
6979
return self._helper.batch_size
7080

81+
def _rnn_output_size(self):
82+
size = self._cell.output_size
83+
if self._output_layer is None:
84+
return size
85+
else:
86+
# To use layer's compute_output_shape, we need to convert the
87+
# RNNCell's output_size entries into shapes with an unknown
88+
# batch size. We then pass this through the layer's
89+
# compute_output_shape and read off all but the first (batch)
90+
# dimensions to get the output size of the rnn with the layer
91+
# applied to the top.
92+
output_shape_with_unknown_batch = nest.map_structure(
93+
lambda s: tensor_shape.TensorShape([None]).concatenate(s),
94+
size)
95+
layer_output_shape = self._output_layer._compute_output_shape( # pylint: disable=protected-access
96+
output_shape_with_unknown_batch)
97+
return nest.map_structure(lambda s: s[1:], layer_output_shape)
98+
7199
@property
72100
def output_size(self):
73101
# Return the cell output and the id
74102
return BasicDecoderOutput(
75-
rnn_output=self._cell.output_size,
103+
rnn_output=self._rnn_output_size(),
76104
sample_id=tensor_shape.TensorShape([]))
77105

78106
@property
@@ -82,7 +110,7 @@ def output_dtype(self):
82110
# Return that structure and int32 (the id)
83111
dtype = nest.flatten(self._initial_state)[0].dtype
84112
return BasicDecoderOutput(
85-
nest.map_structure(lambda _: dtype, self._cell.output_size),
113+
nest.map_structure(lambda _: dtype, self._rnn_output_size()),
86114
dtypes.int32)
87115

88116
def initialize(self, name=None):
@@ -110,6 +138,8 @@ def step(self, time, inputs, state, name=None):
110138
"""
111139
with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
112140
cell_outputs, cell_state = self._cell(inputs, state)
141+
if self._output_layer is not None:
142+
cell_outputs = self._output_layer(cell_outputs)
113143
sample_ids = self._helper.sample(
114144
time=time, outputs=cell_outputs, state=cell_state)
115145
(finished, next_inputs, next_state) = self._helper.next_inputs(

0 commit comments

Comments
 (0)