@@ -22,6 +22,11 @@ int main(int argc, char* argv[]) {
2222 .default_value (100000000 )
2323 .help (" Total positions in each epoch" )
2424 .scan <' i' , int >();
25+
26+ program.add_argument (" --val--size" )
27+ .default_value (10000000 )
28+ .help (" Total positions for each validation epoch" )
29+ .scan <' i' , int >();
2530 program.add_argument (" --save-rate" )
2631 .default_value (50 )
2732 .help (" How frequently to save quantized networks + weights" )
@@ -99,15 +104,16 @@ int main(int argc, char* argv[]) {
99104 << std::endl;
100105 }
101106
102- const int total_epochs = program.get <int >(" --epochs" );
103- const int epoch_size = program.get <int >(" --epoch-size" );
104- const int save_rate = program.get <int >(" --save-rate" );
105- const int ft_size = program.get <int >(" --ft-size" );
106- const float lambda = program.get <float >(" --lambda" );
107- const float lr = program.get <float >(" --lr" );
108- const int batch_size = program.get <int >(" --batch-size" );
109- const int lr_drop_epoch = program.get <int >(" --lr-drop-epoch" );
110- const float lr_drop_ratio = program.get <float >(" --lr-drop-ratio" );
107+ const int total_epochs = program.get <int >(" --epochs" );
108+ const int epoch_size = program.get <int >(" --epoch-size" );
109+ const int val_epoch_size = program.get <int >(" --val-size" );
110+ const int save_rate = program.get <int >(" --save-rate" );
111+ const int ft_size = program.get <int >(" --ft-size" );
112+ const float lambda = program.get <float >(" --lambda" );
113+ const float lr = program.get <float >(" --lr" );
114+ const int batch_size = program.get <int >(" --batch-size" );
115+ const int lr_drop_epoch = program.get <int >(" --lr-drop-epoch" );
116+ const float lr_drop_ratio = program.get <float >(" --lr-drop-ratio" );
111117
112118 std::cout << " Epochs: " << total_epochs << " \n "
113119 << " Epochs Size: " << epoch_size << " \n "
@@ -146,7 +152,7 @@ int main(int argc, char* argv[]) {
146152 std::cout << " Loaded weights from previous " << *previous << std::endl;
147153 }
148154
149- model.train (train_loader, val_loader, total_epochs, epoch_size);
155+ model.train (train_loader, val_loader, total_epochs, epoch_size, val_epoch_size );
150156
151157 train_loader.kill ();
152158 val_loader->kill ();
0 commit comments