66
77class DynamicLSTM (Template ):
88
9- def __init__ (self , num_units , state_is_tuple = True , scope = None ):
9+ def __init__ (self , num_units , const_seq_len = False , state_is_tuple = True , scope = None ):
1010 '''
1111 DESCRIPTION:
1212 DynamicLSTM is for sequences with dynamic length.
1313 PARAMS:
1414 scope (str): scope for the cells. For RNN with the same scope name,
1515 the rnn cell will be reused.
16+ const_seq_len (bool): if true, will use a constant sequence
17+ length equal to the max_time. state_below will just be
18+ singular input instead of tuple.
1619 '''
1720 if scope is None :
1821 self .scope = self .__class__ .__name__
1922 else :
2023 self .scope = scope
2124 with tf .variable_scope (self .scope ):
2225 self .lstm = tf .contrib .rnn .LSTMCell (num_units = num_units , state_is_tuple = state_is_tuple )
23-
26+ self . const_seq_len = const_seq_len
2427
2528
2629 def _train_fprop (self , state_below ):
@@ -36,9 +39,12 @@ def _train_fprop(self, state_below):
3639 if state_is_tuple = False:
3740 return tf.concat(1, [C, h]) of dimension [batchsize, 2 * num_units]
3841 '''
39- X_sb , seqlen_sb = state_below
40- with tf .variable_scope (self .scope ) as scope :
42+ if self .const_seq_len :
43+ X_sb , seqlen_sb = state_below , None
44+ else :
45+ X_sb , seqlen_sb = state_below
4146
47+ with tf .variable_scope (self .scope ) as scope :
4248 try :
4349 bef = set (tf .global_variables ())
4450 outputs , last_states = tf .nn .dynamic_rnn (cell = self .lstm ,
@@ -63,7 +69,7 @@ def _variables(self):
6369
6470class LSTM (Template ):
6571
66- def __init__ (self , num_units , return_idx , initial_state = None , state_is_tuple = True , scope = None ):
72+ def __init__ (self , num_units , return_idx = [ 0 , 1 , 2 ] , initial_state = None , state_is_tuple = True , scope = None ):
6773 '''
6874 DESCRIPTION:
6975 LSTM is for sequences with fixed length.
@@ -75,7 +81,9 @@ def __init__(self, num_units, return_idx, initial_state=None, state_is_tuple=Tru
7581 [batch_size, cell.state_size]. If cell.state_size is a tuple,
7682 this should be a tuple of tensors having shapes [batch_size, s]
7783 for s in cell.state_size.
78- return_idx (list): list of index from the rnn outputs to return
84+ return_idx (list): list of index from the rnn outputs to return from
85+ [outputs, context, last_hid], indexes has to fall into
86+ [0, 1, 2]
7987 '''
8088 if scope is None :
8189 self .scope = self .__class__ .__name__
@@ -86,11 +94,11 @@ def __init__(self, num_units, return_idx, initial_state=None, state_is_tuple=Tru
8694 self .initial_state = initial_state
8795 self .state_is_tuple = state_is_tuple
8896 self .return_idx = return_idx
89- assert max (self .return_idx ) <= 2 and min (self .return_idx ) >= 0
97+ assert max (self .return_idx ) <= 2 and min (self .return_idx ) >= 0 , 'indexes \
98+ does not fall into [outputs, context, last_hid]'
9099 assert isinstance (self .return_idx , list )
91100
92101
93-
94102 def _train_fprop (self , state_below ):
95103 '''
96104 PARAMS:
@@ -135,13 +143,16 @@ def _variables(self):
135143
136144class DynamicBiLSTM (Template ):
137145
138- def __init__ (self , fw_num_units , bw_num_units , state_is_tuple = True , scope = None ):
146+ def __init__ (self , fw_num_units , bw_num_units , const_seq_len = False , state_is_tuple = True , scope = None ):
139147 '''
140148 DESCRIPTION:
141149 BiDynamicLSTM is for sequences with dynamic length.
142150 PARAMS:
143151 scope (str): scope for the cells. For RNN with the same scope name,
144152 the rnn cell will be reused.
153+ const_seq_len (bool): if true, will use a constant sequence
154+ length equal to the max_time. state_below will just be
155+ singular input instead of tuple.
145156 '''
146157 if scope is None :
147158 self .scope = self .__class__ .__name__
@@ -150,6 +161,7 @@ def __init__(self, fw_num_units, bw_num_units, state_is_tuple=True, scope=None):
150161 with tf .variable_scope (self .scope ):
151162 self .fw_lstm = tf .contrib .rnn .LSTMCell (num_units = fw_num_units , state_is_tuple = state_is_tuple )
152163 self .bw_lstm = tf .contrib .rnn .LSTMCell (num_units = bw_num_units , state_is_tuple = state_is_tuple )
164+ self .const_seq_len = const_seq_len
153165
154166
155167 def _train_fprop (self , state_below ):
@@ -167,7 +179,11 @@ def _train_fprop(self, state_below):
167179 if state_is_tuple = False:
168180 return tf.concat(1, [C, h]) of dimension [batchsize, 2 * num_units]
169181 '''
170- X_sb , seqlen_sb = state_below
182+ if self .const_seq_len :
183+ seqlen_sb = None
184+ X_sb = state_below
185+ else :
186+ X_sb , seqlen_sb = state_below
171187
172188 with tf .variable_scope (self .scope ) as scope :
173189 try :
0 commit comments