@@ -11,7 +11,7 @@ int main(int argc, char* argv[]) {
1111 argparse::ArgumentParser program (" Grapheus" );
1212
1313 program.add_argument (" data" ).required ().help (" Directory containing training files" );
14- program.add_argument (" --val-data" ).required (). help (" Directory containing validation files" );
14+ program.add_argument (" --val-data" ).help (" Directory containing validation files" );
1515 program.add_argument (" --output" ).required ().help (" Output directory for network files" );
1616 program.add_argument (" --resume" ).help (" Weights file to resume from" );
1717 program.add_argument (" --epochs" )
@@ -66,9 +66,6 @@ int main(int argc, char* argv[]) {
6666 // Fetch training dataset paths
6767 std::vector<std::string> train_files = dataset::fetch_dataset_paths (program.get (" data" ));
6868
69- // Fetch validation dataset paths
70- std::vector<std::string> val_files = dataset::fetch_dataset_paths (program.get (" --val-data" ));
71-
7269 // Print training dataset file list if files are found
7370 if (!train_files.empty ()) {
7471 std::cout << " Training Dataset Files:" << std::endl;
@@ -84,6 +81,13 @@ int main(int argc, char* argv[]) {
8481 exit (0 );
8582 }
8683
84+ // Fetch validation dataset paths
85+ std::vector<std::string> val_files;
86+
87+ if (program.present (" --val-data" )) {
88+ val_files = dataset::fetch_dataset_paths (program.get (" --val-data" ));
89+ }
90+
8791 // Print validation dataset file list if files are found
8892 if (!val_files.empty ()) {
8993 std::cout << " Validation Dataset Files:" << std::endl;
@@ -93,9 +97,6 @@ int main(int argc, char* argv[]) {
9397 std::cout << " Total validation files: " << val_files.size () << std::endl;
9498 std::cout << " Total validation positions: " << dataset::count_total_positions (val_files)
9599 << std::endl;
96- } else {
97- std::cout << " No validation files found in " << program.get (" --val-data" ) << std::endl;
98- exit (0 );
99100 }
100101
101102 const int total_epochs = program.get <int >(" --epochs" );
@@ -121,10 +122,13 @@ int main(int argc, char* argv[]) {
121122 using BatchLoader = dataset::BatchLoader<chess::Position>;
122123
123124 BatchLoader train_loader {train_files, batch_size};
124- BatchLoader val_loader {val_files, batch_size};
125-
126125 train_loader.start ();
127- val_loader.start ();
126+
127+ std::optional<BatchLoader> val_loader;
128+ if (val_files.size () > 0 ) {
129+ val_loader.emplace (val_files, batch_size);
130+ val_loader->start ();
131+ }
128132
129133 model::BerserkModel model {static_cast <size_t >(ft_size), lambda, static_cast <size_t >(save_rate)};
130134 model.set_loss (MPE {2.5 , true });
@@ -145,7 +149,7 @@ int main(int argc, char* argv[]) {
145149 model.train (train_loader, val_loader, total_epochs, epoch_size);
146150
147151 train_loader.kill ();
148- val_loader. kill ();
152+ val_loader-> kill ();
149153
150154 close ();
151155 return 0 ;
0 commit comments