@@ -6,9 +6,14 @@ class DataLoader(object):
66 def __init__ (self ,training_data ,batch_size ):
77 self .training_data = training_data
88 self .batch_size = batch_size
9+
10+ def get_data (self ):
11+ data = random .choices (self .training_data ,k = self .batch_size )
912
13+ return data
14+
1015 def __load_next__ (self ):
11- data = random . choices ( self .training_data , k = self . batch_size )
16+ data = self .get_data ( )
1217
1318 max_query_len ,max_doc_len ,max_cand_len ,max_word_len = 0 ,0 ,0 ,0
1419 ans = []
@@ -83,4 +88,23 @@ def __load_next__(self):
8388 index += 1
8489
8590 return docs ,doc_char ,docs_mask ,queries ,query_char ,queries_mask , \
86- char_type ,char_type_mask ,answers ,clozes ,cands ,cand_mask ,qe_comm
91+ char_type ,char_type_mask ,answers ,clozes ,cands ,cand_mask ,qe_comm
92+
93+
94+ class TestLoader (DataLoader ):
95+ def __init__ (self ,data ,num_examples ,batch_size = 2 ):
96+ self .data = data
97+ self .examples = num_examples
98+ self .counter = 0
99+ self .batch_size = batch_size
100+
101+ def reset_counter (self ):
102+ self .counter = 0
103+
104+ def get_data (self ):
105+ data = self .data [self .counter :self .count + 2 ]
106+ self .counter += 2
107+ if self .counter == self .examples :
108+ self .reset_counter ()
109+
110+ return data
0 commit comments