Skip to content

Commit a61de12

Browse files
committed
Fix OpenMP + remove validation data + print total fens
1 parent f276bfa commit a61de12

File tree

2 files changed

+19
-36
lines changed

2 files changed

+19
-36
lines changed

CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "${CMAKE_COMMAND} -E time")
1717

1818
set_target_properties(Grapheus PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
1919

20-
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -march=native")
20+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -march=native -fopenmp")
2121

2222
target_link_libraries(Grapheus ${CUDA_LIBRARIES})
2323
target_link_libraries(Grapheus ${CUDA_CUBLAS_LIBRARIES})
2424
target_link_libraries(Grapheus ${CUDA_cusparse_LIBRARY})
25-
target_link_libraries(Grapheus ${CMAKE_THREAD_LIBS_INIT})
25+
target_link_libraries(Grapheus ${CMAKE_THREAD_LIBS_INIT})
26+
target_link_libraries(Grapheus ${OpenMP_CXX_LIBRARIES})

src/main.cu

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,8 @@ struct ChessModel : nn::Model {
2525

2626
// train function
2727
void train(dataset::BatchLoader<chess::Position>& loader,
28-
dataset::BatchLoader<chess::Position>& validation_loader,
2928
int epochs = 1500,
30-
int epoch_size = 1e8,
31-
int validation_size = 1e7) {
29+
int epoch_size = 1e8) {
3230
this->compile(loader.batch_size);
3331

3432
Timer t {};
@@ -37,7 +35,6 @@ struct ChessModel : nn::Model {
3735

3836
uint64_t prev_print_tm = 0;
3937
float total_epoch_loss = 0;
40-
float total_validation_loss = 0;
4138

4239
for (int b = 1; b <= epoch_size / loader.batch_size; b++) {
4340
auto* ds = loader.next();
@@ -67,20 +64,7 @@ struct ChessModel : nn::Model {
6764
std::cout << std::endl;
6865

6966
float epoch_loss = total_epoch_loss / (epoch_size / loader.batch_size);
70-
71-
for (int b = 1; b <= validation_size / validation_loader.batch_size; b++) {
72-
auto* ds = validation_loader.next();
73-
setup_inputs_and_outputs(ds, 0.5);
74-
75-
total_validation_loss += loss();
76-
}
77-
78-
float validation_loss =
79-
total_validation_loss / (validation_size / validation_loader.batch_size);
80-
printf("ep = [%4d], valid_loss = [%1.8f]", i, validation_loss);
81-
std::cout << std::endl;
82-
83-
next_epoch(epoch_loss, validation_loss);
67+
next_epoch(epoch_loss, 0.0);
8468
}
8569
}
8670

@@ -391,35 +375,34 @@ int main(int argc, char* argv[]) {
391375
std::exit(1);
392376
}
393377

394-
395378
math::seed(0);
396379

397380
init();
398381

399382
std::vector<std::string> files {};
400-
std::vector<std::string> validation_files {};
401383

402384
for (const auto& entry : fs::directory_iterator(program.get("data"))) {
403385
const std::string path = entry.path().string();
404-
if (path.find("valid") != std::string::npos) {
405-
std::cout << "Specifying " << path << " as validation data!" << std::endl;
406-
validation_files.push_back(path);
407-
} else {
408-
std::cout << "Specifying " << path << " as training data!" << std::endl;
409-
files.push_back(path);
410-
}
386+
files.push_back(path);
387+
}
388+
389+
uint64_t total_positions = 0;
390+
for (const auto& file_path : files) {
391+
FILE* fin = fopen(file_path.c_str(), "rb");
392+
393+
dataset::DataSetHeader h {};
394+
fread(&h, sizeof(dataset::DataSetHeader), 1, fin);
395+
396+
total_positions += h.entry_count;
397+
fclose(fin);
411398
}
412399

413-
if (validation_files.empty())
414-
validation_files.push_back(files.at(0));
400+
std::cout << "Loading a total of " << files.size() << " files with " << total_positions << " total position(s)" << std::endl;
415401

416402
const int batch_size = 16384;
417403
dataset::BatchLoader<chess::Position> loader {files, batch_size};
418404
loader.start();
419405

420-
dataset::BatchLoader<chess::Position> validation_loader {validation_files, batch_size};
421-
validation_loader.start();
422-
423406
BerserkModel model {};
424407
model.set_loss(MPE {2.5, true});
425408
model.set_lr_schedule(StepDecayLRSchedule {1e-3, 1.0 / 40.0, 500});
@@ -433,10 +416,9 @@ int main(int argc, char* argv[]) {
433416
model.load_weights(*previous);
434417
}
435418

436-
model.train(loader, validation_loader, 1000);
419+
model.train(loader, 1000);
437420

438421
loader.kill();
439-
validation_loader.kill();
440422

441423
close();
442424
return 0;

0 commit comments

Comments
 (0)