Skip to content

Commit edb1f08

Browse files
committed
More CLI args
1 parent f545f05 commit edb1f08

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

src/main.cu

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ struct BerserkModel : ChessModel {
225225
const size_t n_l2 = 32;
226226
const size_t n_out = 1;
227227

228-
BerserkModel(size_t n_ft, float lambda)
228+
BerserkModel(size_t n_ft, float lambda, size_t save_rate)
229229
: ChessModel(lambda) {
230230

231231
in1 = add<SparseInput>(n_features, 32);
@@ -259,9 +259,10 @@ struct BerserkModel : ChessModel {
259259
1e-8,
260260
5 * 16384));
261261

262+
set_save_frequency(save_rate);
262263
add_quantization(Quantizer {
263264
"quant",
264-
50,
265+
save_rate,
265266
QuantizerEntry<int16_t>(&ft->weights.values, quant_one, true),
266267
QuantizerEntry<int16_t>(&ft->bias.values, quant_one),
267268
QuantizerEntry<int8_t>(&l1->weights.values, quant_two),
@@ -368,13 +369,20 @@ int main(int argc, char* argv[]) {
368369
program.add_argument("data").required().help("Directory containing training files");
369370
program.add_argument("--output").required().help("Output directory for network files");
370371
program.add_argument("--resume").help("Weights file to resume from");
372+
program.add_argument("--epochs")
373+
.default_value<int>(1000)
374+
.help("Total number of epochs to train for");
375+
program.add_argument("--save-rate")
376+
.default_value<size_t>(50)
377+
.help("How frequently to save quantized networks + weights");
371378
program.add_argument("--ft-size")
372379
.default_value<size_t>(1024)
373380
.help("Number of neurons in the Feature Transformer");
374381
program.add_argument("--lambda")
375382
.default_value<float>(0.0)
376383
.help("Ratio of evaluation score to use while training");
377-
program.add_argument("--lr").default_value<float>(1e-3).help("Initial learning rate");
384+
program.add_argument("--lr").default_value<float>(1e-3).help(
385+
"The starting learning rate for the optimizer");
378386
program.add_argument("--batch-size")
379387
.default_value<int>(16384)
380388
.help("Number of positions in a mini-batch during training");
@@ -418,14 +426,18 @@ int main(int argc, char* argv[]) {
418426
std::cout << "Loading a total of " << files.size() << " files with " << total_positions
419427
<< " total position(s)" << std::endl;
420428

429+
const int total_epochs = program.get<int>("--epochs");
430+
const size_t save_rate = program.get<size_t>("--save-rate");
421431
const size_t ft_size = program.get<size_t>("--ft-size");
422432
const float lambda = program.get<float>("--lambda");
423433
const float lr = program.get<float>("--lr");
424434
const int batch_size = program.get<int>("--batch-size");
425435
const int lr_drop_epoch = program.get<int>("--lr-drop-epoch");
426436
const float lr_drop_ratio = program.get<float>("--lr-drop-ratio");
427437

428-
std::cout << "FT Size: " << ft_size << "\n"
438+
std::cout << "Epochs: " << total_epochs << "\n"
439+
<< "Save Rate: " << save_rate << "\n"
440+
<< "FT Size: " << ft_size << "\n"
429441
<< "Lambda: " << lambda << "\n"
430442
<< "LR: " << lr << "\n"
431443
<< "Batch: " << batch_size << "\n"
@@ -435,7 +447,7 @@ int main(int argc, char* argv[]) {
435447
dataset::BatchLoader<chess::Position> loader {files, batch_size};
436448
loader.start();
437449

438-
BerserkModel model {ft_size, lambda};
450+
BerserkModel model {ft_size, lambda, save_rate};
439451
model.set_loss(MPE {2.5, true});
440452
model.set_lr_schedule(StepDecayLRSchedule {lr, lr_drop_ratio, lr_drop_epoch});
441453

@@ -451,7 +463,7 @@ int main(int argc, char* argv[]) {
451463
std::cout << "Loaded weights from previous " << *previous << std::endl;
452464
}
453465

454-
model.train(loader, 1000);
466+
model.train(loader, total_epochs);
455467

456468
loader.kill();
457469

src/nn/model/model.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,12 @@ struct Model {
203203
void next_epoch(float epoch_loss, float validation_loss = 0.0) {
204204
// quantitize weights
205205
quantize();
206+
quantize("latest.net");
206207
write_epoch_result(epoch_loss, validation_loss);
207208
// save weights
208209
if (m_epoch % m_save_frequency == 0)
209210
save_weights(this->m_path / "weights" / (std::to_string(m_epoch) + ".state"));
211+
save_weights(this->m_path / "weights" / "latest.state");
210212

211213
m_epoch++;
212214
}

0 commit comments

Comments
 (0)