@@ -358,7 +358,8 @@ def read_arrs_from_tfrecords(tfrecords_filename, data_shapes, dtype=np.float32):
358358
359359
360360 @staticmethod
361- def read_and_decode (tfrecords_filename_list , data_shapes , batch_size , dtype = tf .float32 ):
361+ def read_and_decode (tfrecords_filename_list , data_shapes , batch_size , dtype = tf .float32 ,
362+ capacity = None , min_after_dequeue = None ):
362363 '''
363364 tfrecords_filename_list (list): list of tfrecords paths
364365 data_shapes (dict): dictionary of the record name and shape example: {'X':[32,32], 'y':[10]}
@@ -382,10 +383,12 @@ def read_and_decode(tfrecords_filename_list, data_shapes, batch_size, dtype=tf.f
382383 data = tf .reshape (data , data_shapes [name ])
383384 records .append (data )
384385 names .append (name )
386+ capacity = capacity if capacity else 10 * batch_size
387+ min_after_dequeue = min_after_dequeue if min_after_dequeue else 5 * batch_size
385388 batch_records = tf .train .shuffle_batch (records , batch_size = batch_size ,
386- capacity = 10 * batch_size ,
389+ capacity = capacity ,
387390 num_threads = 4 ,
388- min_after_dequeue = 5 * batch_size )
391+ min_after_dequeue = min_after_dequeue )
389392 if not isinstance (batch_records , (list , tuple )):
390393 batch_records = [batch_records ]
391394 return zip (names , batch_records )
0 commit comments