Skip to content

Commit 2dfb805

Browse files
committed
More CLI args
1 parent edb1f08 commit 2dfb805

File tree

1 file changed

+34
-25
lines changed

1 file changed

+34
-25
lines changed

src/main.cu

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -370,28 +370,37 @@ int main(int argc, char* argv[]) {
370370
program.add_argument("--output").required().help("Output directory for network files");
371371
program.add_argument("--resume").help("Weights file to resume from");
372372
program.add_argument("--epochs")
373-
.default_value<int>(1000)
374-
.help("Total number of epochs to train for");
373+
.default_value(1000)
374+
.help("Total number of epochs to train for")
375+
.scan<'i', int>();
375376
program.add_argument("--save-rate")
376-
.default_value<size_t>(50)
377-
.help("How frequently to save quantized networks + weights");
377+
.default_value(50)
378+
.help("How frequently to save quantized networks + weights")
379+
.scan<'i', int>();
378380
program.add_argument("--ft-size")
379-
.default_value<size_t>(1024)
380-
.help("Number of neurons in the Feature Transformer");
381+
.default_value(1024)
382+
.help("Number of neurons in the Feature Transformer")
383+
.scan<'i', int>();
381384
program.add_argument("--lambda")
382-
.default_value<float>(0.0)
383-
.help("Ratio of evaluation score to use while training");
384-
program.add_argument("--lr").default_value<float>(1e-3).help(
385-
"The starting learning rate for the optimizer");
385+
.default_value(0.0f)
386+
.help("Ratio of evaluation scored to use while training")
387+
.scan<'f', float>();
388+
program.add_argument("--lr")
389+
.default_value(0.001f)
390+
.help("The starting learning rate for the optimizer")
391+
.scan<'f', float>();
386392
program.add_argument("--batch-size")
387-
.default_value<int>(16384)
388-
.help("Number of positions in a mini-batch during training");
393+
.default_value(16384)
394+
.help("Number of positions in a mini-batch during training")
395+
.scan<'i', int>();
389396
program.add_argument("--lr-drop-epoch")
390-
.default_value<int>(500)
391-
.help("Epoch to execute an LR drop at");
397+
.default_value(500)
398+
.help("Epoch to execute an LR drop at")
399+
.scan<'i', int>();
392400
program.add_argument("--lr-drop-ratio")
393-
.default_value<float>(1.0 / 40.0)
394-
.help("How much to scale down LR when dropping");
401+
.default_value(0.025f)
402+
.help("How much to scale down LR when dropping")
403+
.scan<'f', float>();
395404

396405
try {
397406
program.parse_args(argc, argv);
@@ -426,14 +435,14 @@ int main(int argc, char* argv[]) {
426435
std::cout << "Loading a total of " << files.size() << " files with " << total_positions
427436
<< " total position(s)" << std::endl;
428437

429-
const int total_epochs = program.get<int>("--epochs");
430-
const size_t save_rate = program.get<size_t>("--save-rate");
431-
const size_t ft_size = program.get<size_t>("--ft-size");
432-
const float lambda = program.get<float>("--lambda");
433-
const float lr = program.get<float>("--lr");
434-
const int batch_size = program.get<int>("--batch-size");
435-
const int lr_drop_epoch = program.get<int>("--lr-drop-epoch");
436-
const float lr_drop_ratio = program.get<float>("--lr-drop-ratio");
438+
const int total_epochs = program.get<int>("--epochs");
439+
const int save_rate = program.get<int>("--save-rate");
440+
const int ft_size = program.get<int>("--ft-size");
441+
const float lambda = program.get<float>("--lambda");
442+
const float lr = program.get<float>("--lr");
443+
const int batch_size = program.get<int>("--batch-size");
444+
const int lr_drop_epoch = program.get<int>("--lr-drop-epoch");
445+
const float lr_drop_ratio = program.get<float>("--lr-drop-ratio");
437446

438447
std::cout << "Epochs: " << total_epochs << "\n"
439448
<< "Save Rate: " << save_rate << "\n"
@@ -447,7 +456,7 @@ int main(int argc, char* argv[]) {
447456
dataset::BatchLoader<chess::Position> loader {files, batch_size};
448457
loader.start();
449458

450-
BerserkModel model {ft_size, lambda, save_rate};
459+
BerserkModel model {static_cast<size_t>(ft_size), lambda, static_cast<size_t>(save_rate)};
451460
model.set_loss(MPE {2.5, true});
452461
model.set_lr_schedule(StepDecayLRSchedule {lr, lr_drop_ratio, lr_drop_epoch});
453462

0 commit comments

Comments
 (0)