Skip to content

Commit fea55e1

Browse files
ebrevdoVijay Vasudevan
authored andcommitted
Breaking change in TF RNN python api: Return the final state instead of the
list of states when calling tf.nn.rnn() and tf.nn.state_saving_rnn() This is necessary for further cleanup of RNN state propagation code (currently dynamic RNN calculations when passing sequence_length do not return the proper final state, this is a necessary fix to make that fix efficient). Change: 113203893
1 parent e594939 commit fea55e1

6 files changed

Lines changed: 113 additions & 139 deletions

File tree

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
* ASSERT_OK / EXPECT_OK macros conflicted with external projects, so they were
3232
renamed TF_ASSERT_OK, TF_EXPECT_OK. The existing macros are currently
3333
maintained for short-term compatibility but will be removed.
34+
* The non-public `nn.rnn` and the various `nn.seq2seq` methods now return
35+
just the final state instead of the list of all states.
36+
3437

3538
## Bug fixes
3639

tensorflow/models/rnn/ptb/ptb_word_lm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,16 +117,14 @@ def __init__(self, is_training, config):
117117
# from tensorflow.models.rnn import rnn
118118
# inputs = [tf.squeeze(input_, [1])
119119
# for input_ in tf.split(1, num_steps, inputs)]
120-
# outputs, states = rnn.rnn(cell, inputs, initial_state=self._initial_state)
120+
# outputs, state = rnn.rnn(cell, inputs, initial_state=self._initial_state)
121121
outputs = []
122-
states = []
123122
state = self._initial_state
124123
with tf.variable_scope("RNN"):
125124
for time_step in range(num_steps):
126125
if time_step > 0: tf.get_variable_scope().reuse_variables()
127126
(cell_output, state) = cell(inputs[:, time_step, :], state)
128127
outputs.append(cell_output)
129-
states.append(state)
130128

131129
output = tf.reshape(tf.concat(1, outputs), [-1, size])
132130
softmax_w = tf.get_variable("softmax_w", [size, vocab_size])
@@ -137,7 +135,7 @@ def __init__(self, is_training, config):
137135
[tf.ones([batch_size * num_steps])],
138136
vocab_size)
139137
self._cost = cost = tf.reduce_sum(loss) / batch_size
140-
self._final_state = states[-1]
138+
self._final_state = state
141139

142140
if not is_training:
143141
return

