@@ -11,7 +11,7 @@ int main(int argc, char* argv[]) {
1111 argparse::ArgumentParser program (" Grapheus" );
1212
1313 program.add_argument (" data" ).required ().help (" Directory containing training files" );
14- program.add_argument (" --val-data" ).required (). help (" Directory containing validation files" );
14+ program.add_argument (" --val-data" ).help (" Directory containing validation files" );
1515 program.add_argument (" --output" ).required ().help (" Output directory for network files" );
1616 program.add_argument (" --resume" ).help (" Weights file to resume from" );
1717 program.add_argument (" --epochs" )
@@ -22,6 +22,11 @@ int main(int argc, char* argv[]) {
2222 .default_value (100000000 )
2323 .help (" Total positions in each epoch" )
2424 .scan <' i' , int >();
25+
26+ program.add_argument (" --val--size" )
27+ .default_value (10000000 )
28+ .help (" Total positions for each validation epoch" )
29+ .scan <' i' , int >();
2530 program.add_argument (" --save-rate" )
2631 .default_value (50 )
2732 .help (" How frequently to save quantized networks + weights" )
@@ -66,9 +71,6 @@ int main(int argc, char* argv[]) {
6671 // Fetch training dataset paths
6772 std::vector<std::string> train_files = dataset::fetch_dataset_paths (program.get (" data" ));
6873
69- // Fetch validation dataset paths
70- std::vector<std::string> val_files = dataset::fetch_dataset_paths (program.get (" --val-data" ));
71-
7274 // Print training dataset file list if files are found
7375 if (!train_files.empty ()) {
7476 std::cout << " Training Dataset Files:" << std::endl;
@@ -84,6 +86,13 @@ int main(int argc, char* argv[]) {
8486 exit (0 );
8587 }
8688
89+ // Fetch validation dataset paths
90+ std::vector<std::string> val_files;
91+
92+ if (program.present (" --val-data" )) {
93+ val_files = dataset::fetch_dataset_paths (program.get (" --val-data" ));
94+ }
95+
8796 // Print validation dataset file list if files are found
8897 if (!val_files.empty ()) {
8998 std::cout << " Validation Dataset Files:" << std::endl;
@@ -93,20 +102,18 @@ int main(int argc, char* argv[]) {
93102 std::cout << " Total validation files: " << val_files.size () << std::endl;
94103 std::cout << " Total validation positions: " << dataset::count_total_positions (val_files)
95104 << std::endl;
96- } else {
97- std::cout << " No validation files found in " << program.get (" --val-data" ) << std::endl;
98- exit (0 );
99105 }
100106
101- const int total_epochs = program.get <int >(" --epochs" );
102- const int epoch_size = program.get <int >(" --epoch-size" );
103- const int save_rate = program.get <int >(" --save-rate" );
104- const int ft_size = program.get <int >(" --ft-size" );
105- const float lambda = program.get <float >(" --lambda" );
106- const float lr = program.get <float >(" --lr" );
107- const int batch_size = program.get <int >(" --batch-size" );
108- const int lr_drop_epoch = program.get <int >(" --lr-drop-epoch" );
109- const float lr_drop_ratio = program.get <float >(" --lr-drop-ratio" );
107+ const int total_epochs = program.get <int >(" --epochs" );
108+ const int epoch_size = program.get <int >(" --epoch-size" );
109+ const int val_epoch_size = program.get <int >(" --val-size" );
110+ const int save_rate = program.get <int >(" --save-rate" );
111+ const int ft_size = program.get <int >(" --ft-size" );
112+ const float lambda = program.get <float >(" --lambda" );
113+ const float lr = program.get <float >(" --lr" );
114+ const int batch_size = program.get <int >(" --batch-size" );
115+ const int lr_drop_epoch = program.get <int >(" --lr-drop-epoch" );
116+ const float lr_drop_ratio = program.get <float >(" --lr-drop-ratio" );
110117
111118 std::cout << " Epochs: " << total_epochs << " \n "
112119 << " Epochs Size: " << epoch_size << " \n "
@@ -121,10 +128,13 @@ int main(int argc, char* argv[]) {
121128 using BatchLoader = dataset::BatchLoader<chess::Position>;
122129
123130 BatchLoader train_loader {train_files, batch_size};
124- BatchLoader val_loader {val_files, batch_size};
125-
126131 train_loader.start ();
127- val_loader.start ();
132+
133+ std::optional<BatchLoader> val_loader;
134+ if (val_files.size () > 0 ) {
135+ val_loader.emplace (val_files, batch_size);
136+ val_loader->start ();
137+ }
128138
129139 model::BerserkModel model {static_cast <size_t >(ft_size), lambda, static_cast <size_t >(save_rate)};
130140 model.set_loss (MPE {2.5 , true });
@@ -142,10 +152,10 @@ int main(int argc, char* argv[]) {
142152 std::cout << " Loaded weights from previous " << *previous << std::endl;
143153 }
144154
145- model.train (train_loader, val_loader, total_epochs, epoch_size);
155+ model.train (train_loader, val_loader, total_epochs, epoch_size, val_epoch_size );
146156
147157 train_loader.kill ();
148- val_loader. kill ();
158+ val_loader-> kill ();
149159
150160 close ();
151161 return 0 ;
0 commit comments