1+ #include < cstddef>
12#include < iostream>
23#include < memory>
34#include < string>
45#include < algorithm>
56#include < iterator>
67#include < future>
78#include < mutex>
9+ #include < string_view>
810#include < thread>
911#include < deque>
1012#include < random>
13+ #include < variant>
1114
1215#include " lib/nnue_training_data_formats.h"
1316#include " lib/nnue_training_data_stream.h"
@@ -55,6 +58,8 @@ static Square orient_flip(Color color, Square sq) {
5558}
5659
5760struct HalfKP {
61+ static constexpr std::string_view NAME = " HalfKP" ;
62+
5863 static constexpr int NUM_SQ = 64 ;
5964 static constexpr int NUM_PT = 10 ;
6065 static constexpr int NUM_PLANES = (NUM_SQ * NUM_PT + 1 );
@@ -92,6 +97,8 @@ struct HalfKP {
9297};
9398
9499struct HalfKPFactorized {
100+ static constexpr std::string_view NAME = " HalfKP^" ;
101+
95102 // Factorized features
96103 static constexpr int K_INPUTS = HalfKP::NUM_SQ;
97104 static constexpr int PIECE_INPUTS = HalfKP::NUM_SQ * HalfKP::NUM_PT;
@@ -137,6 +144,8 @@ struct HalfKPFactorized {
137144};
138145
139146struct HalfKA {
147+ static constexpr std::string_view NAME = " HalfKA" ;
148+
140149 static constexpr int NUM_SQ = 64 ;
141150 static constexpr int NUM_PT = 12 ;
142151 static constexpr int NUM_PLANES = (NUM_SQ * NUM_PT + 1 );
@@ -170,6 +179,8 @@ struct HalfKA {
170179};
171180
172181struct HalfKAFactorized {
182+ static constexpr std::string_view NAME = " HalfKA^" ;
183+
173184 // Factorized features
174185 static constexpr int PIECE_INPUTS = HalfKA::NUM_SQ * HalfKA::NUM_PT;
175186 static constexpr int INPUTS = HalfKA::INPUTS + PIECE_INPUTS;
@@ -199,6 +210,8 @@ struct HalfKAFactorized {
199210};
200211
201212struct HalfKAv2 {
213+ static constexpr std::string_view NAME = " HalfKAv2" ;
214+
202215 static constexpr int NUM_SQ = 64 ;
203216 static constexpr int NUM_PT = 11 ;
204217 static constexpr int NUM_PLANES = NUM_SQ * NUM_PT;
@@ -234,6 +247,8 @@ struct HalfKAv2 {
234247};
235248
236249struct HalfKAv2Factorized {
250+ static constexpr std::string_view NAME = " HalfKAv2^" ;
251+
237252 // Factorized features
238253 static constexpr int NUM_PT = 12 ;
239254 static constexpr int PIECE_INPUTS = HalfKAv2::NUM_SQ * NUM_PT;
@@ -274,6 +289,8 @@ static Square orient_flip_2(Color color, Square sq, Square ksq) {
274289}
275290
276291struct HalfKAv2_hm {
292+ static constexpr std::string_view NAME = " HalfKAv2_hm" ;
293+
277294 static constexpr int NUM_SQ = 64 ;
278295 static constexpr int NUM_PT = 11 ;
279296 static constexpr int NUM_PLANES = NUM_SQ * NUM_PT;
@@ -325,6 +342,8 @@ struct HalfKAv2_hm {
325342};
326343
327344struct HalfKAv2_hmFactorized {
345+ static constexpr std::string_view NAME = " HalfKAv2_hm^" ;
346+
328347 // Factorized features
329348 static constexpr int NUM_PT = 12 ;
330349 static constexpr int PIECE_INPUTS = HalfKAv2_hm::NUM_SQ * NUM_PT;
@@ -399,6 +418,8 @@ constexpr auto threatoffsets = []() {
399418}();
400419
401420struct Full_Threats {
421+ static constexpr std::string_view NAME = " Full_Threats" ;
422+
402423 static constexpr int SQUARE_NB = 64 ;
403424 static constexpr int PIECE_NB = 12 ;
404425 static constexpr int COLOR_NB = 2 ;
@@ -570,6 +591,8 @@ struct Full_Threats {
570591};
571592
572593struct Full_ThreatsFactorized {
594+ static constexpr std::string_view NAME = " Full_Threats^" ;
595+
573596 // Factorized features
574597 static constexpr int PIECE_INPUTS = 768 ;
575598 static constexpr int INPUTS = 79856 + 22528 + 768 ;
@@ -605,12 +628,50 @@ struct FeatureSet {
605628 static constexpr int INPUTS = T::INPUTS;
606629 static constexpr int MAX_ACTIVE_FEATURES = T::MAX_ACTIVE_FEATURES;
607630
631+ static constexpr std::string_view NAME = T::NAME;
632+
608633 static std::pair<int , int >
609634 fill_features_sparse (const TrainingDataEntry& e, int * features, float * values, Color color) {
610635 return T::fill_features_sparse (e, features, values, color);
611636 }
612637};
613638
639+
640+ template <typename ... Ts>
641+ auto find_feature (std::string_view name) {
642+ using Variant = std::variant<std::monostate, Ts...>;
643+ using Factory = Variant (*)();
644+
645+ struct Entry {
646+ std::string_view name;
647+ Factory make;
648+ };
649+
650+ static constexpr Entry factories[] = {{Ts::NAME, +[]() -> Variant { return Ts{}; }}...};
651+
652+ for (auto & f : factories)
653+ {
654+ if (name == f.name )
655+ return f.make ();
656+ }
657+
658+ return Variant{std::monostate{}};
659+ }
660+
661+ auto get_feature (std::string_view name) {
662+ return find_feature<HalfKP, //
663+ HalfKPFactorized, //
664+ HalfKA, //
665+ HalfKAFactorized, //
666+ HalfKAv2, //
667+ HalfKAv2Factorized, //
668+ HalfKAv2_hm, //
669+ HalfKAv2_hmFactorized, //
670+ Full_Threats, //
671+ Full_ThreatsFactorized //
672+ >(name);
673+ }
674+
614675struct SparseBatch {
615676 static constexpr bool IS_BATCH = true ;
616677
@@ -1185,49 +1246,21 @@ EXPORT SparseBatch* get_sparse_batch_from_fens(const char* feature_set_c,
11851246 e.result = results[i];
11861247 }
11871248
1188- std::string_view feature_set (feature_set_c);
1189- if (feature_set == " HalfKP" )
1190- {
1191- return new SparseBatch (FeatureSet<HalfKP>{}, entries);
1192- }
1193- else if (feature_set == " HalfKP^" )
1194- {
1195- return new SparseBatch (FeatureSet<HalfKPFactorized>{}, entries);
1196- }
1197- else if (feature_set == " HalfKA" )
1198- {
1199- return new SparseBatch (FeatureSet<HalfKA>{}, entries);
1200- }
1201- else if (feature_set == " HalfKA^" )
1202- {
1203- return new SparseBatch (FeatureSet<HalfKAFactorized>{}, entries);
1204- }
1205- else if (feature_set == " HalfKAv2" )
1206- {
1207- return new SparseBatch (FeatureSet<HalfKAv2>{}, entries);
1208- }
1209- else if (feature_set == " HalfKAv2^" )
1210- {
1211- return new SparseBatch (FeatureSet<HalfKAv2Factorized>{}, entries);
1212- }
1213- else if (feature_set == " HalfKAv2_hm" )
1214- {
1215- return new SparseBatch (FeatureSet<HalfKAv2_hm>{}, entries);
1216- }
1217- else if (feature_set == " HalfKAv2_hm^" )
1218- {
1219- return new SparseBatch (FeatureSet<HalfKAv2_hmFactorized>{}, entries);
1220- }
1221- else if (feature_set == " Full_Threats" )
1222- {
1223- return new SparseBatch (FeatureSet<Full_Threats>{}, entries);
1224- }
1225- else if (feature_set == " Full_Threats^" )
1226- {
1227- return new SparseBatch (FeatureSet<Full_ThreatsFactorized>{}, entries);
1228- }
1229- fprintf (stderr, " Unknown feature_set %s\n " , feature_set_c);
1230- return nullptr ;
1249+ auto feature_variant = get_feature (feature_set_c);
1250+
1251+ return std::visit (
1252+ [&](const auto fs) -> SparseBatch* {
1253+ using F = std::decay_t <decltype (fs)>;
1254+ if constexpr (std::is_same_v<F, std::monostate>)
1255+ {
1256+ return nullptr ;
1257+ }
1258+ else
1259+ {
1260+ return new SparseBatch (FeatureSet<decltype (fs)>{}, entries);
1261+ }
1262+ },
1263+ feature_variant);
12311264}
12321265
12331266// changing the signature needs matching changes in data_loader/_native.py
@@ -1256,59 +1289,22 @@ EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char*
12561289 auto skipPredicate = make_skip_predicate (config);
12571290 auto filenames_vec = std::vector<std::string>(filenames, filenames + num_files);
12581291
1259- std::string_view feature_set (feature_set_c);
1260- if (feature_set == " HalfKP" )
1261- {
1262- return new FeaturedBatchStream<FeatureSet<HalfKP>, SparseBatch>(
1263- concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1264- }
1265- else if (feature_set == " HalfKP^" )
1266- {
1267- return new FeaturedBatchStream<FeatureSet<HalfKPFactorized>, SparseBatch>(
1268- concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1269- }
1270- else if (feature_set == " HalfKA" )
1271- {
1272- return new FeaturedBatchStream<FeatureSet<HalfKA>, SparseBatch>(
1273- concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1274- }
1275- else if (feature_set == " HalfKA^" )
1276- {
1277- return new FeaturedBatchStream<FeatureSet<HalfKAFactorized>, SparseBatch>(
1278- concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1279- }
1280- else if (feature_set == " HalfKAv2" )
1281- {
1282- return new FeaturedBatchStream<FeatureSet<HalfKAv2>, SparseBatch>(
1283- concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1284- }
1285- else if (feature_set == " HalfKAv2^" )
1286- {
1287- return new FeaturedBatchStream<FeatureSet<HalfKAv2Factorized>, SparseBatch>(
1288- concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1289- }
1290- else if (feature_set == " HalfKAv2_hm" )
1291- {
1292- return new FeaturedBatchStream<FeatureSet<HalfKAv2_hm>, SparseBatch>(
1293- concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1294- }
1295- else if (feature_set == " HalfKAv2_hm^" )
1296- {
1297- return new FeaturedBatchStream<FeatureSet<HalfKAv2_hmFactorized>, SparseBatch>(
1298- concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1299- }
1300- else if (feature_set == " Full_Threats" )
1301- {
1302- return new FeaturedBatchStream<FeatureSet<Full_Threats>, SparseBatch>(
1303- concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1304- }
1305- else if (feature_set == " Full_Threats^" )
1306- {
1307- return new FeaturedBatchStream<FeatureSet<Full_ThreatsFactorized>, SparseBatch>(
1308- concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1309- }
1310- fprintf (stderr, " Unknown feature_set %s\n " , feature_set_c);
1311- return nullptr ;
1292+ auto feature_variant = get_feature (feature_set_c);
1293+
1294+ return std::visit (
1295+ [&](const auto fs) -> Stream<SparseBatch>* {
1296+ using F = std::decay_t <decltype (fs)>;
1297+ if constexpr (std::is_same_v<F, std::monostate>)
1298+ {
1299+ return nullptr ;
1300+ }
1301+ else
1302+ {
1303+ return new FeaturedBatchStream<FeatureSet<decltype (fs)>, SparseBatch>(
1304+ concurrency, filenames_vec, batch_size, cyclic, skipPredicate);
1305+ }
1306+ },
1307+ feature_variant);
13121308}
13131309
13141310EXPORT void CDECL destroy_sparse_batch_stream (Stream<SparseBatch>* stream) { delete stream; }
@@ -1415,4 +1411,4 @@ int main(int argc, char** argv) {
14151411 return 0 ;
14161412}
14171413
1418- #endif
1414+ #endif
0 commit comments