tensorflow/python/kernel_tests/rnn_test.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@ def testRNN(self):
6868
max_length = 8 # unrolled up to this length
6969
inputs = max_length * [
7070
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
71-
outputs, states = tf.nn.rnn(cell, inputs, dtype=tf.float32)
71+
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
7272
self.assertEqual(len(outputs), len(inputs))
7373
for out, inp in zip(outputs, inputs):
7474
self.assertEqual(out.get_shape(), inp.get_shape())
7575
self.assertEqual(out.dtype, inp.dtype)
7676

7777
with self.test_session(use_gpu=False) as sess:
7878
input_value = np.random.randn(batch_size, input_size)
79-
values = sess.run(outputs + [states[-1]],
79+
values = sess.run(outputs + [state],
8080
feed_dict={inputs[0]: input_value})
8181

8282
# Outputs
@@ -98,7 +98,7 @@ def testDropout(self):
9898
inputs = max_length * [
9999
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
100100
with tf.variable_scope("share_scope"):
101-
outputs, states = tf.nn.rnn(cell, inputs, dtype=tf.float32)
101+
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
102102
with tf.variable_scope("drop_scope"):
103103
dropped_outputs, _ = tf.nn.rnn(
104104
full_dropout_cell, inputs, dtype=tf.float32)
@@ -109,7 +109,7 @@ def testDropout(self):
109109

110110
with self.test_session(use_gpu=False) as sess:
111111
input_value = np.random.randn(batch_size, input_size)
112-
values = sess.run(outputs + [states[-1]],
112+
values = sess.run(outputs + [state],
113113
feed_dict={inputs[0]: input_value})
114114
full_dropout_values = sess.run(dropped_outputs,
115115
feed_dict={inputs[0]: input_value})
@@ -128,31 +128,29 @@ def testDynamicCalculation(self):
128128
inputs = max_length * [
129129
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
130130
with tf.variable_scope("drop_scope"):
131-
dynamic_outputs, dynamic_states = tf.nn.rnn(
131+
dynamic_outputs, dynamic_state = tf.nn.rnn(
132132
cell, inputs, sequence_length=sequence_length, dtype=tf.float32)
133133
self.assertEqual(len(dynamic_outputs), len(inputs))
134-
self.assertEqual(len(dynamic_states), len(inputs))
135134

136135
with self.test_session(use_gpu=False) as sess:
137136
input_value = np.random.randn(batch_size, input_size)
138137
dynamic_values = sess.run(dynamic_outputs,
139138
feed_dict={inputs[0]: input_value,
140139
sequence_length: [2, 3]})
141-
dynamic_state_values = sess.run(dynamic_states,
140+
dynamic_state_values = sess.run([dynamic_state],
142141
feed_dict={inputs[0]: input_value,
143142
sequence_length: [2, 3]})
144143

145144
# fully calculated for t = 0, 1, 2
146145
for v in dynamic_values[:3]:
147146
self.assertAllClose(v, input_value + 1.0)
148-
for vi, v in enumerate(dynamic_state_values[:3]):
149-
self.assertAllEqual(v, 1.0 * (vi + 1) *
150-
np.ones((batch_size, input_size)))
151147
# zeros for t = 3+
152148
for v in dynamic_values[3:]:
153149
self.assertAllEqual(v, np.zeros_like(input_value))
154-
for v in dynamic_state_values[3:]:
155-
self.assertAllEqual(v, np.zeros_like(input_value))
150+
# final state is frozen from state at max(sequence_lengths) == 2
151+
self.assertAllEqual(
152+
dynamic_state_values[0],
153+
1.0 * (2 + 1) * np.ones((batch_size, input_size)))
156154

157155

158156
class LSTMTest(tf.test.TestCase):
@@ -219,7 +217,7 @@ def _testNoProjNoShardingSimpleStateSaver(self, use_gpu):
219217
inputs = max_length * [
220218
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
221219
with tf.variable_scope("share_scope"):
222-
outputs, states = tf.nn.state_saving_rnn(
220+
outputs, state = tf.nn.state_saving_rnn(
223221
cell, inputs, state_saver=state_saver, state_name="save_lstm")
224222
self.assertEqual(len(outputs), len(inputs))
225223
for out in outputs:
@@ -228,7 +226,7 @@ def _testNoProjNoShardingSimpleStateSaver(self, use_gpu):
228226
tf.initialize_all_variables().run()
229227
input_value = np.random.randn(batch_size, input_size)
230228
(last_state_value, saved_state_value) = sess.run(
231-
[states[-1], state_saver.saved_state],
229+
[state, state_saver.saved_state],
232230
feed_dict={inputs[0]: input_value})
233231
self.assertAllEqual(last_state_value, saved_state_value)
234232

@@ -340,10 +338,10 @@ def _testShardNoShardEquivalentOutput(self, use_gpu):
340338
initializer=initializer, num_proj=num_proj)
341339

342340
with tf.variable_scope("noshard_scope"):
343-
outputs_noshard, states_noshard = tf.nn.rnn(
341+
outputs_noshard, state_noshard = tf.nn.rnn(
344342
cell_noshard, inputs, dtype=tf.float32)
345343
with tf.variable_scope("shard_scope"):
346-
outputs_shard, states_shard = tf.nn.rnn(
344+
outputs_shard, state_shard = tf.nn.rnn(
347345
cell_shard, inputs, dtype=tf.float32)
348346

349347
self.assertEqual(len(outputs_noshard), len(inputs))
@@ -354,8 +352,8 @@ def _testShardNoShardEquivalentOutput(self, use_gpu):
354352
feeds = dict((x, input_value) for x in inputs)
355353
values_noshard = sess.run(outputs_noshard, feed_dict=feeds)
356354
values_shard = sess.run(outputs_shard, feed_dict=feeds)
357-
state_values_noshard = sess.run(states_noshard, feed_dict=feeds)
358-
state_values_shard = sess.run(states_shard, feed_dict=feeds)
355+
state_values_noshard = sess.run([state_noshard], feed_dict=feeds)
356+
state_values_shard = sess.run([state_shard], feed_dict=feeds)
359357
self.assertEqual(len(values_noshard), len(values_shard))
360358
self.assertEqual(len(state_values_noshard), len(state_values_shard))
361359
for (v_noshard, v_shard) in zip(values_noshard, values_shard):
@@ -389,22 +387,21 @@ def _testDoubleInputWithDropoutAndDynamicCalculation(
389387
initializer=initializer)
390388
dropout_cell = tf.nn.rnn_cell.DropoutWrapper(cell, 0.5, seed=0)
391389

392-
outputs, states = tf.nn.rnn(
390+
outputs, state = tf.nn.rnn(
393391
dropout_cell, inputs, sequence_length=sequence_length,
394392
initial_state=cell.zero_state(batch_size, tf.float64))
395393

396394
self.assertEqual(len(outputs), len(inputs))
397-
self.assertEqual(len(outputs), len(states))
398395

399396
tf.initialize_all_variables().run(feed_dict={sequence_length: [2, 3]})
400397
input_value = np.asarray(np.random.randn(batch_size, input_size),
401398
dtype=np.float64)
402399
values = sess.run(outputs, feed_dict={inputs[0]: input_value,
403400
sequence_length: [2, 3]})
404-
state_values = sess.run(states, feed_dict={inputs[0]: input_value,
401+
state_value = sess.run([state], feed_dict={inputs[0]: input_value,
405402
sequence_length: [2, 3]})
406403
self.assertEqual(values[0].dtype, input_value.dtype)
407-
self.assertEqual(state_values[0].dtype, input_value.dtype)
404+
self.assertEqual(state_value[0].dtype, input_value.dtype)
408405

409406
def testSharingWeightsWithReuse(self):
410407
num_units = 3

tensorflow/python/kernel_tests/seq2seq_test.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,18 @@ def testRNNDecoder(self):
3535
with self.test_session() as sess:
3636
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
3737
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
38-
_, enc_states = tf.nn.rnn(
38+
_, enc_state = tf.nn.rnn(
3939
tf.nn.rnn_cell.GRUCell(2), inp, dtype=tf.float32)
4040
dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
4141
cell = tf.nn.rnn_cell.OutputProjectionWrapper(
4242
tf.nn.rnn_cell.GRUCell(2), 4)
43-
dec, mem = tf.nn.seq2seq.rnn_decoder(dec_inp, enc_states[-1], cell)
43+
dec, mem = tf.nn.seq2seq.rnn_decoder(dec_inp, enc_state, cell)
4444
sess.run([tf.initialize_all_variables()])
4545
res = sess.run(dec)
4646
self.assertEqual(len(res), 3)
4747
self.assertEqual(res[0].shape, (2, 4))
4848

49-
res = sess.run(mem)
50-
self.assertEqual(len(res), 4)
49+
res = sess.run([mem])
5150
self.assertEqual(res[0].shape, (2, 2))
5251

5352
def testBasicRNNSeq2Seq(self):
@@ -63,8 +62,7 @@ def testBasicRNNSeq2Seq(self):
6362
self.assertEqual(len(res), 3)
6463
self.assertEqual(res[0].shape, (2, 4))
6564

66-
res = sess.run(mem)
67-
self.assertEqual(len(res), 4)
65+
res = sess.run([mem])
6866
self.assertEqual(res[0].shape, (2, 2))
6967

7068
def testTiedRNNSeq2Seq(self):
@@ -80,26 +78,26 @@ def testTiedRNNSeq2Seq(self):
8078
self.assertEqual(len(res), 3)
8179
self.assertEqual(res[0].shape, (2, 4))
8280

83-
res = sess.run(mem)
84-
self.assertEqual(len(res), 4)
81+
res = sess.run([mem])
82+
self.assertEqual(len(res), 1)
8583
self.assertEqual(res[0].shape, (2, 2))
8684

8785
def testEmbeddingRNNDecoder(self):
8886
with self.test_session() as sess:
8987
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
9088
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
9189
cell = tf.nn.rnn_cell.BasicLSTMCell(2)
92-
_, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
90+
_, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
9391
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
94-
dec, mem = tf.nn.seq2seq.embedding_rnn_decoder(dec_inp, enc_states[-1],
92+
dec, mem = tf.nn.seq2seq.embedding_rnn_decoder(dec_inp, enc_state,
9593
cell, 4)
9694
sess.run([tf.initialize_all_variables()])
9795
res = sess.run(dec)
9896
self.assertEqual(len(res), 3)
9997
self.assertEqual(res[0].shape, (2, 2))
10098

101-
res = sess.run(mem)
102-
self.assertEqual(len(res), 4)
99+
res = sess.run([mem])
100+
self.assertEqual(len(res), 1)
103101
self.assertEqual(res[0].shape, (2, 4))
104102

105103
def testEmbeddingRNNSeq2Seq(self):
@@ -115,8 +113,7 @@ def testEmbeddingRNNSeq2Seq(self):
115113
self.assertEqual(len(res), 3)
116114
self.assertEqual(res[0].shape, (2, 5))
117115

118-
res = sess.run(mem)
119-
self.assertEqual(len(res), 4)
116+
res = sess.run([mem])
120117
self.assertEqual(res[0].shape, (2, 4))
121118

122119
# Test externally provided output projection.
@@ -161,8 +158,7 @@ def testEmbeddingTiedRNNSeq2Seq(self):
161158
self.assertEqual(len(res), 3)
162159
self.assertEqual(res[0].shape, (2, 5))
163160

164-
res = sess.run(mem)
165-
self.assertEqual(len(res), 4)
161+
res = sess.run([mem])
166162
self.assertEqual(res[0].shape, (2, 4))
167163

168164
# Test externally provided output projection.
@@ -198,64 +194,61 @@ def testAttentionDecoder1(self):
198194
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
199195
cell = tf.nn.rnn_cell.GRUCell(2)
200196
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
201-
enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
197+
enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
202198
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
203199
for e in enc_outputs])
204200
dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
205201
dec, mem = tf.nn.seq2seq.attention_decoder(
206-
dec_inp, enc_states[-1],
202+
dec_inp, enc_state,
207203
attn_states, cell, output_size=4)
208204
sess.run([tf.initialize_all_variables()])
209205
res = sess.run(dec)
210206
self.assertEqual(len(res), 3)
211207
self.assertEqual(res[0].shape, (2, 4))
212208

213-
res = sess.run(mem)
214-
self.assertEqual(len(res), 4)
209+
res = sess.run([mem])
215210
self.assertEqual(res[0].shape, (2, 2))
216211

217212
def testAttentionDecoder2(self):
218213
with self.test_session() as sess:
219214
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
220215
cell = tf.nn.rnn_cell.GRUCell(2)
221216
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
222-
enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
217+
enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
223218
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
224219
for e in enc_outputs])
225220
dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
226221
dec, mem = tf.nn.seq2seq.attention_decoder(
227-
dec_inp, enc_states[-1],
222+
dec_inp, enc_state,
228223
attn_states, cell, output_size=4,
229224
num_heads=2)
230225
sess.run([tf.initialize_all_variables()])
231226
res = sess.run(dec)
232227
self.assertEqual(len(res), 3)
233228
self.assertEqual(res[0].shape, (2, 4))
234229

235-
res = sess.run(mem)
236-
self.assertEqual(len(res), 4)
230+
res = sess.run([mem])
237231
self.assertEqual(res[0].shape, (2, 2))
238232

239233
def testEmbeddingAttentionDecoder(self):
240234
with self.test_session() as sess:
241235
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
242236
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
243237
cell = tf.nn.rnn_cell.GRUCell(2)
244-
enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
238+
enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
245239
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
246240
for e in enc_outputs])
247241
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
248242
dec, mem = tf.nn.seq2seq.embedding_attention_decoder(
249-
dec_inp, enc_states[-1],
243+
dec_inp, enc_state,
250244
attn_states, cell, 4,
251245
output_size=3)
252246
sess.run([tf.initialize_all_variables()])
253247
res = sess.run(dec)
254248
self.assertEqual(len(res), 3)
255249
self.assertEqual(res[0].shape, (2, 3))
256250

257-
res = sess.run(mem)
258-
self.assertEqual(len(res), 4)
251+
res = sess.run([mem])
259252
self.assertEqual(res[0].shape, (2, 2))
260253

261254
def testEmbeddingAttentionSeq2Seq(self):
@@ -271,8 +264,7 @@ def testEmbeddingAttentionSeq2Seq(self):
271264
self.assertEqual(len(res), 3)
272265
self.assertEqual(res[0].shape, (2, 5))
273266

274-
res = sess.run(mem)
275-
self.assertEqual(len(res), 4)
267+
res = sess.run([mem])
276268
self.assertEqual(res[0].shape, (2, 4))
277269

278270
# Test externally provided output projection.

0 commit comments

Comments
 (0)