55
66from utils import process_all_files ,load_GloVe ,accuracy_cal
77from model import GA_Reader
8- from data_loader import DataLoader
8+ from data_loader import DataLoader , TestLoader
99
1010def train (epochs ,iterations ,loader_train ,loader_val ,
1111 model ,optimizer ,loss_function ):
@@ -36,20 +36,44 @@ def train(epochs,iterations,loader_train,loader_val,
3636
3737def validate (loader_val ,model ,loss_function ):
3838 model .eval ()
39-
40- doc ,doc_char ,doc_mask ,query ,query_char ,query_mask , \
41- char_type ,char_type_mask ,answer ,cloze ,cand , \
42- cand_mask ,qe_comm = loader_val .__load_next__ ()
43-
44- output = model ( doc ,doc_char ,doc_mask ,query ,query_char ,query_mask ,
45- char_type ,char_type_mask ,answer ,cloze ,cand ,
46- cand_mask ,qe_comm )
39+ return_loss = 0
40+ accuracy = 0
41+
42+ for _ in range (loader_val .examples // loader_val .batch_size ):
43+ doc ,doc_char ,doc_mask ,query ,query_char ,query_mask , \
44+ char_type ,char_type_mask ,answer ,cloze ,cand , \
45+ cand_mask ,qe_comm = loader_val .__load_next__ ()
46+
47+ output = model ( doc ,doc_char ,doc_mask ,query ,query_char ,query_mask ,
48+ char_type ,char_type_mask ,answer ,cloze ,cand ,
49+ cand_mask ,qe_comm )
50+
51+ accuracy += accuracy_cal (output ,answer )
52+ loss = loss_function (output ,answer )
53+ return_loss += loss .item ()
4754
48- accuracy = accuracy_cal ( output , answer )
49- loss = loss_function ( output , answer )
55+ return_loss /= ( loader_val . examples // loader_val . batch_size )
56+ accuracy = 100 * accuracy / loader_val . examples
5057
51- return loss . item () ,accuracy
58+ return return_loss ,accuracy
5259
60+ def test (loader_test ,model ):
61+ model .eval ()
62+ accuracy = 0
63+ for _ in range (loader_test .examples // loader_test .batch_size ):
64+ doc ,doc_char ,doc_mask ,query ,query_char ,query_mask , \
65+ char_type ,char_type_mask ,answer ,cloze ,cand , \
66+ cand_mask ,qe_comm = loader_test .__load_next__ ()
67+
68+ output = model ( doc ,doc_char ,doc_mask ,query ,query_char ,query_mask ,
69+ char_type ,char_type_mask ,answer ,cloze ,cand ,
70+ cand_mask ,qe_comm )
71+
72+ accuracy += accuracy_cal (output ,answer )
73+
74+ accuracy = 100 * accuracy / loader_test .examples
75+ print ('test accuracy=' ,accuracy )
76+
5377def main (args ):
5478 word_to_int ,int_to_word ,char_to_int ,int_to_char , \
5579 training_data = process_all_files (args .train_file )
@@ -62,11 +86,17 @@ def main(args):
6286
6387 optimizer = optim .Adam (model .parameters (),lr = args .lr )
6488 data_loader_train = DataLoader (training_data [:args .training_size ],args .batch_size )
65- data_loader_validate = DataLoader (training_data [args .training_size :],args .batch_size )
89+ data_loader_validate = TestLoader (training_data [args .training_size :args . \
90+ training_size + args .dev_size ],args .dev_size )
91+ data_loader_test = TestLoader (training_data [args . \
92+ training_size_args .dev_size :args . \
93+ training_size + args .dev_size + args .test_size ],args .test_size )
6694
6795 train (args .epochs ,args .iterations ,data_loader_train ,
6896 data_loader_validate ,model ,optimizer ,loss_function )
6997
98+ test (data_loader_test ,model )
99+
70100def setup ():
71101 parser = argparse .ArgumentParser ('argument parser' )
72102 parser .add_argument ('--lr' ,type = float ,default = 0.00005 )
@@ -82,9 +112,9 @@ def setup():
82112 parser .add_argument ('--gru_layers' ,type = int ,default = 3 )
83113 parser .add_argument ('--embed_file' ,type = str ,default = os .getcwd ()+ '/word2vec_glove.text' )
84114 parser .add_argument ('--train_file' ,type = str ,default = os .getcwd ()+ '/train/' )
85- parser .add_argument ('--dev_file ' ,type = str ,default = os . getcwd () + '/validation/' )
86- parser .add_argument ('--test_file ' ,type = str ,default = os . getcwd () + '/test/' )
87- parser .add_argument ('--training_size ' ,type = int ,default = 380 , 298 )
115+ parser .add_argument ('--train_size ' ,type = int ,default = 380298 )
116+ parser .add_argument ('--dev_size ' ,type = int ,default = 3924 )
117+ parser .add_argument ('--test_size ' ,type = int ,default = 3198 )
88118
89119 args = parser .parse_args ()
90120
0 commit comments