Skip to content

Commit 230015f

Browse files
committed
Add validation size
1 parent 9b6fb16 commit 230015f

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

src/main.cu

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ int main(int argc, char* argv[]) {
2222
.default_value(100000000)
2323
.help("Total positions in each epoch")
2424
.scan<'i', int>();
25+
26+
program.add_argument("--val--size")
27+
.default_value(10000000)
28+
.help("Total positions for each validation epoch")
29+
.scan<'i', int>();
2530
program.add_argument("--save-rate")
2631
.default_value(50)
2732
.help("How frequently to save quantized networks + weights")
@@ -99,15 +104,16 @@ int main(int argc, char* argv[]) {
99104
<< std::endl;
100105
}
101106

102-
const int total_epochs = program.get<int>("--epochs");
103-
const int epoch_size = program.get<int>("--epoch-size");
104-
const int save_rate = program.get<int>("--save-rate");
105-
const int ft_size = program.get<int>("--ft-size");
106-
const float lambda = program.get<float>("--lambda");
107-
const float lr = program.get<float>("--lr");
108-
const int batch_size = program.get<int>("--batch-size");
109-
const int lr_drop_epoch = program.get<int>("--lr-drop-epoch");
110-
const float lr_drop_ratio = program.get<float>("--lr-drop-ratio");
107+
const int total_epochs = program.get<int>("--epochs");
108+
const int epoch_size = program.get<int>("--epoch-size");
109+
const int val_epoch_size = program.get<int>("--val-size");
110+
const int save_rate = program.get<int>("--save-rate");
111+
const int ft_size = program.get<int>("--ft-size");
112+
const float lambda = program.get<float>("--lambda");
113+
const float lr = program.get<float>("--lr");
114+
const int batch_size = program.get<int>("--batch-size");
115+
const int lr_drop_epoch = program.get<int>("--lr-drop-epoch");
116+
const float lr_drop_ratio = program.get<float>("--lr-drop-ratio");
111117

112118
std::cout << "Epochs: " << total_epochs << "\n"
113119
<< "Epochs Size: " << epoch_size << "\n"
@@ -146,7 +152,7 @@ int main(int argc, char* argv[]) {
146152
std::cout << "Loaded weights from previous " << *previous << std::endl;
147153
}
148154

149-
model.train(train_loader, val_loader, total_epochs, epoch_size);
155+
model.train(train_loader, val_loader, total_epochs, epoch_size, val_epoch_size);
150156

151157
train_loader.kill();
152158
val_loader->kill();

src/models/chessmodel.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ struct ChessModel : Model {
5353
*/
5454
void train(BatchLoader& train_loader,
5555
std::optional<BatchLoader>& val_loader,
56-
int epochs = 1500,
57-
int epoch_size = 1e8) {
56+
int epochs = 1500,
57+
int epoch_size = 1e8,
58+
int val_epoch_size = 1e7) {
5859

5960
this->compile(train_loader.batch_size);
6061

@@ -92,8 +93,6 @@ struct ChessModel : Model {
9293
}
9394
}
9495

95-
int val_epoch_size = epoch_size / 10;
96-
9796
// Validation phase (if validation loader is provided)
9897
if (val_loader.has_value()) {
9998
for (int b = 1; b <= val_epoch_size / val_loader->batch_size; b++) {

0 commit comments

Comments
 (0)