Skip to content

Commit 5043e09

Browse files
committed
update read and decode
1 parent 5d57edf commit 5043e09

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tensorgraph/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,8 @@ def read_arrs_from_tfrecords(tfrecords_filename, data_shapes, dtype=np.float32):
358358

359359

360360
@staticmethod
361-
def read_and_decode(tfrecords_filename_list, data_shapes, batch_size, dtype=tf.float32):
361+
def read_and_decode(tfrecords_filename_list, data_shapes, batch_size, dtype=tf.float32,
362+
capacity=None, min_after_dequeue=None):
362363
'''
363364
tfrecords_filename_list (list): list of tfrecords paths
364365
data_shapes (dict): dictionary of the record name and shape example: {'X':[32,32], 'y':[10]}
@@ -382,10 +383,12 @@ def read_and_decode(tfrecords_filename_list, data_shapes, batch_size, dtype=tf.f
382383
data = tf.reshape(data, data_shapes[name])
383384
records.append(data)
384385
names.append(name)
386+
capacity = capacity if capacity else 10*batch_size
387+
min_after_dequeue = min_after_dequeue if min_after_dequeue else 5*batch_size
385388
batch_records = tf.train.shuffle_batch(records, batch_size=batch_size,
386-
capacity=10*batch_size,
389+
capacity=capacity,
387390
num_threads=4,
388-
min_after_dequeue=5*batch_size)
391+
min_after_dequeue=min_after_dequeue)
389392
if not isinstance(batch_records, (list, tuple)):
390393
batch_records = [batch_records]
391394
return zip(names, batch_records)

0 commit comments

Comments
 (0)