@@ -70,7 +70,7 @@ def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True):
7070 num_directions = len (inners )
7171 total_layers = num_layers * num_directions
7272
73- def forward (input , hidden , weight ):
73+ def forward (input , hidden , weight , batch_sizes ):
7474 assert (len (weight ) == total_layers )
7575 next_hidden = []
7676
@@ -82,7 +82,7 @@ def forward(input, hidden, weight):
8282 for j , inner in enumerate (inners ):
8383 l = i * num_directions + j
8484
85- hy , output = inner (input , hidden [l ], weight [l ])
85+ hy , output = inner (input , hidden [l ], weight [l ], batch_sizes )
8686 next_hidden .append (hy )
8787 all_output .append (output )
8888
@@ -107,7 +107,7 @@ def forward(input, hidden, weight):
107107
108108
109109def Recurrent (inner , reverse = False ):
110- def forward (input , hidden , weight ):
110+ def forward (input , hidden , weight , batch_sizes ):
111111 output = []
112112 steps = range (input .size (0 ) - 1 , - 1 , - 1 ) if reverse else range (input .size (0 ))
113113 for i in steps :
@@ -124,17 +124,16 @@ def forward(input, hidden, weight):
124124 return forward
125125
126126
127- def variable_recurrent_factory (batch_sizes ):
128- def fac (inner , reverse = False ):
129- if reverse :
130- return VariableRecurrentReverse (batch_sizes , inner )
131- else :
132- return VariableRecurrent (batch_sizes , inner )
133- return fac
127+ def variable_recurrent_factory (inner , reverse = False ):
128+ if reverse :
129+ return VariableRecurrentReverse (inner )
130+ else :
131+ return VariableRecurrent (inner )
134132
135133
136- def VariableRecurrent (batch_sizes , inner ):
137- def forward (input , hidden , weight ):
134+ def VariableRecurrent (inner ):
135+ def forward (input , hidden , weight , batch_sizes ):
136+
138137 output = []
139138 input_offset = 0
140139 last_batch_size = batch_sizes [0 ]
@@ -172,8 +171,8 @@ def forward(input, hidden, weight):
172171 return forward
173172
174173
175- def VariableRecurrentReverse (batch_sizes , inner ):
176- def forward (input , hidden , weight ):
174+ def VariableRecurrentReverse (inner ):
175+ def forward (input , hidden , weight , batch_sizes ):
177176 output = []
178177 input_offset = input .size (0 )
179178 last_batch_size = batch_sizes [- 1 ]
@@ -209,7 +208,7 @@ def forward(input, hidden, weight):
209208
210209
211210def AutogradRNN (mode , input_size , hidden_size , num_layers = 1 , batch_first = False ,
212- dropout = 0 , train = True , bidirectional = False , batch_sizes = None ,
211+ dropout = 0 , train = True , bidirectional = False , variable_length = False ,
213212 dropout_state = None , flat_weight = None ):
214213
215214 if mode == 'RNN_RELU' :
@@ -223,10 +222,7 @@ def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False,
223222 else :
224223 raise Exception ('Unknown mode: {}' .format (mode ))
225224
226- if batch_sizes is None :
227- rec_factory = Recurrent
228- else :
229- rec_factory = variable_recurrent_factory (batch_sizes )
225+ rec_factory = variable_recurrent_factory if variable_length else Recurrent
230226
231227 if bidirectional :
232228 layer = (rec_factory (cell ), rec_factory (cell , reverse = True ))
@@ -239,13 +235,13 @@ def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False,
239235 dropout = dropout ,
240236 train = train )
241237
242- def forward (input , weight , hidden ):
243- if batch_first and batch_sizes is None :
238+ def forward (input , weight , hidden , batch_sizes ):
239+ if batch_first and not variable_length :
244240 input = input .transpose (0 , 1 )
245241
246- nexth , output = func (input , hidden , weight )
242+ nexth , output = func (input , hidden , weight , batch_sizes )
247243
248- if batch_first and batch_sizes is None :
244+ if batch_first and not variable_length :
249245 output = output .transpose (0 , 1 )
250246
251247 return output , nexth
@@ -255,7 +251,7 @@ def forward(input, weight, hidden):
255251
256252def CudnnRNN (mode , input_size , hidden_size , num_layers = 1 ,
257253 batch_first = False , dropout = 0 , train = True , bidirectional = False ,
258- batch_sizes = None , dropout_state = None , flat_weight = None ):
254+ variable_length = False , dropout_state = None , flat_weight = None ):
259255 if dropout_state is None :
260256 dropout_state = {}
261257 mode = cudnn .rnn .get_cudnn_mode (mode )
@@ -266,7 +262,7 @@ def CudnnRNN(mode, input_size, hidden_size, num_layers=1,
266262 "at every call, possibly greatly increasing memory usage. "
267263 "To compact weights again call flatten_parameters()." , stacklevel = 5 )
268264
269- def forward (input , weight , hx ):
265+ def forward (input , weight , hx , batch_sizes ):
270266 if mode == cudnn .CUDNN_LSTM :
271267 hx , cx = hx
272268 else :
@@ -284,7 +280,7 @@ def forward(input, weight, hx):
284280 hx , cx ,
285281 mode , hidden_size , num_layers ,
286282 batch_first , dropout , train , bool (bidirectional ),
287- batch_sizes if batch_sizes else (),
283+ list ( batch_sizes . data ) if variable_length else (),
288284 Variable (dropout_desc .state ) if dropout_desc .state is not None else None )
289285
290286 if cx is not None :
0 commit comments