Skip to content

Commit 0453f4e

Browse files
authored
Merge pull request #7 from rafid-dev/validation
2 parents 0706691 + 2360a51 commit 0453f4e

File tree

4 files changed

+126
-38
lines changed

4 files changed

+126
-38
lines changed

src/dataset/io.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "dataset.h"
44

55
#include <filesystem>
6+
#include <fstream>
67
#include <iostream>
78

89
namespace dataset {
@@ -120,4 +121,43 @@ bool is_readable(const std::string& file) {
120121
return expected_size == size;
121122
}
122123

124+
/**
125+
* @brief Counts total positions across all datasets from provided file paths.
126+
*
127+
* Iterates through file paths, reads headers, and sums positions for total count.
128+
*
129+
* @param files Vector of dataset file paths.
130+
* @return Total positions across all datasets.
131+
*/
132+
uint64_t count_total_positions(const std::vector<std::string>& files) {
133+
uint64_t total_positions = 0;
134+
135+
// Iterate through each file path and read dataset headers to count positions
136+
for (const auto& path : files) {
137+
std::ifstream fin(path, std::ios::binary);
138+
DataSetHeader h {};
139+
fin.read(reinterpret_cast<char*>(&h), sizeof(DataSetHeader));
140+
total_positions += h.entry_count;
141+
}
142+
143+
return total_positions;
144+
}
145+
146+
/**
147+
* @brief Retrieves dataset file paths from the specified directory.
148+
*
149+
* Iterates through the directory and collects paths of dataset files.
150+
*
151+
* @param directory The directory containing dataset files.
152+
* @return Vector of dataset file paths.
153+
*/
154+
auto fetch_dataset_paths(const std::string& directory) {
155+
std::vector<std::string> files;
156+
for (const auto& entry : std::filesystem::directory_iterator(directory)) {
157+
const std::string path = entry.path().string();
158+
files.push_back(path);
159+
}
160+
return files;
161+
}
162+
123163
} // namespace dataset

src/main.cu

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,22 @@
66

77
using namespace nn;
88
using namespace data;
9-
namespace fs = std::filesystem;
109

