Skip to content

Commit d8b7c01

Browse files
committed
Simplify feature creation in loader
1 parent 6b2e13a commit d8b7c01

File tree

1 file changed

+93
-97
lines changed

1 file changed

+93
-97
lines changed

training_data_loader.cpp

Lines changed: 93 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
#include <cstddef>
12
#include <iostream>
23
#include <memory>
34
#include <string>
45
#include <algorithm>
56
#include <iterator>
67
#include <future>
78
#include <mutex>
9+
#include <string_view>
810
#include <thread>
911
#include <deque>
1012
#include <random>
13+
#include <variant>
1114

1215
#include "lib/nnue_training_data_formats.h"
1316
#include "lib/nnue_training_data_stream.h"
@@ -55,6 +58,8 @@ static Square orient_flip(Color color, Square sq) {
5558
}
5659

5760
struct HalfKP {
61+
static constexpr std::string_view NAME = "HalfKP";
62+
5863
static constexpr int NUM_SQ = 64;
5964
static constexpr int NUM_PT = 10;
6065
static constexpr int NUM_PLANES = (NUM_SQ * NUM_PT + 1);
@@ -92,6 +97,8 @@ struct HalfKP {
9297
};
9398

9499
struct HalfKPFactorized {
100+
static constexpr std::string_view NAME = "HalfKP^";
101+
95102
// Factorized features
96103
static constexpr int K_INPUTS = HalfKP::NUM_SQ;
97104
static constexpr int PIECE_INPUTS = HalfKP::NUM_SQ * HalfKP::NUM_PT;
@@ -137,6 +144,8 @@ struct HalfKPFactorized {
137144
};
138145

139146
struct HalfKA {
147+
static constexpr std::string_view NAME = "HalfKA";
148+
140149
static constexpr int NUM_SQ = 64;
141150
static constexpr int NUM_PT = 12;
142151
static constexpr int NUM_PLANES = (NUM_SQ * NUM_PT + 1);
@@ -170,6 +179,8 @@ struct HalfKA {
170179
};
171180

172181
struct HalfKAFactorized {
182+
static constexpr std::string_view NAME = "HalfKA^";
183+
173184
// Factorized features
174185
static constexpr int PIECE_INPUTS = HalfKA::NUM_SQ * HalfKA::NUM_PT;
175186
static constexpr int INPUTS = HalfKA::INPUTS + PIECE_INPUTS;
@@ -199,6 +210,8 @@ struct HalfKAFactorized {
199210
};
200211

201212
struct HalfKAv2 {
213+
static constexpr std::string_view NAME = "HalfKAv2";
214+
202215
static constexpr int NUM_SQ = 64;
203216
static constexpr int NUM_PT = 11;
204217
static constexpr int NUM_PLANES = NUM_SQ * NUM_PT;
@@ -234,6 +247,8 @@ struct HalfKAv2 {
234247
};
235248

236249
struct HalfKAv2Factorized {
250+
static constexpr std::string_view NAME = "HalfKAv2^";
251+
237252
// Factorized features
238253
static constexpr int NUM_PT = 12;
239254
static constexpr int PIECE_INPUTS = HalfKAv2::NUM_SQ * NUM_PT;
@@ -274,6 +289,8 @@ static Square orient_flip_2(Color color, Square sq, Square ksq) {
274289
}
275290

276291
struct HalfKAv2_hm {
292+
static constexpr std::string_view NAME = "HalfKAv2_hm";
293+
277294
static constexpr int NUM_SQ = 64;
278295
static constexpr int NUM_PT = 11;
279296
static constexpr int NUM_PLANES = NUM_SQ * NUM_PT;
@@ -325,6 +342,8 @@ struct HalfKAv2_hm {
325342
};
326343

327344
struct HalfKAv2_hmFactorized {
345+
static constexpr std::string_view NAME = "HalfKAv2_hm^";
346+
328347
// Factorized features
329348
static constexpr int NUM_PT = 12;
330349
static constexpr int PIECE_INPUTS = HalfKAv2_hm::NUM_SQ * NUM_PT;
@@ -399,6 +418,8 @@ constexpr auto threatoffsets = []() {
399418
}();
400419

401420
struct Full_Threats {
421+
static constexpr std::string_view NAME = "Full_Threats";
422+
402423
static constexpr int SQUARE_NB = 64;
403424
static constexpr int PIECE_NB = 12;
404425
static constexpr int COLOR_NB = 2;
@@ -570,6 +591,8 @@ struct Full_Threats {
570591
};
571592

572593
struct Full_ThreatsFactorized {
594+
static constexpr std::string_view NAME = "Full_Threats^";
595+
573596
// Factorized features
574597
static constexpr int PIECE_INPUTS = 768;
575598
static constexpr int INPUTS = 79856 + 22528 + 768;
@@ -605,12 +628,50 @@ struct FeatureSet {
605628
static constexpr int INPUTS = T::INPUTS;
606629
static constexpr int MAX_ACTIVE_FEATURES = T::MAX_ACTIVE_FEATURES;
607630

631+
static constexpr std::string_view NAME = T::NAME;
632+
608633
static std::pair<int, int>
609634
fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color) {
610635
return T::fill_features_sparse(e, features, values, color);
611636
}
612637
};
613638

639+
640+
template<typename... Ts>
641+
auto find_feature(std::string_view name) {
642+
using Variant = std::variant<std::monostate, Ts...>;
643+
using Factory = Variant (*)();
644+
645+
struct Entry {
646+
std::string_view name;
647+
Factory make;
648+
};
649+
650+
static constexpr Entry factories[] = {{Ts::NAME, +[]() -> Variant { return Ts{}; }}...};
651+
652+
for (auto& f : factories)
653+
{
654+
if (name == f.name)
655+
return f.make();
656+
}
657+
658+
return Variant{std::monostate{}};
659+
}
660+
661+
auto get_feature(std::string_view name) {
662+
return find_feature<HalfKP, //
663+
HalfKPFactorized, //
664+
HalfKA, //
665+
HalfKAFactorized, //
666+
HalfKAv2, //
667+
HalfKAv2Factorized, //
668+
HalfKAv2_hm, //
669+
HalfKAv2_hmFactorized, //
670+
Full_Threats, //
671+
Full_ThreatsFactorized //
672+
>(name);
673+
}
674+
614675
struct SparseBatch {
615676
static constexpr bool IS_BATCH = true;
616677

@@ -1185,49 +1246,21 @@ EXPORT SparseBatch* get_sparse_batch_from_fens(const char* feature_set_c,
11851246
e.result = results[i];
11861247
}
11871248

1188-
std::string_view feature_set(feature_set_c);
1189-
if (feature_set == "HalfKP")
1190-
{
1191-
return new SparseBatch(FeatureSet<HalfKP>{}, entries);
1192-
}
1193-
else if (feature_set == "HalfKP^")
1194-
{
1195-
return new SparseBatch(FeatureSet<HalfKPFactorized>{}, entries);
1196-
}
1197-
else if (feature_set == "HalfKA")
1198-
{
1199-
return new SparseBatch(FeatureSet<HalfKA>{}, entries);
1200-
}
1201-
else if (feature_set == "HalfKA^")
1202-
{
1203-
return new SparseBatch(FeatureSet<HalfKAFactorized>{}, entries);
1204-
}
1205-
else if (feature_set == "HalfKAv2")
1206-
{
1207-
return new SparseBatch(FeatureSet<HalfKAv2>{}, entries);
1208-
}
1209-
else if (feature_set == "HalfKAv2^")
1210-
{
1211-
return new SparseBatch(FeatureSet<HalfKAv2Factorized>{}, entries);
1212-
}
1213-
else if (feature_set == "HalfKAv2_hm")
1214-
{
1215-
return new SparseBatch(FeatureSet<HalfKAv2_hm>{}, entries);
1216-
}
1217-
else if (feature_set == "HalfKAv2_hm^")
1218-
{
1219-
return new SparseBatch(FeatureSet<HalfKAv2_hmFactorized>{}, entries);
1220-
}
1221-
else if (feature_set == "Full_Threats")
1222-
{
1223-
return new SparseBatch(FeatureSet<Full_Threats>{}, entries);
1224-
}
1225-
else if (feature_set == "Full_Threats^")
1226-
{
1227-
return new SparseBatch(FeatureSet<Full_ThreatsFactorized>{}, entries);
1228-
}
1229-
fprintf(stderr, "Unknown feature_set %s\n", feature_set_c);
1230-
return nullptr;
1249+
auto feature_variant = get_feature(feature_set_c);
1250+
1251+
return std::visit(
1252+
[&](const auto fs) -> SparseBatch* {
1253+
using F = std::decay_t<decltype(fs)>;
1254+
if constexpr (std::is_same_v<F, std::monostate>)
1255+
{
1256+
return nullptr;
1257+
}
1258+
else
1259+
{
1260+
return new SparseBatch(FeatureSet<decltype(fs)>{}, entries);
1261+
}
1262+
},
1263+
feature_variant);
12311264
}
12321265

12331266
// changing the signature needs matching changes in data_loader/_native.py
@@ -1256,59 +1289,22 @@ EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char*
12561289
auto skipPredicate = make_skip_predicate(config);
12571290
auto filenames_vec = std::vector<std::string>(filenames, filenames + num_files);
12581291

1259-
std::string_view feature_set(feature_set_c);
1260-
if (feature_set == "HalfKP")
1261-
{
1262-
return new FeaturedBatchStream<FeatureSet<HalfKP>, SparseBatch>(
1263-
concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1264-
}
1265-
else if (feature_set == "HalfKP^")
1266-
{
1267-
return new FeaturedBatchStream<FeatureSet<HalfKPFactorized>, SparseBatch>(
1268-
concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1269-
}
1270-
else if (feature_set == "HalfKA")
1271-
{
1272-
return new FeaturedBatchStream<FeatureSet<HalfKA>, SparseBatch>(
1273-
concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1274-
}
1275-
else if (feature_set == "HalfKA^")
1276-
{
1277-
return new FeaturedBatchStream<FeatureSet<HalfKAFactorized>, SparseBatch>(
1278-
concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1279-
}
1280-
else if (feature_set == "HalfKAv2")
1281-
{
1282-
return new FeaturedBatchStream<FeatureSet<HalfKAv2>, SparseBatch>(
1283-
concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1284-
}
1285-
else if (feature_set == "HalfKAv2^")
1286-
{
1287-
return new FeaturedBatchStream<FeatureSet<HalfKAv2Factorized>, SparseBatch>(
1288-
concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1289-
}
1290-
else if (feature_set == "HalfKAv2_hm")
1291-
{
1292-
return new FeaturedBatchStream<FeatureSet<HalfKAv2_hm>, SparseBatch>(
1293-
concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1294-
}
1295-
else if (feature_set == "HalfKAv2_hm^")
1296-
{
1297-
return new FeaturedBatchStream<FeatureSet<HalfKAv2_hmFactorized>, SparseBatch>(
1298-
concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1299-
}
1300-
else if (feature_set == "Full_Threats")
1301-
{
1302-
return new FeaturedBatchStream<FeatureSet<Full_Threats>, SparseBatch>(
1303-
concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1304-
}
1305-
else if (feature_set == "Full_Threats^")
1306-
{
1307-
return new FeaturedBatchStream<FeatureSet<Full_ThreatsFactorized>, SparseBatch>(
1308-
concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1309-
}
1310-
fprintf(stderr, "Unknown feature_set %s\n", feature_set_c);
1311-
return nullptr;
1292+
auto feature_variant = get_feature(feature_set_c);
1293+
1294+
return std::visit(
1295+
[&](const auto fs) -> Stream<SparseBatch>* {
1296+
using F = std::decay_t<decltype(fs)>;
1297+
if constexpr (std::is_same_v<F, std::monostate>)
1298+
{
1299+
return nullptr;
1300+
}
1301+
else
1302+
{
1303+
return new FeaturedBatchStream<FeatureSet<decltype(fs)>, SparseBatch>(
1304+
concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1305+
}
1306+
},
1307+
feature_variant);
13121308
}
13131309

13141310
EXPORT void CDECL destroy_sparse_batch_stream(Stream<SparseBatch>* stream) { delete stream; }
@@ -1415,4 +1411,4 @@ int main(int argc, char** argv) {
14151411
return 0;
14161412
}
14171413

1418-
#endif
1414+
#endif

0 commit comments

Comments
 (0)