1+ import random
2+ import torch
3+
4+
5+ class DataLoader (object ):
6+ def __init__ (self ,training_data ,batch_size ):
7+ self .training_data = training_data
8+ self .batch_size = batch_size
9+
10+ def __load_next__ (self ):
11+ data = random .choices (self .training_data ,k = self .batch_size )
12+
13+ max_query_len ,max_doc_len ,max_cand_len ,max_word_len = 0 ,0 ,0 ,0
14+ ans = []
15+ clozes = []
16+ word_types = {}
17+ for i ,instance in enumerate (data ):
18+ doc ,query ,doc_char ,query_char ,cand ,ans_ ,cloze_ = instance
19+ max_doc_len = max (max_doc_len ,len (doc ))
20+ max_query_len = max (max_query_len ,len (query ))
21+ max_cand_len = max (max_cand_len ,len (cand ))
22+ ans .append (ans_ [0 ])
23+ clozes .append (cloze_ [0 ])
24+
25+ for index ,word in enumerate (doc_char ):
26+ max_word_len = max (max_word_len ,len (word ))
27+ if tuple (word ) not in word_types :
28+ word_types [tuple (word )]= []
29+ word_types [tuple (word )].append ((1 ,i ,index ))
30+ for index ,word in enumerate (query_char ):
31+ max_word_len = max (max_word_len ,len (word ))
32+ if tuple (word ) not in word_types :
33+ word_types [tuple (word )]= []
34+ word_types [tuple (word )].append ((0 ,i ,index ))
35+
36+ docs = torch .zeros (self .batch_size ,max_doc_len ,dtype = torch .long )
37+ queries = torch .zeros (self .batch_size ,max_query_len ,dtype = torch .long )
38+ cands = torch .zeros (self .batch_size ,max_doc_len ,max_cand_len ,dtype = torch .long )
39+ docs_mask = torch .zeros (self .batch_size ,max_doc_len ,dtype = torch .long )
40+ queries_mask = torch .zeros (self .batch_size ,max_query_len ,dtype = torch .long )
41+ cand_mask = torch .zeros (self .batch_size ,max_doc_len ,dtype = torch .long )
42+ qe_comm = torch .zeros (self .batch_size ,max_doc_len ,dtype = torch .long )
43+ answers = torch .tensor (ans ,dtype = torch .long )
44+ clozes = torch .tensor (clozes ,dtype = torch .long )
45+
46+ for i ,instance in enumerate (data ):
47+ doc ,query ,doc_char ,query_char ,cand ,ans_ ,cloze_ = instance
48+ docs [i ,:len (doc )]= torch .tensor (doc )
49+ queries [i ,:len (query )]= torch .tensor (query )
50+ docs_mask [i ,:len (doc )]= 1
51+ queries_mask [i ,:len (query )]= 1
52+
53+ for k ,index in enumerate (doc ):
54+ for j ,index_c in enumerate (cand ):
55+ if index == index_c :
56+ cands [i ][k ][j ]= 1
57+ cand_mask [i ][k ]= 1
58+
59+ for y in query :
60+ if y == index :
61+ qe_comm [i ][k ]= 1
62+ break
63+
64+ for x ,cl in enumerate (cand ):
65+ if cl == answers [i ]:
66+ answers [i ]= x
67+ break
68+
69+ doc_char = torch .zeros (self .batch_size ,max_doc_len ,dtype = torch .long )
70+ query_char = torch .zeros (self .batch_size ,max_query_len ,dtype = torch .long )
71+ char_type = torch .zeros (len (word_types ),max_word_len ,dtype = torch .long )
72+ char_type_mask = torch .zeros (len (word_types ),max_word_len ,dtype = torch .long )
73+
74+ index = 0
75+ for word ,word_list in word_types .items ():
76+ char_type [index ,:len (word )]= torch .tensor (list (word ))
77+ char_type_mask [index ,:len (word )]= 1
78+ for (i ,j ,k ) in word_list :
79+ if i == 1 :
80+ doc_char [j ,k ]= index
81+ else :
82+ query_char [j ,k ]= index
83+ index += 1
84+
85+ return docs ,doc_char ,docs_mask ,queries ,query_char ,queries_mask , \
86+ char_type ,char_type_mask ,answers ,clozes ,cands ,cand_mask ,qe_comm
0 commit comments