@@ -225,7 +225,7 @@ struct BerserkModel : ChessModel {
225225 const size_t n_l2 = 32 ;
226226 const size_t n_out = 1 ;
227227
228- BerserkModel (size_t n_ft, float lambda)
228+ BerserkModel (size_t n_ft, float lambda, size_t save_rate )
229229 : ChessModel(lambda) {
230230
231231 in1 = add<SparseInput>(n_features, 32 );
@@ -259,9 +259,10 @@ struct BerserkModel : ChessModel {
259259 1e-8 ,
260260 5 * 16384 ));
261261
262+ set_save_frequency (save_rate);
262263 add_quantization (Quantizer {
263264 " quant" ,
264- 50 ,
265+ save_rate ,
265266 QuantizerEntry<int16_t >(&ft->weights .values , quant_one, true ),
266267 QuantizerEntry<int16_t >(&ft->bias .values , quant_one),
267268 QuantizerEntry<int8_t >(&l1->weights .values , quant_two),
@@ -368,13 +369,20 @@ int main(int argc, char* argv[]) {
368369 program.add_argument (" data" ).required ().help (" Directory containing training files" );
369370 program.add_argument (" --output" ).required ().help (" Output directory for network files" );
370371 program.add_argument (" --resume" ).help (" Weights file to resume from" );
372+ program.add_argument (" --epochs" )
373+ .default_value <int >(1000 )
374+ .help (" Total number of epochs to train for" );
375+ program.add_argument (" --save-rate" )
376+ .default_value <size_t >(50 )
377+ .help (" How frequently to save quantized networks + weights" );
371378 program.add_argument (" --ft-size" )
372379 .default_value <size_t >(1024 )
373380 .help (" Number of neurons in the Feature Transformer" );
374381 program.add_argument (" --lambda" )
375382 .default_value <float >(0.0 )
376383 .help (" Ratio of evaluation score to use while training" );
377- program.add_argument (" --lr" ).default_value <float >(1e-3 ).help (" Initial learning rate" );
384+ program.add_argument (" --lr" ).default_value <float >(1e-3 ).help (
385+ " The starting learning rate for the optimizer" );
378386 program.add_argument (" --batch-size" )
379387 .default_value <int >(16384 )
380388 .help (" Number of positions in a mini-batch during training" );
@@ -418,14 +426,18 @@ int main(int argc, char* argv[]) {
418426 std::cout << " Loading a total of " << files.size () << " files with " << total_positions
419427 << " total position(s)" << std::endl;
420428
429+ const int total_epochs = program.get <int >(" --epochs" );
430+ const size_t save_rate = program.get <size_t >(" --save-rate" );
421431 const size_t ft_size = program.get <size_t >(" --ft-size" );
422432 const float lambda = program.get <float >(" --lambda" );
423433 const float lr = program.get <float >(" --lr" );
424434 const int batch_size = program.get <int >(" --batch-size" );
425435 const int lr_drop_epoch = program.get <int >(" --lr-drop-epoch" );
426436 const float lr_drop_ratio = program.get <float >(" --lr-drop-ratio" );
427437
428- std::cout << " FT Size: " << ft_size << " \n "
438+ std::cout << " Epochs: " << total_epochs << " \n "
439+ << " Save Rate: " << save_rate << " \n "
440+ << " FT Size: " << ft_size << " \n "
429441 << " Lambda: " << lambda << " \n "
430442 << " LR: " << lr << " \n "
431443 << " Batch: " << batch_size << " \n "
@@ -435,7 +447,7 @@ int main(int argc, char* argv[]) {
435447 dataset::BatchLoader<chess::Position> loader {files, batch_size};
436448 loader.start ();
437449
438- BerserkModel model {ft_size, lambda};
450+ BerserkModel model {ft_size, lambda, save_rate };
439451 model.set_loss (MPE {2.5 , true });
440452 model.set_lr_schedule (StepDecayLRSchedule {lr, lr_drop_ratio, lr_drop_epoch});
441453
@@ -451,7 +463,7 @@ int main(int argc, char* argv[]) {
451463 std::cout << " Loaded weights from previous " << *previous << std::endl;
452464 }
453465
454- model.train (loader, 1000 );
466+ model.train (loader, total_epochs );
455467
456468 loader.kill ();
457469
0 commit comments