Skip to content

Commit 9b6fb16

Browse files
committed
Make validation optional
1 parent 2360a51 commit 9b6fb16

File tree

2 files changed

+37
-25
lines changed

2 files changed

+37
-25
lines changed

src/main.cu

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ int main(int argc, char* argv[]) {
1111
argparse::ArgumentParser program("Grapheus");
1212

1313
program.add_argument("data").required().help("Directory containing training files");
14-
program.add_argument("--val-data").required().help("Directory containing validation files");
14+
program.add_argument("--val-data").help("Directory containing validation files");
1515
program.add_argument("--output").required().help("Output directory for network files");
1616
program.add_argument("--resume").help("Weights file to resume from");
1717
program.add_argument("--epochs")
@@ -66,9 +66,6 @@ int main(int argc, char* argv[]) {
6666
// Fetch training dataset paths
6767
std::vector<std::string> train_files = dataset::fetch_dataset_paths(program.get("data"));
6868

69-
// Fetch validation dataset paths
70-
std::vector<std::string> val_files = dataset::fetch_dataset_paths(program.get("--val-data"));
71-
7269
// Print training dataset file list if files are found
7370
if (!train_files.empty()) {
7471
std::cout << "Training Dataset Files:" << std::endl;
@@ -84,6 +81,13 @@ int main(int argc, char* argv[]) {
8481
exit(0);
8582
}
8683

84+
// Fetch validation dataset paths
85+
std::vector<std::string> val_files;
86+
87+
if (program.present("--val-data")) {
88+
val_files = dataset::fetch_dataset_paths(program.get("--val-data"));
89+
}
90+
8791
// Print validation dataset file list if files are found
8892
if (!val_files.empty()) {
8993
std::cout << "Validation Dataset Files:" << std::endl;
@@ -93,9 +97,6 @@ int main(int argc, char* argv[]) {
9397
std::cout << "Total validation files: " << val_files.size() << std::endl;
9498
std::cout << "Total validation positions: " << dataset::count_total_positions(val_files)
9599
<< std::endl;
96-
} else {
97-
std::cout << "No validation files found in " << program.get("--val-data") << std::endl;
98-
exit(0);
99100
}
100101

101102
const int total_epochs = program.get<int>("--epochs");
@@ -121,10 +122,13 @@ int main(int argc, char* argv[]) {
121122
using BatchLoader = dataset::BatchLoader<chess::Position>;
122123

123124
BatchLoader train_loader {train_files, batch_size};
124-
BatchLoader val_loader {val_files, batch_size};
125-
126125
train_loader.start();
127-
val_loader.start();
126+
127+
std::optional<BatchLoader> val_loader;
128+
if (val_files.size() > 0) {
129+
val_loader.emplace(val_files, batch_size);
130+
val_loader->start();
131+
}
128132

129133
model::BerserkModel model {static_cast<size_t>(ft_size), lambda, static_cast<size_t>(save_rate)};
130134
model.set_loss(MPE {2.5, true});
@@ -145,7 +149,7 @@ int main(int argc, char* argv[]) {
145149
model.train(train_loader, val_loader, total_epochs, epoch_size);
146150

147151
train_loader.kill();
148-
val_loader.kill();
152+
val_loader->kill();
149153

150154
close();
151155
return 0;

src/models/chessmodel.h

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include "../operations/operations.h"
1212

1313
#include <algorithm>
14+
#include <optional>
15+
1416
namespace model {
1517

1618
using namespace nn;
@@ -45,14 +47,14 @@ struct ChessModel : Model {
4547
* of epochs.
4648
*
4749
* @param train_loader The batch loader for training data.
48-
* @param val_loader The batch loader for validation data.
50+
* @param val_loader The batch loader for validation data (optional).
4951
* @param epochs Number of training epochs (default: 1500).
5052
* @param epoch_size Number of batches per epoch (default: 1e8).
5153
*/
52-
void train(BatchLoader& train_loader,
53-
BatchLoader& val_loader,
54-
int epochs = 1500,
55-
int epoch_size = 1e8) {
54+
void train(BatchLoader& train_loader,
55+
std::optional<BatchLoader>& val_loader,
56+
int epochs = 1500,
57+
int epoch_size = 1e8) {
5658

5759
this->compile(train_loader.batch_size);
5860

@@ -90,19 +92,25 @@ struct ChessModel : Model {
9092
}
9193
}
9294

93-
// Validation phase
94-
for (int b = 1; b <= epoch_size / val_loader.batch_size; b++) {
95-
auto* ds = val_loader.next();
96-
setup_inputs_and_outputs(ds);
95+
int val_epoch_size = epoch_size / 10;
9796

98-
float val_batch_loss = loss();
99-
total_val_loss += val_batch_loss;
97+
// Validation phase (if validation loader is provided)
98+
if (val_loader.has_value()) {
99+
for (int b = 1; b <= val_epoch_size / val_loader->batch_size; b++) {
100+
auto* ds = val_loader->next();
101+
setup_inputs_and_outputs(ds);
102+
103+
float val_batch_loss = loss();
104+
total_val_loss += val_batch_loss;
105+
}
100106
}
101107

102-
float epoch_loss = total_epoch_loss / (epoch_size / train_loader.batch_size);
103-
float val_loss = total_val_loss / (epoch_size / val_loader.batch_size);
108+
float epoch_loss = total_epoch_loss / (val_epoch_size / train_loader.batch_size);
109+
float val_loss = (val_loader.has_value())
110+
? total_val_loss / (val_epoch_size / val_loader->batch_size)
111+
: 0;
104112

105-
printf(", val_loss = [%1.8f]", val_loss);
113+
printf(", val_loss = [%1.8f] ", val_loss);
106114
next_epoch(epoch_loss, val_loss);
107115
std::cout << std::endl;
108116
}

0 commit comments

Comments
 (0)