@@ -25,10 +25,8 @@ struct ChessModel : nn::Model {
2525
2626 // train function
2727 void train (dataset::BatchLoader<chess::Position>& loader,
28- dataset::BatchLoader<chess::Position>& validation_loader,
2928 int epochs = 1500 ,
30- int epoch_size = 1e8 ,
31- int validation_size = 1e7 ) {
29+ int epoch_size = 1e8 ) {
3230 this ->compile (loader.batch_size );
3331
3432 Timer t {};
@@ -37,7 +35,6 @@ struct ChessModel : nn::Model {
3735
3836 uint64_t prev_print_tm = 0 ;
3937 float total_epoch_loss = 0 ;
40- float total_validation_loss = 0 ;
4138
4239 for (int b = 1 ; b <= epoch_size / loader.batch_size ; b++) {
4340 auto * ds = loader.next ();
@@ -67,20 +64,7 @@ struct ChessModel : nn::Model {
6764 std::cout << std::endl;
6865
6966 float epoch_loss = total_epoch_loss / (epoch_size / loader.batch_size );
70-
71- for (int b = 1 ; b <= validation_size / validation_loader.batch_size ; b++) {
72- auto * ds = validation_loader.next ();
73- setup_inputs_and_outputs (ds, 0.5 );
74-
75- total_validation_loss += loss ();
76- }
77-
78- float validation_loss =
79- total_validation_loss / (validation_size / validation_loader.batch_size );
80- printf (" ep = [%4d], valid_loss = [%1.8f]" , i, validation_loss);
81- std::cout << std::endl;
82-
83- next_epoch (epoch_loss, validation_loss);
67+ next_epoch (epoch_loss, 0.0 );
8468 }
8569 }
8670
@@ -391,35 +375,34 @@ int main(int argc, char* argv[]) {
391375 std::exit (1 );
392376 }
393377
394-
395378 math::seed (0 );
396379
397380 init ();
398381
399382 std::vector<std::string> files {};
400- std::vector<std::string> validation_files {};
401383
402384 for (const auto & entry : fs::directory_iterator (program.get (" data" ))) {
403385 const std::string path = entry.path ().string ();
404- if (path.find (" valid" ) != std::string::npos) {
405- std::cout << " Specifying " << path << " as validation data!" << std::endl;
406- validation_files.push_back (path);
407- } else {
408- std::cout << " Specifying " << path << " as training data!" << std::endl;
409- files.push_back (path);
410- }
386+ files.push_back (path);
387+ }
388+
389+ uint64_t total_positions = 0 ;
390+ for (const auto & file_path : files) {
391+ FILE* fin = fopen (file_path.c_str (), " rb" );
392+
393+ dataset::DataSetHeader h {};
394+ fread (&h, sizeof (dataset::DataSetHeader), 1 , fin);
395+
396+ total_positions += h.entry_count ;
397+ fclose (fin);
411398 }
412399
413- if (validation_files.empty ())
414- validation_files.push_back (files.at (0 ));
400+ std::cout << " Loading a total of " << files.size () << " files with " << total_positions << " total position(s)" << std::endl;
415401
416402 const int batch_size = 16384 ;
417403 dataset::BatchLoader<chess::Position> loader {files, batch_size};
418404 loader.start ();
419405
420- dataset::BatchLoader<chess::Position> validation_loader {validation_files, batch_size};
421- validation_loader.start ();
422-
423406 BerserkModel model {};
424407 model.set_loss (MPE {2.5 , true });
425408 model.set_lr_schedule (StepDecayLRSchedule {1e-3 , 1.0 / 40.0 , 500 });
@@ -433,10 +416,9 @@ int main(int argc, char* argv[]) {
433416 model.load_weights (*previous);
434417 }
435418
436- model.train (loader, validation_loader, 1000 );
419+ model.train (loader, 1000 );
437420
438421 loader.kill ();
439- validation_loader.kill ();
440422
441423 close ();
442424 return 0 ;
0 commit comments