Skip to content

Commit d3717ba

Browse files
authored
Merge pull request #8 from rafid-dev/validation
Make validation optional
2 parents 0453f4e + 230015f commit d3717ba

File tree

2 files changed

+52
-35
lines changed

2 files changed

+52
-35
lines changed

src/main.cu

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ int main(int argc, char* argv[]) {
1111
argparse::ArgumentParser program("Grapheus");
1212

1313
program.add_argument("data").required().help("Directory containing training files");
14-
program.add_argument("--val-data").required().help("Directory containing validation files");
14+
program.add_argument("--val-data").help("Directory containing validation files");
1515
program.add_argument("--output").required().help("Output directory for network files");
1616
program.add_argument("--resume").help("Weights file to resume from");
1717
program.add_argument("--epochs")
@@ -22,6 +22,11 @@ int main(int argc, char* argv[]) {
2222
.default_value(100000000)
2323
.help("Total positions in each epoch")
2424
.scan<'i', int>();
25+
26+
program.add_argument("--val--size")
27+
.default_value(10000000)
28+
.help("Total positions for each validation epoch")
29+
.scan<'i', int>();
2530
program.add_argument("--save-rate")
2631
.default_value(50)
2732
.help("How frequently to save quantized networks + weights")
@@ -66,9 +71,6 @@ int main(int argc, char* argv[]) {
6671
// Fetch training dataset paths
6772
std::vector<std::string> train_files = dataset::fetch_dataset_paths(program.get("data"));
6873

69-
// Fetch validation dataset paths
70-
std::vector<std::string> val_files = dataset::fetch_dataset_paths(program.get("--val-data"));
71-
7274
// Print training dataset file list if files are found
7375
if (!train_files.empty()) {
7476
std::cout << "Training Dataset Files:" << std::endl;
@@ -84,6 +86,13 @@ int main(int argc, char* argv[]) {
8486
exit(0);
8587
}
8688

89+
// Fetch validation dataset paths
90+
std::vector<std::string> val_files;
91+
92+
if (program.present("--val-data")) {
93+
val_files = dataset::fetch_dataset_paths(program.get("--val-data"));
94+
}
95+
8796
// Print validation dataset file list if files are found
8897
if (!val_files.empty()) {
8998
std::cout << "Validation Dataset Files:" << std::endl;
@@ -93,20 +102,18 @@ int main(int argc, char* argv[]) {
93102
std::cout << "Total validation files: " << val_files.size() << std::endl;
94103
std::cout << "Total validation positions: " << dataset::count_total_positions(val_files)
95104
<< std::endl;
96-
} else {
97-
std::cout << "No validation files found in " << program.get("--val-data") << std::endl;
98-
exit(0);
99105
}
100106

101-
const int total_epochs = program.get<int>("--epochs");
102-
const int epoch_size = program.get<int>("--epoch-size");
103-
const int save_rate = program.get<int>("--save-rate");
104-
const int ft_size = program.get<int>("--ft-size");
105-
const float lambda = program.get<float>("--lambda");
106-
const float lr = program.get<float>("--lr");
107-
const int batch_size = program.get<int>("--batch-size");
108-
const int lr_drop_epoch = program.get<int>("--lr-drop-epoch");
109-
const float lr_drop_ratio = program.get<float>("--lr-drop-ratio");
107+
const int total_epochs = program.get<int>("--epochs");
108+
const int epoch_size = program.get<int>("--epoch-size");
109+
const int val_epoch_size = program.get<int>("--val-size");
110+
const int save_rate = program.get<int>("--save-rate");
111+
const int ft_size = program.get<int>("--ft-size");
112+
const float lambda = program.get<float>("--lambda");
113+
const float lr = program.get<float>("--lr");
114+
const int batch_size = program.get<int>("--batch-size");
115+
const int lr_drop_epoch = program.get<int>("--lr-drop-epoch");
116+
const float lr_drop_ratio = program.get<float>("--lr-drop-ratio");
110117

111118
std::cout << "Epochs: " << total_epochs << "\n"
112119
<< "Epochs Size: " << epoch_size << "\n"
@@ -121,10 +128,13 @@ int main(int argc, char* argv[]) {
121128
using BatchLoader = dataset::BatchLoader<chess::Position>;
122129

123130
BatchLoader train_loader {train_files, batch_size};
124-
BatchLoader val_loader {val_files, batch_size};
125-
126131
train_loader.start();
127-
val_loader.start();
132+
133+
std::optional<BatchLoader> val_loader;
134+
if (val_files.size() > 0) {
135+
val_loader.emplace(val_files, batch_size);
136+
val_loader->start();
137+
}
128138

129139
model::BerserkModel model {static_cast<size_t>(ft_size), lambda, static_cast<size_t>(save_rate)};
130140
model.set_loss(MPE {2.5, true});
@@ -142,10 +152,10 @@ int main(int argc, char* argv[]) {
142152
std::cout << "Loaded weights from previous " << *previous << std::endl;
143153
}
144154

145-
model.train(train_loader, val_loader, total_epochs, epoch_size);
155+
model.train(train_loader, val_loader, total_epochs, epoch_size, val_epoch_size);
146156

147157
train_loader.kill();
148-
val_loader.kill();
158+
val_loader->kill();
149159

150160
close();
151161
return 0;

src/models/chessmodel.h

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include "../operations/operations.h"
1212

1313
#include <algorithm>
14+
#include <optional>
15+
1416
namespace model {
1517

1618
using namespace nn;
@@ -45,14 +47,15 @@ struct ChessModel : Model {
4547
* of epochs.
4648
*
4749
* @param train_loader The batch loader for training data.
48-
* @param val_loader The batch loader for validation data.
50+
* @param val_loader The batch loader for validation data (optional).
4951
* @param epochs Number of training epochs (default: 1500).
5052
* @param epoch_size Number of batches per epoch (default: 1e8).
5153
*/
52-
void train(BatchLoader& train_loader,
53-
BatchLoader& val_loader,
54-
int epochs = 1500,
55-
int epoch_size = 1e8) {
54+
void train(BatchLoader& train_loader,
55+
std::optional<BatchLoader>& val_loader,
56+
int epochs = 1500,
57+
int epoch_size = 1e8,
58+
int val_epoch_size = 1e7) {
5659

5760
this->compile(train_loader.batch_size);
5861

@@ -90,19 +93,23 @@ struct ChessModel : Model {
9093
}
9194
}
9295

93-
// Validation phase
94-
for (int b = 1; b <= epoch_size / val_loader.batch_size; b++) {
95-
auto* ds = val_loader.next();
96-
setup_inputs_and_outputs(ds);
96+
// Validation phase (if validation loader is provided)
97+
if (val_loader.has_value()) {
98+
for (int b = 1; b <= val_epoch_size / val_loader->batch_size; b++) {
99+
auto* ds = val_loader->next();
100+
setup_inputs_and_outputs(ds);
97101

98-
float val_batch_loss = loss();
99-
total_val_loss += val_batch_loss;
102+
float val_batch_loss = loss();
103+
total_val_loss += val_batch_loss;
104+
}
100105
}
101106

102-
float epoch_loss = total_epoch_loss / (epoch_size / train_loader.batch_size);
103-
float val_loss = total_val_loss / (epoch_size / val_loader.batch_size);
107+
float epoch_loss = total_epoch_loss / (val_epoch_size / train_loader.batch_size);
108+
float val_loss = (val_loader.has_value())
109+
? total_val_loss / (val_epoch_size / val_loader->batch_size)
110+
: 0;
104111

105-
printf(", val_loss = [%1.8f]", val_loss);
112+
printf(", val_loss = [%1.8f] ", val_loss);
106113
next_epoch(epoch_loss, val_loss);
107114
std::cout << std::endl;
108115
}

0 commit comments

Comments
 (0)