Skip to content

Commit c55d845

Browse files
added testing loader
1 parent cb51c43 commit c55d845

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

data_loader.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@ class DataLoader(object):
66
def __init__(self,training_data,batch_size):
77
self.training_data=training_data
88
self.batch_size=batch_size
9+
10+
def get_data(self):
11+
data=random.choices(self.training_data,k=self.batch_size)
912

13+
return data
14+
1015
def __load_next__(self):
11-
data=random.choices(self.training_data,k=self.batch_size)
16+
data=self.get_data()
1217

1318
max_query_len,max_doc_len,max_cand_len,max_word_len=0,0,0,0
1419
ans=[]
@@ -83,4 +88,23 @@ def __load_next__(self):
8388
index+=1
8489

8590
return docs,doc_char,docs_mask,queries,query_char,queries_mask, \
86-
char_type,char_type_mask,answers,clozes,cands,cand_mask,qe_comm
91+
char_type,char_type_mask,answers,clozes,cands,cand_mask,qe_comm
92+
93+
94+
class TestLoader(DataLoader):
95+
def __init__(self,data,num_examples,batch_size=2):
96+
self.data=data
97+
self.examples=num_examples
98+
self.counter=0
99+
self.batch_size=batch_size
100+
101+
def reset_counter(self):
102+
self.counter=0
103+
104+
def get_data(self):
105+
data=self.data[self.counter:self.count+2]
106+
self.counter+=2
107+
if self.counter==self.examples:
108+
self.reset_counter()
109+
110+
return data

0 commit comments

Comments
 (0)