66
77using namespace nn ;
88using namespace data ;
9- namespace fs = std::filesystem;
109
1110int main (int argc, char * argv[]) {
1211 argparse::ArgumentParser program (" Grapheus" );
1312
1413 program.add_argument (" data" ).required ().help (" Directory containing training files" );
14+ program.add_argument (" --val-data" ).required ().help (" Directory containing validation files" );
1515 program.add_argument (" --output" ).required ().help (" Output directory for network files" );
1616 program.add_argument (" --resume" ).help (" Weights file to resume from" );
1717 program.add_argument (" --epochs" )
1818 .default_value (1000 )
1919 .help (" Total number of epochs to train for" )
2020 .scan <' i' , int >();
21+ program.add_argument (" --epoch-size" )
22+ .default_value (100000000 )
23+ .help (" Total positions in each epoch" )
24+ .scan <' i' , int >();
2125 program.add_argument (" --save-rate" )
2226 .default_value (50 )
2327 .help (" How frequently to save quantized networks + weights" )
@@ -59,28 +63,43 @@ int main(int argc, char* argv[]) {
5963
6064 init ();
6165
62- std::vector<std::string> files {};
63-
64- for (const auto & entry : fs::directory_iterator (program.get (" data" ))) {
65- const std::string path = entry.path ().string ();
66- files.push_back (path);
66+ // Fetch training dataset paths
67+ std::vector<std::string> train_files = dataset::fetch_dataset_paths (program.get (" data" ));
68+
69+ // Fetch validation dataset paths
70+ std::vector<std::string> val_files = dataset::fetch_dataset_paths (program.get (" --val-data" ));
71+
72+ // Print training dataset file list if files are found
73+ if (!train_files.empty ()) {
74+ std::cout << " Training Dataset Files:" << std::endl;
75+ for (const auto & file : train_files) {
76+ std::cout << file << std::endl;
77+ }
78+ std::cout << " Total training files: " << train_files.size () << std::endl;
79+ std::cout << " Total training positions: " << dataset::count_total_positions (train_files)
80+ << std::endl
81+ << std::endl;
82+ } else {
83+ std::cout << " No training files found in " << program.get (" data" ) << std::endl << std::endl;
84+ exit (0 );
6785 }
6886
69- uint64_t total_positions = 0 ;
70- for (const auto & file_path : files) {
71- FILE* fin = fopen (file_path.c_str (), " rb" );
72-
73- dataset::DataSetHeader h {};
74- fread (&h, sizeof (dataset::DataSetHeader), 1 , fin);
75-
76- total_positions += h.entry_count ;
77- fclose (fin);
87+ // Print validation dataset file list if files are found
88+ if (!val_files.empty ()) {
89+ std::cout << " Validation Dataset Files:" << std::endl;
90+ for (const auto & file : val_files) {
91+ std::cout << file << std::endl;
92+ }
93+ std::cout << " Total validation files: " << val_files.size () << std::endl;
94+ std::cout << " Total validation positions: " << dataset::count_total_positions (val_files)
95+ << std::endl;
96+ } else {
97+ std::cout << " No validation files found in " << program.get (" --val-data" ) << std::endl;
98+ exit (0 );
7899 }
79100
80- std::cout << " Loading a total of " << files.size () << " files with " << total_positions
81- << " total position(s)" << std::endl;
82-
83101 const int total_epochs = program.get <int >(" --epochs" );
102+ const int epoch_size = program.get <int >(" --epoch-size" );
84103 const int save_rate = program.get <int >(" --save-rate" );
85104 const int ft_size = program.get <int >(" --ft-size" );
86105 const float lambda = program.get <float >(" --lambda" );
@@ -90,6 +109,7 @@ int main(int argc, char* argv[]) {
90109 const float lr_drop_ratio = program.get <float >(" --lr-drop-ratio" );
91110
92111 std::cout << " Epochs: " << total_epochs << " \n "
112+ << " Epochs Size: " << epoch_size << " \n "
93113 << " Save Rate: " << save_rate << " \n "
94114 << " FT Size: " << ft_size << " \n "
95115 << " Lambda: " << lambda << " \n "
@@ -98,8 +118,13 @@ int main(int argc, char* argv[]) {
98118 << " LR Drop @ " << lr_drop_epoch << " \n "
99119 << " LR Drop R " << lr_drop_ratio << std::endl;
100120
101- dataset::BatchLoader<chess::Position> loader {files, batch_size};
102- loader.start ();
121+ using BatchLoader = dataset::BatchLoader<chess::Position>;
122+
123+ BatchLoader train_loader {train_files, batch_size};
124+ BatchLoader val_loader {val_files, batch_size};
125+
126+ train_loader.start ();
127+ val_loader.start ();
103128
104129 model::BerserkModel model {static_cast <size_t >(ft_size), lambda, static_cast <size_t >(save_rate)};
105130 model.set_loss (MPE {2.5 , true });
@@ -117,9 +142,10 @@ int main(int argc, char* argv[]) {
117142 std::cout << " Loaded weights from previous " << *previous << std::endl;
118143 }
119144
120- model.train (loader, total_epochs);
145+ model.train (train_loader, val_loader, total_epochs, epoch_size );
121146
122- loader.kill ();
147+ train_loader.kill ();
148+ val_loader.kill ();
123149
124150 close ();
125151 return 0 ;
0 commit comments