1110
int main(int argc, char* argv[]) {
1211
argparse::ArgumentParser program("Grapheus");
1312

1413
program.add_argument("data").required().help("Directory containing training files");
14+
program.add_argument("--val-data").required().help("Directory containing validation files");
1515
program.add_argument("--output").required().help("Output directory for network files");
1616
program.add_argument("--resume").help("Weights file to resume from");
1717
program.add_argument("--epochs")
1818
.default_value(1000)
1919
.help("Total number of epochs to train for")
2020
.scan<'i', int>();
21+
program.add_argument("--epoch-size")
22+
.default_value(100000000)
23+
.help("Total positions in each epoch")
24+
.scan<'i', int>();
2125
program.add_argument("--save-rate")
2226
.default_value(50)
2327
.help("How frequently to save quantized networks + weights")
@@ -59,28 +63,43 @@ int main(int argc, char* argv[]) {
5963

6064
init();
6165

62-
std::vector<std::string> files {};
63-
64-
for (const auto& entry : fs::directory_iterator(program.get("data"))) {
65-
const std::string path = entry.path().string();
66-
files.push_back(path);
66+
// Fetch training dataset paths
67+
std::vector<std::string> train_files = dataset::fetch_dataset_paths(program.get("data"));
68+
69+
// Fetch validation dataset paths
70+
std::vector<std::string> val_files = dataset::fetch_dataset_paths(program.get("--val-data"));
71+
72+
// Print training dataset file list if files are found
73+
if (!train_files.empty()) {
74+
std::cout << "Training Dataset Files:" << std::endl;
75+
for (const auto& file : train_files) {
76+
std::cout << file << std::endl;
77+
}
78+
std::cout << "Total training files: " << train_files.size() << std::endl;
79+
std::cout << "Total training positions: " << dataset::count_total_positions(train_files)
80+
<< std::endl
81+
<< std::endl;
82+
} else {
83+
std::cout << "No training files found in " << program.get("data") << std::endl << std::endl;
84+
exit(0);
6785
}
6886

69-
uint64_t total_positions = 0;
70-
for (const auto& file_path : files) {
71-
FILE* fin = fopen(file_path.c_str(), "rb");
72-
73-
dataset::DataSetHeader h {};
74-
fread(&h, sizeof(dataset::DataSetHeader), 1, fin);
75-
76-
total_positions += h.entry_count;
77-
fclose(fin);
87+
// Print validation dataset file list if files are found
88+
if (!val_files.empty()) {
89+
std::cout << "Validation Dataset Files:" << std::endl;
90+
for (const auto& file : val_files) {
91+
std::cout << file << std::endl;
92+
}
93+
std::cout << "Total validation files: " << val_files.size() << std::endl;
94+
std::cout << "Total validation positions: " << dataset::count_total_positions(val_files)
95+
<< std::endl;
96+
} else {
97+
std::cout << "No validation files found in " << program.get("--val-data") << std::endl;
98+
exit(0);
7899
}
79100

80-
std::cout << "Loading a total of " << files.size() << " files with " << total_positions
81-
<< " total position(s)" << std::endl;
82-
83101
const int total_epochs = program.get<int>("--epochs");
102+
const int epoch_size = program.get<int>("--epoch-size");
84103
const int save_rate = program.get<int>("--save-rate");
85104
const int ft_size = program.get<int>("--ft-size");
86105
const float lambda = program.get<float>("--lambda");
@@ -90,6 +109,7 @@ int main(int argc, char* argv[]) {
90109
const float lr_drop_ratio = program.get<float>("--lr-drop-ratio");
91110

92111
std::cout << "Epochs: " << total_epochs << "\n"
112+
<< "Epochs Size: " << epoch_size << "\n"
93113
<< "Save Rate: " << save_rate << "\n"
94114
<< "FT Size: " << ft_size << "\n"
95115
<< "Lambda: " << lambda << "\n"
@@ -98,8 +118,13 @@ int main(int argc, char* argv[]) {
98118
<< "LR Drop @ " << lr_drop_epoch << "\n"
99119
<< "LR Drop R " << lr_drop_ratio << std::endl;
100120

101-
dataset::BatchLoader<chess::Position> loader {files, batch_size};
102-
loader.start();
121+
using BatchLoader = dataset::BatchLoader<chess::Position>;
122+
123+
BatchLoader train_loader {train_files, batch_size};
124+
BatchLoader val_loader {val_files, batch_size};
125+
126+
train_loader.start();
127+
val_loader.start();
103128

104129
model::BerserkModel model {static_cast<size_t>(ft_size), lambda, static_cast<size_t>(save_rate)};
105130
model.set_loss(MPE {2.5, true});
@@ -117,9 +142,10 @@ int main(int argc, char* argv[]) {
117142
std::cout << "Loaded weights from previous " << *previous << std::endl;
118143
}
119144

120-
model.train(loader, total_epochs);
145+
model.train(train_loader, val_loader, total_epochs, epoch_size);
121146

122-
loader.kill();
147+
train_loader.kill();
148+
val_loader.kill();
123149

124150
close();
125151
return 0;

src/models/berserk.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ struct BerserkModel : ChessModel {
106106

107107
auto& target = m_loss->target;
108108

109-
#pragma omp parallel for schedule(static) num_threads(16)
109+
#pragma omp parallel for schedule(static) num_threads(6)
110110
for (int b = 0; b < positions->header.entry_count; b++) {
111111
chess::Position* pos = &positions->positions[b];
112112
// fill in the inputs and target values

src/models/chessmodel.h

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,35 @@ struct ChessModel : Model {
3838
*/
3939
virtual void setup_inputs_and_outputs(dataset::DataSet<chess::Position>* positions) = 0;
4040

41-
/// @brief Train the model with the given batch loader.
42-
/// @param loader Batch loader for dataset input.
43-
/// @param epochs Number of training epochs.
44-
/// @param epoch_size Size of each training epoch.
45-
void train(dataset::BatchLoader<chess::Position>& loader,
46-
int epochs = 1500,
47-
int epoch_size = 1e8) {
48-
this->compile(loader.batch_size);
41+
using BatchLoader = dataset::BatchLoader<chess::Position>;
42+
43+
/**
44+
* @brief Trains the model using the provided train and validation loaders for a specified number
45+
* of epochs.
46+
*
47+
* @param train_loader The batch loader for training data.
48+
* @param val_loader The batch loader for validation data.
49+
* @param epochs Number of training epochs (default: 1500).
50+
* @param epoch_size Number of batches per epoch (default: 1e8).
51+
*/
52+
void train(BatchLoader& train_loader,
53+
BatchLoader& val_loader,
54+
int epochs = 1500,
55+
int epoch_size = 1e8) {
56+
57+
this->compile(train_loader.batch_size);
4958

5059
Timer t {};
5160
for (int i = 1; i <= epochs; i++) {
5261
t.start();
5362

5463
uint64_t prev_print_tm = 0;
5564
float total_epoch_loss = 0;
65+
float total_val_loss = 0;
5666

57-
for (int b = 1; b <= epoch_size / loader.batch_size; b++) {
58-
auto* ds = loader.next();
67+
// Training phase
68+
for (int b = 1; b <= epoch_size / train_loader.batch_size; b++) {
69+
auto* ds = train_loader.next();
5970
setup_inputs_and_outputs(ds);
6071

6172
float batch_loss = batch();
@@ -64,7 +75,7 @@ struct ChessModel : Model {
6475

6576
t.stop();
6677
uint64_t elapsed = t.elapsed();
67-
if (elapsed - prev_print_tm > 1000 || b == epoch_size / loader.batch_size) {
78+
if (elapsed - prev_print_tm > 1000 || b == epoch_size / train_loader.batch_size) {
6879
prev_print_tm = elapsed;
6980

7081
printf("\rep = [%4d], epoch_loss = [%1.8f], batch = [%5d], batch_loss = [%1.8f], "
@@ -79,10 +90,21 @@ struct ChessModel : Model {
7990
}
8091
}
8192

82-
std::cout << std::endl;
93+
// Validation phase
94+
for (int b = 1; b <= epoch_size / val_loader.batch_size; b++) {
95+
auto* ds = val_loader.next();
96+
setup_inputs_and_outputs(ds);
97+
98+
float val_batch_loss = loss();
99+
total_val_loss += val_batch_loss;
100+
}
101+
102+
float epoch_loss = total_epoch_loss / (epoch_size / train_loader.batch_size);
103+
float val_loss = total_val_loss / (epoch_size / val_loader.batch_size);
83104

84-
float epoch_loss = total_epoch_loss / (epoch_size / loader.batch_size);
85-
next_epoch(epoch_loss, 0.0);
105+
printf(", val_loss = [%1.8f]", val_loss);
106+
next_epoch(epoch_loss, val_loss);
107+
std::cout << std::endl;
86108
}
87109
}
88110

@@ -236,4 +258,4 @@ struct ChessModel : Model {
236258
}
237259
}
238260
};
239-
}
261+
} // namespace model

0 commit comments

Comments
 (0)