1111from open_seq2seq .data .text2text .t2t import _read_and_batch_from_files
1212from open_seq2seq .data .text2text .tokenizer import PAD_ID
1313
14- tf .compat .v1 .disable_eager_execution ()
14+ # if hasattr(tf.compat, 'v1'):
15+ # tf.compat.v1.disable_eager_execution()
1516
1617class SpecialTextTokens (Enum ):
1718 PAD_ID = 0 # special padding token
@@ -162,7 +163,7 @@ def _pad2eight(self, lst, do_pad_eight):
162163 return lst + [SpecialTextTokens .PAD_ID .value ] * (8 - len (lst ) % 8 )
163164
164165 def _src_token_to_id (self , line ):
165- tokens = line .numpy (). decode ("utf-8" ).split (self ._delimiter )
166+ tokens = line .decode ("utf-8" ).split (self ._delimiter ) #line.numpy().
166167 if self ._use_start_token :
167168 return np .array (self ._pad2eight ([SpecialTextTokens .S_ID .value ] + \
168169 [self .src_seq2idx .get (token , SpecialTextTokens .UNK_ID .value ) for token in tokens [:self .max_len - 2 ]] + \
@@ -173,7 +174,7 @@ def _src_token_to_id(self, line):
173174 [SpecialTextTokens .EOS_ID .value ], self ._pad_lengths_to_eight ), dtype = "int32" )
174175
175176 def _tgt_token_to_id (self , line ):
176- tokens = line .numpy (). decode ("utf-8" ).split (self ._delimiter )
177+ tokens = line .decode ("utf-8" ).split (self ._delimiter ) #line.numpy().
177178 if self ._use_start_token :
178179 return np .array (self ._pad2eight ([SpecialTextTokens .S_ID .value ] + \
179180 [self .tgt_seq2idx .get (token , SpecialTextTokens .UNK_ID .value ) for token in tokens [:self .max_len - 2 ]] + \
@@ -197,14 +198,14 @@ def build_graph(self):
197198 _targets = _targets .shard (num_shards = self ._num_workers ,
198199 index = self ._worker_id )
199200
200- _sources = _sources .map (lambda line : tf .py_function (func = self ._src_token_to_id , inp = [line ],
201- Tout = [tf .int32 ]), # stateful=False),
201+ _sources = _sources .map (lambda line : tf .py_func (func = self ._src_token_to_id , inp = [line ],
202+ Tout = [tf .int32 ], stateful = False ),
202203 num_parallel_calls = self ._map_parallel_calls ) \
203204 .map (lambda tokens : (tokens , tf .size (tokens )),
204205 num_parallel_calls = self ._map_parallel_calls )
205206
206- _targets = _targets .map (lambda line : tf .py_function (func = self ._tgt_token_to_id , inp = [line ],
207- Tout = [tf .int32 ]), # stateful=False),
207+ _targets = _targets .map (lambda line : tf .py_func (func = self ._tgt_token_to_id , inp = [line ],
208+ Tout = [tf .int32 ], stateful = False ),
208209 num_parallel_calls = self ._map_parallel_calls ) \
209210 .map (lambda tokens : (tokens , tf .size (tokens )),
210211 num_parallel_calls = self ._map_parallel_calls )
0 commit comments