@@ -35,19 +35,18 @@ def testRNNDecoder(self):
3535 with self .test_session () as sess :
3636 with tf .variable_scope ("root" , initializer = tf .constant_initializer (0.5 )):
3737 inp = [tf .constant (0.5 , shape = [2 , 2 ]) for _ in xrange (2 )]
38- _ , enc_states = tf .nn .rnn (
38+ _ , enc_state = tf .nn .rnn (
3939 tf .nn .rnn_cell .GRUCell (2 ), inp , dtype = tf .float32 )
4040 dec_inp = [tf .constant (0.4 , shape = [2 , 2 ]) for _ in xrange (3 )]
4141 cell = tf .nn .rnn_cell .OutputProjectionWrapper (
4242 tf .nn .rnn_cell .GRUCell (2 ), 4 )
43- dec , mem = tf .nn .seq2seq .rnn_decoder (dec_inp , enc_states [ - 1 ] , cell )
43+ dec , mem = tf .nn .seq2seq .rnn_decoder (dec_inp , enc_state , cell )
4444 sess .run ([tf .initialize_all_variables ()])
4545 res = sess .run (dec )
4646 self .assertEqual (len (res ), 3 )
4747 self .assertEqual (res [0 ].shape , (2 , 4 ))
4848
49- res = sess .run (mem )
50- self .assertEqual (len (res ), 4 )
49+ res = sess .run ([mem ])
5150 self .assertEqual (res [0 ].shape , (2 , 2 ))
5251
5352 def testBasicRNNSeq2Seq (self ):
@@ -63,8 +62,7 @@ def testBasicRNNSeq2Seq(self):
6362 self .assertEqual (len (res ), 3 )
6463 self .assertEqual (res [0 ].shape , (2 , 4 ))
6564
66- res = sess .run (mem )
67- self .assertEqual (len (res ), 4 )
65+ res = sess .run ([mem ])
6866 self .assertEqual (res [0 ].shape , (2 , 2 ))
6967
7068 def testTiedRNNSeq2Seq (self ):
@@ -80,26 +78,26 @@ def testTiedRNNSeq2Seq(self):
8078 self .assertEqual (len (res ), 3 )
8179 self .assertEqual (res [0 ].shape , (2 , 4 ))
8280
83- res = sess .run (mem )
84- self .assertEqual (len (res ), 4 )
81+ res = sess .run ([ mem ] )
82+ self .assertEqual (len (res ), 1 )
8583 self .assertEqual (res [0 ].shape , (2 , 2 ))
8684
8785 def testEmbeddingRNNDecoder (self ):
8886 with self .test_session () as sess :
8987 with tf .variable_scope ("root" , initializer = tf .constant_initializer (0.5 )):
9088 inp = [tf .constant (0.5 , shape = [2 , 2 ]) for _ in xrange (2 )]
9189 cell = tf .nn .rnn_cell .BasicLSTMCell (2 )
92- _ , enc_states = tf .nn .rnn (cell , inp , dtype = tf .float32 )
90+ _ , enc_state = tf .nn .rnn (cell , inp , dtype = tf .float32 )
9391 dec_inp = [tf .constant (i , tf .int32 , shape = [2 ]) for i in xrange (3 )]
94- dec , mem = tf .nn .seq2seq .embedding_rnn_decoder (dec_inp , enc_states [ - 1 ] ,
92+ dec , mem = tf .nn .seq2seq .embedding_rnn_decoder (dec_inp , enc_state ,
9593 cell , 4 )
9694 sess .run ([tf .initialize_all_variables ()])
9795 res = sess .run (dec )
9896 self .assertEqual (len (res ), 3 )
9997 self .assertEqual (res [0 ].shape , (2 , 2 ))
10098
101- res = sess .run (mem )
102- self .assertEqual (len (res ), 4 )
99+ res = sess .run ([ mem ] )
100+ self .assertEqual (len (res ), 1 )
103101 self .assertEqual (res [0 ].shape , (2 , 4 ))
104102
105103 def testEmbeddingRNNSeq2Seq (self ):
@@ -115,8 +113,7 @@ def testEmbeddingRNNSeq2Seq(self):
115113 self .assertEqual (len (res ), 3 )
116114 self .assertEqual (res [0 ].shape , (2 , 5 ))
117115
118- res = sess .run (mem )
119- self .assertEqual (len (res ), 4 )
116+ res = sess .run ([mem ])
120117 self .assertEqual (res [0 ].shape , (2 , 4 ))
121118
122119 # Test externally provided output projection.
@@ -161,8 +158,7 @@ def testEmbeddingTiedRNNSeq2Seq(self):
161158 self .assertEqual (len (res ), 3 )
162159 self .assertEqual (res [0 ].shape , (2 , 5 ))
163160
164- res = sess .run (mem )
165- self .assertEqual (len (res ), 4 )
161+ res = sess .run ([mem ])
166162 self .assertEqual (res [0 ].shape , (2 , 4 ))
167163
168164 # Test externally provided output projection.
@@ -198,64 +194,61 @@ def testAttentionDecoder1(self):
198194 with tf .variable_scope ("root" , initializer = tf .constant_initializer (0.5 )):
199195 cell = tf .nn .rnn_cell .GRUCell (2 )
200196 inp = [tf .constant (0.5 , shape = [2 , 2 ]) for _ in xrange (2 )]
201- enc_outputs , enc_states = tf .nn .rnn (cell , inp , dtype = tf .float32 )
197+ enc_outputs , enc_state = tf .nn .rnn (cell , inp , dtype = tf .float32 )
202198 attn_states = tf .concat (1 , [tf .reshape (e , [- 1 , 1 , cell .output_size ])
203199 for e in enc_outputs ])
204200 dec_inp = [tf .constant (0.4 , shape = [2 , 2 ]) for _ in xrange (3 )]
205201 dec , mem = tf .nn .seq2seq .attention_decoder (
206- dec_inp , enc_states [ - 1 ] ,
202+ dec_inp , enc_state ,
207203 attn_states , cell , output_size = 4 )
208204 sess .run ([tf .initialize_all_variables ()])
209205 res = sess .run (dec )
210206 self .assertEqual (len (res ), 3 )
211207 self .assertEqual (res [0 ].shape , (2 , 4 ))
212208
213- res = sess .run (mem )
214- self .assertEqual (len (res ), 4 )
209+ res = sess .run ([mem ])
215210 self .assertEqual (res [0 ].shape , (2 , 2 ))
216211
217212 def testAttentionDecoder2 (self ):
218213 with self .test_session () as sess :
219214 with tf .variable_scope ("root" , initializer = tf .constant_initializer (0.5 )):
220215 cell = tf .nn .rnn_cell .GRUCell (2 )
221216 inp = [tf .constant (0.5 , shape = [2 , 2 ]) for _ in xrange (2 )]
222- enc_outputs , enc_states = tf .nn .rnn (cell , inp , dtype = tf .float32 )
217+ enc_outputs , enc_state = tf .nn .rnn (cell , inp , dtype = tf .float32 )
223218 attn_states = tf .concat (1 , [tf .reshape (e , [- 1 , 1 , cell .output_size ])
224219 for e in enc_outputs ])
225220 dec_inp = [tf .constant (0.4 , shape = [2 , 2 ]) for _ in xrange (3 )]
226221 dec , mem = tf .nn .seq2seq .attention_decoder (
227- dec_inp , enc_states [ - 1 ] ,
222+ dec_inp , enc_state ,
228223 attn_states , cell , output_size = 4 ,
229224 num_heads = 2 )
230225 sess .run ([tf .initialize_all_variables ()])
231226 res = sess .run (dec )
232227 self .assertEqual (len (res ), 3 )
233228 self .assertEqual (res [0 ].shape , (2 , 4 ))
234229
235- res = sess .run (mem )
236- self .assertEqual (len (res ), 4 )
230+ res = sess .run ([mem ])
237231 self .assertEqual (res [0 ].shape , (2 , 2 ))
238232
239233 def testEmbeddingAttentionDecoder (self ):
240234 with self .test_session () as sess :
241235 with tf .variable_scope ("root" , initializer = tf .constant_initializer (0.5 )):
242236 inp = [tf .constant (0.5 , shape = [2 , 2 ]) for _ in xrange (2 )]
243237 cell = tf .nn .rnn_cell .GRUCell (2 )
244- enc_outputs , enc_states = tf .nn .rnn (cell , inp , dtype = tf .float32 )
238+ enc_outputs , enc_state = tf .nn .rnn (cell , inp , dtype = tf .float32 )
245239 attn_states = tf .concat (1 , [tf .reshape (e , [- 1 , 1 , cell .output_size ])
246240 for e in enc_outputs ])
247241 dec_inp = [tf .constant (i , tf .int32 , shape = [2 ]) for i in xrange (3 )]
248242 dec , mem = tf .nn .seq2seq .embedding_attention_decoder (
249- dec_inp , enc_states [ - 1 ] ,
243+ dec_inp , enc_state ,
250244 attn_states , cell , 4 ,
251245 output_size = 3 )
252246 sess .run ([tf .initialize_all_variables ()])
253247 res = sess .run (dec )
254248 self .assertEqual (len (res ), 3 )
255249 self .assertEqual (res [0 ].shape , (2 , 3 ))
256250
257- res = sess .run (mem )
258- self .assertEqual (len (res ), 4 )
251+ res = sess .run ([mem ])
259252 self .assertEqual (res [0 ].shape , (2 , 2 ))
260253
261254 def testEmbeddingAttentionSeq2Seq (self ):
@@ -271,8 +264,7 @@ def testEmbeddingAttentionSeq2Seq(self):
271264 self .assertEqual (len (res ), 3 )
272265 self .assertEqual (res [0 ].shape , (2 , 5 ))
273266
274- res = sess .run (mem )
275- self .assertEqual (len (res ), 4 )
267+ res = sess .run ([mem ])
276268 self .assertEqual (res [0 ].shape , (2 , 4 ))
277269
278270 # Test externally provided output projection.
0 commit comments